Skip to content

Commit

Permalink
[Feature] Multi-thread Grammar Compilation (#39)
Browse files Browse the repository at this point in the history
This PR supports multi-thread grammar complication, therefore significantly speeds up the preprocessing stage. 

APIs:
```
compiler = CachedGrammarCompiler(tokenizer_info, max_threads)
compiled_grammar = CompiledGrammar(grammar, tokenizer_info, max_threads)
```
  • Loading branch information
Ubospica authored Nov 15, 2024
1 parent f529e27 commit 2e43bb3
Show file tree
Hide file tree
Showing 12 changed files with 376 additions and 144 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ json_schema_str = BuiltinGrammar.json_schema(json.dumps(person_schema))
```

#### Step 2: Compiling grammars
The compilation is multi-threaded and cached for every grammar.

```python
from xgrammar import TokenizerInfo, CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
Expand All @@ -91,7 +92,7 @@ tokenizer_info = TokenizerInfo.from_huggingface(tokenizer)
Method 1: Use CachedGrammarCompiler to avoid compile grammar multiple times
```python
# 2. Construct CachedGrammarCompiler (once per model)
compiler = CachedGrammarCompiler(tokenizer_info)
compiler = CachedGrammarCompiler(tokenizer_info, max_threads=8)

# 3. Fetch CompiledGrammar and construct GrammarMatcher (once per request)
compiled_grammar = compiler.compile_json_schema(json_schema_str)
Expand All @@ -101,7 +102,7 @@ matcher = GrammarMatcher(compiled_grammar)
Method 2: Compile grammar directly
```python
# 2. Construct CompiledGrammar directly (once per grammar)
compiled_grammar = CompiledGrammar(grammar, tokenizer_info)
compiled_grammar = CompiledGrammar(grammar, tokenizer_info, max_threads=8)

# 3. Construct GrammarMatcher (once per request)
matcher = GrammarMatcher(compiled_grammar)
Expand Down
172 changes: 74 additions & 98 deletions cpp/grammar_matcher_preproc.h → cpp/grammar_cached_compiler.h
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
/*!
* Copyright (c) 2024 by Contributors
* \file xgrammar/grammar_matcher_preproc.h
* \brief The header for the preprocessing of the grammar matcher.
* \file xgrammar/grammar_cached_compiler.h
* \brief The header for the cached compiler of the grammar matcher.
*/
#ifndef XGRAMMAR_GRAMMAR_MATCHER_PREPROC_H_
#define XGRAMMAR_GRAMMAR_MATCHER_PREPROC_H_
#ifndef XGRAMMAR_GRAMMAR_CACHED_COMPILER_H_
#define XGRAMMAR_GRAMMAR_CACHED_COMPILER_H_

#include <xgrammar/xgrammar.h>

#include <sstream>
#include <unordered_set>
#include <vector>

#include "grammar_data_structure.h"
#include "grammar_matcher_base.h"
#include "support/dynamic_bitset.h"
#include "support/encoding.h"
#include "support/thread_pool.h"
#include "support/thread_safe_cache.h"
#include "support/utils.h"

Expand Down Expand Up @@ -74,30 +76,12 @@ struct CatagorizedTokens {
*/
class CompiledGrammar::Impl {
public:
Impl(const BNFGrammar& grammar, const std::vector<std::string>& decoded_vocab);
Impl(const BNFGrammar& grammar, const TokenizerInfo& tokenizer_info)
: Impl(grammar, tokenizer_info.GetDecodedVocab()) {}

/******************* Information about the tokenizer *******************/

/*! \brief The vocabulary size of the tokenizer. Special tokens are included. */
size_t vocab_size;
/*! \brief The vocabulary. Special tokens are included. */
std::vector<std::string> decoded_vocab;
/*! \brief All (id, token) pairs sorted in lexicographic order. This sorting is done to
* maximize prefix reuse during matching. Special tokens and stop tokens are not included. */
std::vector<std::pair<int32_t, std::string>> sorted_decoded_vocab;
/*! \brief The stop tokens. When the GrammarMatcher can reach the end of the grammar,
* stop tokens can be accepted. */
std::vector<int32_t> detected_stop_token_ids;
/*! \brief The special tokens. These tokens are ignored (masked out) during the grammar-guided
* generation. */
std::unordered_set<int32_t> special_token_ids;

/******************* Information about the grammar *******************/
Impl(const BNFGrammar& grammar, const TokenizerInfo& tokenizer_info, int max_threads);

/*! \brief The grammar for the GrammarMatcher. */
BNFGrammar grammar;
/*! \brief The tokenizer information. */
TokenizerInfo tokenizer_info;

/******************* Grammar-specific tokenizer information *******************/

Expand Down Expand Up @@ -127,12 +111,13 @@ class CompiledGrammar::Impl {

class CachedGrammarCompiler::Impl {
public:
Impl(const std::vector<std::string>& decoded_vocab)
: decoded_vocab_(decoded_vocab),
compiled_grammar_for_json_cache_([this]() {
return CompiledGrammar(BuiltinGrammar::JSON(), this->decoded_vocab_);
Impl(const TokenizerInfo& tokenizer_info, int max_threads)
: tokenizer_info_(tokenizer_info),
max_threads_(max_threads),
compiled_grammar_for_json_cache_([&]() {
return CompiledGrammar(BuiltinGrammar::JSON(), this->tokenizer_info_, this->max_threads_);
}),
compiled_grammar_for_schema_cache_([this](const SchemaKey& key) {
compiled_grammar_for_schema_cache_([&](const SchemaKey& key) {
return this->ComputeCompiledGrammarForJSONSchema(key);
}) {}

Expand All @@ -156,12 +141,16 @@ class CachedGrammarCompiler::Impl {
CompiledGrammar ComputeCompiledGrammarForJSONSchema(const SchemaKey& key) {
auto [schema, indent, separators, strict_mode] = key;
return CompiledGrammar(
BuiltinGrammar::JSONSchema(schema, indent, separators, strict_mode), decoded_vocab_
BuiltinGrammar::JSONSchema(schema, indent, separators, strict_mode),
tokenizer_info_,
max_threads_
);
}

/*! \brief The vocabulary associated with this storage class. */
std::vector<std::string> decoded_vocab_;
TokenizerInfo tokenizer_info_;
/*! \brief The maximum number of threads to use. */
int max_threads_;
/*! \brief The cache for the compiled grammar for JSON. */
ThreadSafeCache<CompiledGrammar> compiled_grammar_for_json_cache_;
/*! \brief The cache for the compiled grammar of a JSON schema. */
Expand Down Expand Up @@ -372,50 +361,26 @@ inline CatagorizedTokens GrammarMatcherForCompiler::GetCatagorizedTokens(
/******************* CompiledGrammar *******************/

CompiledGrammar::Impl::Impl(
const BNFGrammar& grammar, const std::vector<std::string>& decoded_vocab
const BNFGrammar& grammar, const TokenizerInfo& tokenizer_info, int max_threads
) {
using RuleExprType = BNFGrammar::Impl::RuleExprType;

this->grammar = grammar;
this->vocab_size = decoded_vocab.size();
this->decoded_vocab = decoded_vocab;
this->tokenizer_info = tokenizer_info;

if (this->vocab_size == 0) {
if (tokenizer_info.GetVocabSize() == 0) {
return;
}

for (int i = 0; i < static_cast<int>(decoded_vocab.size()); ++i) {
const auto& token = decoded_vocab[i];
// TODO(yixin): Now we detect stop tokens from the token string. We should be able to pass
// the stop token set in.
// LLaMA2: </s>
// LLaMA3: <|end_of_text|>, <|eot_id|>
// Phi-2: <|endoftext|>
// Gemma: <eos>, <end_of_turn>
if (token == "</s>" || token == "<|end_of_text|>" || token == "<|eot_id|>" ||
token == "<|endoftext|>" || token == "<eos>" || token == "<|eos|>" ||
token == "<end_of_turn>" || token == "<|end▁of▁sentence|>") {
this->detected_stop_token_ids.push_back(i);
} else if ((token[0] == '<' && token.back() == '>' && token.size() >= 3) ||
token == "[@BOS@]") {
// gemma treats [@BOS@] as a special token
this->special_token_ids.insert(i);
} else {
this->sorted_decoded_vocab.push_back({i, token});
}
}

auto f_compare_token = [](const std::pair<int32_t, std::string>& a,
const std::pair<int32_t, std::string>& b) {
return a.second < b.second;
};
std::sort(this->sorted_decoded_vocab.begin(), this->sorted_decoded_vocab.end(), f_compare_token);

// Find the corresponding catagorized tokens for:
// 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3)
// 2. All byte strings (with element_in_string=0, 1, 2, ...)
auto root_rule_id = grammar->GetMainRuleId();
for (int rule_id = 0; rule_id < static_cast<int>(grammar->NumRules()); ++rule_id) {

ThreadPool thread_pool(max_threads);
std::mutex catagorized_tokens_mutex;

auto root_rule_id = grammar->GetRootRuleId();
for (int32_t rule_id = 0; rule_id < static_cast<int>(grammar->NumRules()); ++rule_id) {
auto rule = grammar->GetRule(rule_id);
auto rule_body = grammar->GetRuleExpr(rule.body_expr_id);
XGRAMMAR_DCHECK(rule_body.type == RuleExprType::kChoices);
Expand All @@ -430,43 +395,52 @@ CompiledGrammar::Impl::Impl(
if (element.type == RuleExprType::kRuleRef) {
continue;
}

auto add_catagorized_tokens = [&](const RulePosition& rule_position) {
auto grammar_matcher = GrammarMatcherForCompiler(grammar, rule_position);
auto cur_catagorized_tokens_for_grammar = grammar_matcher.GetCatagorizedTokens(
this->vocab_size, this->sorted_decoded_vocab, rule_id != root_rule_id
);
this->catagorized_tokens_for_grammar[rule_position] = cur_catagorized_tokens_for_grammar;
};

auto cur_rule_position = RulePosition(rule_id, sequence_id, element_id);
if (element.type == RuleExprType::kByteString) {
for (int idx = 0; idx < element.size(); ++idx) {
cur_rule_position.element_in_string = idx;
add_catagorized_tokens(cur_rule_position);
thread_pool.Execute([&, rule_id, sequence_id, element_id, element]() {
auto add_catagorized_tokens = [&](const RulePosition& rule_position) {
auto grammar_matcher = GrammarMatcherForCompiler(grammar, rule_position);
auto cur_catagorized_tokens_for_grammar = grammar_matcher.GetCatagorizedTokens(
tokenizer_info.GetVocabSize(),
tokenizer_info.GetSortedDecodedVocab(),
rule_id != root_rule_id
);
{
std::lock_guard<std::mutex> lock(catagorized_tokens_mutex);
this->catagorized_tokens_for_grammar[rule_position] =
cur_catagorized_tokens_for_grammar;
}
};

auto cur_rule_position = RulePosition(rule_id, sequence_id, element_id);
if (element.type == RuleExprType::kByteString) {
for (int idx = 0; idx < element.size(); ++idx) {
cur_rule_position.element_in_string = idx;
add_catagorized_tokens(cur_rule_position);
}
} else {
XGRAMMAR_DCHECK(
element.type == RuleExprType::kCharacterClassStar ||
element.type == RuleExprType::kCharacterClass
);
for (int left_utf8_bytes = 0; left_utf8_bytes <= 3; ++left_utf8_bytes) {
cur_rule_position.left_utf8_bytes = left_utf8_bytes;
add_catagorized_tokens(cur_rule_position);
}
}
} else {
XGRAMMAR_DCHECK(
element.type == RuleExprType::kCharacterClassStar ||
element.type == RuleExprType::kCharacterClass
);
for (int left_utf8_bytes = 0; left_utf8_bytes <= 3; ++left_utf8_bytes) {
cur_rule_position.left_utf8_bytes = left_utf8_bytes;
add_catagorized_tokens(cur_rule_position);
}
}
});
}
}
}
}

CompiledGrammar::CompiledGrammar(
const BNFGrammar& grammar, const std::vector<std::string>& decoded_vocab
const BNFGrammar& grammar, const std::vector<std::string>& decoded_vocab, int max_threads
)
: pimpl_(std::make_shared<Impl>(grammar, decoded_vocab)) {}
: pimpl_(std::make_shared<Impl>(grammar, TokenizerInfo(decoded_vocab), max_threads)) {}

CompiledGrammar::CompiledGrammar(const BNFGrammar& grammar, const TokenizerInfo& tokenizer_info)
: pimpl_(std::make_shared<Impl>(grammar, tokenizer_info)) {}
CompiledGrammar::CompiledGrammar(
const BNFGrammar& grammar, const TokenizerInfo& tokenizer_info, int max_threads
)
: pimpl_(std::make_shared<Impl>(grammar, tokenizer_info, max_threads)) {}

/******************* CachedGrammarCompiler *******************/

Expand All @@ -492,11 +466,13 @@ inline void CachedGrammarCompiler::Impl::Clear() {
compiled_grammar_for_schema_cache_.Clear();
}

CachedGrammarCompiler::CachedGrammarCompiler(const std::vector<std::string>& decoded_vocab)
: pimpl_(std::make_shared<Impl>(decoded_vocab)) {}
CachedGrammarCompiler::CachedGrammarCompiler(
const std::vector<std::string>& decoded_vocab, int max_threads
)
: pimpl_(std::make_shared<Impl>(TokenizerInfo(decoded_vocab), max_threads)) {}

CachedGrammarCompiler::CachedGrammarCompiler(const TokenizerInfo& tokenizer_info)
: pimpl_(std::make_shared<Impl>(tokenizer_info.GetDecodedVocab())) {}
CachedGrammarCompiler::CachedGrammarCompiler(const TokenizerInfo& tokenizer_info, int max_threads)
: pimpl_(std::make_shared<Impl>(tokenizer_info, max_threads)) {}

CompiledGrammar CachedGrammarCompiler::CompileJSONGrammar() { return pimpl_->CompileJSONGrammar(); }

Expand All @@ -513,4 +489,4 @@ void CachedGrammarCompiler::Clear() { pimpl_->Clear(); }

} // namespace xgrammar

#endif // XGRAMMAR_GRAMMAR_MATCHER_PREPROC_H_
#endif // XGRAMMAR_GRAMMAR_CACHED_COMPILER_H_
4 changes: 2 additions & 2 deletions cpp/grammar_data_structure.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ class BNFGrammar::Impl {
return rules_[rule_id];
}
/*! \brief Get the root rule id of the grammar. */
int32_t GetMainRuleId() const { return root_rule_id_; }
int32_t GetRootRuleId() const { return root_rule_id_; }
/*! \brief Get the root rule of the grammar. */
const Rule& GetMainRule() const {
const Rule& GetRootRule() const {
XGRAMMAR_DCHECK(root_rule_id_ >= 0 && root_rule_id_ < static_cast<int32_t>(rules_.size()))
<< "root_rule_id " << root_rule_id_ << " is out of bound";
return rules_[root_rule_id_];
Expand Down
2 changes: 1 addition & 1 deletion cpp/grammar_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class NestedRuleUnwrapper : public BNFGrammarMutator {
builder_.UpdateRuleBody(i, new_body_expr_id);
builder_.AddLookaheadAssertion(i, VisitLookaheadAssertion(rule.lookahead_assertion_id));
}
return builder_.Get(grammar_->GetMainRule().name);
return builder_.Get(grammar_->GetRootRule().name);
}

private:
Expand Down
2 changes: 1 addition & 1 deletion cpp/grammar_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class BNFGrammarFunctor {
// Handle lookahead assertion
builder_.AddLookaheadAssertion(i, VisitLookaheadAssertion(rule.lookahead_assertion_id));
}
return builder_.Get(grammar_->GetMainRule().name);
return builder_.Get(grammar_->GetRootRule().name);
} else {
return ReturnType();
}
Expand Down
Loading

0 comments on commit 2e43bb3

Please sign in to comment.