From 2e43bb34bafa2a572103e4ce8a53dc851dfc52c0 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Fri, 15 Nov 2024 08:40:05 -0500 Subject: [PATCH] [Feature] Multi-thread Grammar Compilation (#39) This PR supports multi-thread grammar complication, therefore significantly speeds up the preprocessing stage. APIs: ``` compiler = CachedGrammarCompiler(tokenizer_info, max_threads) compiled_grammar = CompiledGrammar(grammar, tokenizer_info, max_threads) ``` --- README.md | 5 +- ...er_preproc.h => grammar_cached_compiler.h} | 172 ++++++++--------- cpp/grammar_data_structure.h | 4 +- cpp/grammar_functor.cc | 2 +- cpp/grammar_functor.h | 2 +- cpp/grammar_matcher.cc | 32 ++-- cpp/grammar_matcher_base.h | 2 +- cpp/pybind/pybind.cc | 5 +- cpp/support/thread_pool.h | 177 ++++++++++++++++++ cpp/tokenizer.cc | 72 ++++++- include/xgrammar/xgrammar.h | 15 +- python/xgrammar/xgrammar.py | 32 ++-- 12 files changed, 376 insertions(+), 144 deletions(-) rename cpp/{grammar_matcher_preproc.h => grammar_cached_compiler.h} (75%) create mode 100644 cpp/support/thread_pool.h diff --git a/README.md b/README.md index 5c4e6af..98d9701 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,7 @@ json_schema_str = BuiltinGrammar.json_schema(json.dumps(person_schema)) ``` #### Step 2: Compiling grammars +The compilation is multi-threaded and cached for every grammar. ```python from xgrammar import TokenizerInfo, CachedGrammarCompiler, CompiledGrammar, GrammarMatcher @@ -91,7 +92,7 @@ tokenizer_info = TokenizerInfo.from_huggingface(tokenizer) Method 1: Use CachedGrammarCompiler to avoid compile grammar multiple times ```python # 2. Construct CachedGrammarCompiler (once per model) -compiler = CachedGrammarCompiler(tokenizer_info) +compiler = CachedGrammarCompiler(tokenizer_info, max_threads=8) # 3. Fetch CompiledGrammar and construct GrammarMatcher (once per request) compiled_grammar = compiler.compile_json_schema(json_schema_str) @@ -101,7 +102,7 @@ matcher = GrammarMatcher(compiled_grammar) Method 2: Compile grammar directly ```python # 2. Construct CompiledGrammar directly (once per grammar) -compiled_grammar = CompiledGrammar(grammar, tokenizer_info) +compiled_grammar = CompiledGrammar(grammar, tokenizer_info, max_threads=8) # 3. Construct GrammarMatcher (once per request) matcher = GrammarMatcher(compiled_grammar) diff --git a/cpp/grammar_matcher_preproc.h b/cpp/grammar_cached_compiler.h similarity index 75% rename from cpp/grammar_matcher_preproc.h rename to cpp/grammar_cached_compiler.h index 764d72b..e0ca0dc 100644 --- a/cpp/grammar_matcher_preproc.h +++ b/cpp/grammar_cached_compiler.h @@ -1,13 +1,14 @@ /*! * Copyright (c) 2024 by Contributors - * \file xgrammar/grammar_matcher_preproc.h - * \brief The header for the preprocessing of the grammar matcher. + * \file xgrammar/grammar_cached_compiler.h + * \brief The header for the cached compiler of the grammar matcher. */ -#ifndef XGRAMMAR_GRAMMAR_MATCHER_PREPROC_H_ -#define XGRAMMAR_GRAMMAR_MATCHER_PREPROC_H_ +#ifndef XGRAMMAR_GRAMMAR_CACHED_COMPILER_H_ +#define XGRAMMAR_GRAMMAR_CACHED_COMPILER_H_ #include +#include #include #include @@ -15,6 +16,7 @@ #include "grammar_matcher_base.h" #include "support/dynamic_bitset.h" #include "support/encoding.h" +#include "support/thread_pool.h" #include "support/thread_safe_cache.h" #include "support/utils.h" @@ -74,30 +76,12 @@ struct CatagorizedTokens { */ class CompiledGrammar::Impl { public: - Impl(const BNFGrammar& grammar, const std::vector& decoded_vocab); - Impl(const BNFGrammar& grammar, const TokenizerInfo& tokenizer_info) - : Impl(grammar, tokenizer_info.GetDecodedVocab()) {} - - /******************* Information about the tokenizer *******************/ - - /*! \brief The vocabulary size of the tokenizer. Special tokens are included. */ - size_t vocab_size; - /*! \brief The vocabulary. Special tokens are included. */ - std::vector decoded_vocab; - /*! \brief All (id, token) pairs sorted in lexicographic order. This sorting is done to - * maximize prefix reuse during matching. Special tokens and stop tokens are not included. */ - std::vector> sorted_decoded_vocab; - /*! \brief The stop tokens. When the GrammarMatcher can reach the end of the grammar, - * stop tokens can be accepted. */ - std::vector detected_stop_token_ids; - /*! \brief The special tokens. These tokens are ignored (masked out) during the grammar-guided - * generation. */ - std::unordered_set special_token_ids; - - /******************* Information about the grammar *******************/ + Impl(const BNFGrammar& grammar, const TokenizerInfo& tokenizer_info, int max_threads); /*! \brief The grammar for the GrammarMatcher. */ BNFGrammar grammar; + /*! \brief The tokenizer information. */ + TokenizerInfo tokenizer_info; /******************* Grammar-specific tokenizer information *******************/ @@ -127,12 +111,13 @@ class CompiledGrammar::Impl { class CachedGrammarCompiler::Impl { public: - Impl(const std::vector& decoded_vocab) - : decoded_vocab_(decoded_vocab), - compiled_grammar_for_json_cache_([this]() { - return CompiledGrammar(BuiltinGrammar::JSON(), this->decoded_vocab_); + Impl(const TokenizerInfo& tokenizer_info, int max_threads) + : tokenizer_info_(tokenizer_info), + max_threads_(max_threads), + compiled_grammar_for_json_cache_([&]() { + return CompiledGrammar(BuiltinGrammar::JSON(), this->tokenizer_info_, this->max_threads_); }), - compiled_grammar_for_schema_cache_([this](const SchemaKey& key) { + compiled_grammar_for_schema_cache_([&](const SchemaKey& key) { return this->ComputeCompiledGrammarForJSONSchema(key); }) {} @@ -156,12 +141,16 @@ class CachedGrammarCompiler::Impl { CompiledGrammar ComputeCompiledGrammarForJSONSchema(const SchemaKey& key) { auto [schema, indent, separators, strict_mode] = key; return CompiledGrammar( - BuiltinGrammar::JSONSchema(schema, indent, separators, strict_mode), decoded_vocab_ + BuiltinGrammar::JSONSchema(schema, indent, separators, strict_mode), + tokenizer_info_, + max_threads_ ); } /*! \brief The vocabulary associated with this storage class. */ - std::vector decoded_vocab_; + TokenizerInfo tokenizer_info_; + /*! \brief The maximum number of threads to use. */ + int max_threads_; /*! \brief The cache for the compiled grammar for JSON. */ ThreadSafeCache compiled_grammar_for_json_cache_; /*! \brief The cache for the compiled grammar of a JSON schema. */ @@ -372,50 +361,26 @@ inline CatagorizedTokens GrammarMatcherForCompiler::GetCatagorizedTokens( /******************* CompiledGrammar *******************/ CompiledGrammar::Impl::Impl( - const BNFGrammar& grammar, const std::vector& decoded_vocab + const BNFGrammar& grammar, const TokenizerInfo& tokenizer_info, int max_threads ) { using RuleExprType = BNFGrammar::Impl::RuleExprType; this->grammar = grammar; - this->vocab_size = decoded_vocab.size(); - this->decoded_vocab = decoded_vocab; + this->tokenizer_info = tokenizer_info; - if (this->vocab_size == 0) { + if (tokenizer_info.GetVocabSize() == 0) { return; } - for (int i = 0; i < static_cast(decoded_vocab.size()); ++i) { - const auto& token = decoded_vocab[i]; - // TODO(yixin): Now we detect stop tokens from the token string. We should be able to pass - // the stop token set in. - // LLaMA2: - // LLaMA3: <|end_of_text|>, <|eot_id|> - // Phi-2: <|endoftext|> - // Gemma: , - if (token == "" || token == "<|end_of_text|>" || token == "<|eot_id|>" || - token == "<|endoftext|>" || token == "" || token == "<|eos|>" || - token == "" || token == "<|end▁of▁sentence|>") { - this->detected_stop_token_ids.push_back(i); - } else if ((token[0] == '<' && token.back() == '>' && token.size() >= 3) || - token == "[@BOS@]") { - // gemma treats [@BOS@] as a special token - this->special_token_ids.insert(i); - } else { - this->sorted_decoded_vocab.push_back({i, token}); - } - } - - auto f_compare_token = [](const std::pair& a, - const std::pair& b) { - return a.second < b.second; - }; - std::sort(this->sorted_decoded_vocab.begin(), this->sorted_decoded_vocab.end(), f_compare_token); - // Find the corresponding catagorized tokens for: // 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3) // 2. All byte strings (with element_in_string=0, 1, 2, ...) - auto root_rule_id = grammar->GetMainRuleId(); - for (int rule_id = 0; rule_id < static_cast(grammar->NumRules()); ++rule_id) { + + ThreadPool thread_pool(max_threads); + std::mutex catagorized_tokens_mutex; + + auto root_rule_id = grammar->GetRootRuleId(); + for (int32_t rule_id = 0; rule_id < static_cast(grammar->NumRules()); ++rule_id) { auto rule = grammar->GetRule(rule_id); auto rule_body = grammar->GetRuleExpr(rule.body_expr_id); XGRAMMAR_DCHECK(rule_body.type == RuleExprType::kChoices); @@ -430,43 +395,52 @@ CompiledGrammar::Impl::Impl( if (element.type == RuleExprType::kRuleRef) { continue; } - - auto add_catagorized_tokens = [&](const RulePosition& rule_position) { - auto grammar_matcher = GrammarMatcherForCompiler(grammar, rule_position); - auto cur_catagorized_tokens_for_grammar = grammar_matcher.GetCatagorizedTokens( - this->vocab_size, this->sorted_decoded_vocab, rule_id != root_rule_id - ); - this->catagorized_tokens_for_grammar[rule_position] = cur_catagorized_tokens_for_grammar; - }; - - auto cur_rule_position = RulePosition(rule_id, sequence_id, element_id); - if (element.type == RuleExprType::kByteString) { - for (int idx = 0; idx < element.size(); ++idx) { - cur_rule_position.element_in_string = idx; - add_catagorized_tokens(cur_rule_position); + thread_pool.Execute([&, rule_id, sequence_id, element_id, element]() { + auto add_catagorized_tokens = [&](const RulePosition& rule_position) { + auto grammar_matcher = GrammarMatcherForCompiler(grammar, rule_position); + auto cur_catagorized_tokens_for_grammar = grammar_matcher.GetCatagorizedTokens( + tokenizer_info.GetVocabSize(), + tokenizer_info.GetSortedDecodedVocab(), + rule_id != root_rule_id + ); + { + std::lock_guard lock(catagorized_tokens_mutex); + this->catagorized_tokens_for_grammar[rule_position] = + cur_catagorized_tokens_for_grammar; + } + }; + + auto cur_rule_position = RulePosition(rule_id, sequence_id, element_id); + if (element.type == RuleExprType::kByteString) { + for (int idx = 0; idx < element.size(); ++idx) { + cur_rule_position.element_in_string = idx; + add_catagorized_tokens(cur_rule_position); + } + } else { + XGRAMMAR_DCHECK( + element.type == RuleExprType::kCharacterClassStar || + element.type == RuleExprType::kCharacterClass + ); + for (int left_utf8_bytes = 0; left_utf8_bytes <= 3; ++left_utf8_bytes) { + cur_rule_position.left_utf8_bytes = left_utf8_bytes; + add_catagorized_tokens(cur_rule_position); + } } - } else { - XGRAMMAR_DCHECK( - element.type == RuleExprType::kCharacterClassStar || - element.type == RuleExprType::kCharacterClass - ); - for (int left_utf8_bytes = 0; left_utf8_bytes <= 3; ++left_utf8_bytes) { - cur_rule_position.left_utf8_bytes = left_utf8_bytes; - add_catagorized_tokens(cur_rule_position); - } - } + }); } } } } CompiledGrammar::CompiledGrammar( - const BNFGrammar& grammar, const std::vector& decoded_vocab + const BNFGrammar& grammar, const std::vector& decoded_vocab, int max_threads ) - : pimpl_(std::make_shared(grammar, decoded_vocab)) {} + : pimpl_(std::make_shared(grammar, TokenizerInfo(decoded_vocab), max_threads)) {} -CompiledGrammar::CompiledGrammar(const BNFGrammar& grammar, const TokenizerInfo& tokenizer_info) - : pimpl_(std::make_shared(grammar, tokenizer_info)) {} +CompiledGrammar::CompiledGrammar( + const BNFGrammar& grammar, const TokenizerInfo& tokenizer_info, int max_threads +) + : pimpl_(std::make_shared(grammar, tokenizer_info, max_threads)) {} /******************* CachedGrammarCompiler *******************/ @@ -492,11 +466,13 @@ inline void CachedGrammarCompiler::Impl::Clear() { compiled_grammar_for_schema_cache_.Clear(); } -CachedGrammarCompiler::CachedGrammarCompiler(const std::vector& decoded_vocab) - : pimpl_(std::make_shared(decoded_vocab)) {} +CachedGrammarCompiler::CachedGrammarCompiler( + const std::vector& decoded_vocab, int max_threads +) + : pimpl_(std::make_shared(TokenizerInfo(decoded_vocab), max_threads)) {} -CachedGrammarCompiler::CachedGrammarCompiler(const TokenizerInfo& tokenizer_info) - : pimpl_(std::make_shared(tokenizer_info.GetDecodedVocab())) {} +CachedGrammarCompiler::CachedGrammarCompiler(const TokenizerInfo& tokenizer_info, int max_threads) + : pimpl_(std::make_shared(tokenizer_info, max_threads)) {} CompiledGrammar CachedGrammarCompiler::CompileJSONGrammar() { return pimpl_->CompileJSONGrammar(); } @@ -513,4 +489,4 @@ void CachedGrammarCompiler::Clear() { pimpl_->Clear(); } } // namespace xgrammar -#endif // XGRAMMAR_GRAMMAR_MATCHER_PREPROC_H_ +#endif // XGRAMMAR_GRAMMAR_CACHED_COMPILER_H_ diff --git a/cpp/grammar_data_structure.h b/cpp/grammar_data_structure.h index 76d750d..243a225 100644 --- a/cpp/grammar_data_structure.h +++ b/cpp/grammar_data_structure.h @@ -83,9 +83,9 @@ class BNFGrammar::Impl { return rules_[rule_id]; } /*! \brief Get the root rule id of the grammar. */ - int32_t GetMainRuleId() const { return root_rule_id_; } + int32_t GetRootRuleId() const { return root_rule_id_; } /*! \brief Get the root rule of the grammar. */ - const Rule& GetMainRule() const { + const Rule& GetRootRule() const { XGRAMMAR_DCHECK(root_rule_id_ >= 0 && root_rule_id_ < static_cast(rules_.size())) << "root_rule_id " << root_rule_id_ << " is out of bound"; return rules_[root_rule_id_]; diff --git a/cpp/grammar_functor.cc b/cpp/grammar_functor.cc index 96bc038..dbbc9dc 100644 --- a/cpp/grammar_functor.cc +++ b/cpp/grammar_functor.cc @@ -108,7 +108,7 @@ class NestedRuleUnwrapper : public BNFGrammarMutator { builder_.UpdateRuleBody(i, new_body_expr_id); builder_.AddLookaheadAssertion(i, VisitLookaheadAssertion(rule.lookahead_assertion_id)); } - return builder_.Get(grammar_->GetMainRule().name); + return builder_.Get(grammar_->GetRootRule().name); } private: diff --git a/cpp/grammar_functor.h b/cpp/grammar_functor.h index aa90f67..8168d7a 100644 --- a/cpp/grammar_functor.h +++ b/cpp/grammar_functor.h @@ -63,7 +63,7 @@ class BNFGrammarFunctor { // Handle lookahead assertion builder_.AddLookaheadAssertion(i, VisitLookaheadAssertion(rule.lookahead_assertion_id)); } - return builder_.Get(grammar_->GetMainRule().name); + return builder_.Get(grammar_->GetRootRule().name); } else { return ReturnType(); } diff --git a/cpp/grammar_matcher.cc b/cpp/grammar_matcher.cc index cfcf71e..5f02bf8 100644 --- a/cpp/grammar_matcher.cc +++ b/cpp/grammar_matcher.cc @@ -7,9 +7,9 @@ #include #include +#include "grammar_cached_compiler.h" #include "grammar_data_structure.h" #include "grammar_matcher_base.h" -#include "grammar_matcher_preproc.h" #include "grammar_matcher_state.h" #include "grammar_serializer.h" #include "support/dynamic_bitset.h" @@ -129,9 +129,10 @@ class GrammarMatcher::Impl : public GrammarMatcherBase { ) : GrammarMatcherBase(compiled_grammar->grammar), compiled_grammar_(compiled_grammar), - stop_token_ids_(override_stop_tokens.value_or(compiled_grammar->detected_stop_token_ids)), + tokenizer_info_(compiled_grammar->tokenizer_info), + stop_token_ids_(override_stop_tokens.value_or(tokenizer_info_.GetStopTokenIds())), terminate_without_stop_token_(terminate_without_stop_token), - vocab_size_(vocab_size.value_or(compiled_grammar_->vocab_size)), + vocab_size_(vocab_size.value_or(tokenizer_info_.GetVocabSize())), max_rollback_tokens_(max_rollback_tokens), tmp_accepted_bitset_(vocab_size_) { XGRAMMAR_CHECK(!override_stop_tokens.has_value() || !override_stop_tokens->empty()) @@ -200,6 +201,7 @@ class GrammarMatcher::Impl : public GrammarMatcherBase { bool AcceptStopToken(); CompiledGrammar compiled_grammar_; + TokenizerInfo tokenizer_info_; std::vector stop_token_ids_; bool terminate_without_stop_token_; int vocab_size_; @@ -241,12 +243,12 @@ bool GrammarMatcher::Impl::AcceptToken(int32_t token_id, bool verbose) { return false; } - XGRAMMAR_CHECK(token_id >= 0 && token_id < static_cast(compiled_grammar_->vocab_size)) + XGRAMMAR_CHECK(token_id >= 0 && token_id < vocab_size_) << "Invalid token id " << token_id << " for GrammarMatcher"; if (verbose) { XGRAMMAR_LOG(INFO) << "Accepting token id " << token_id << ", string: \"" - << PrintAsEscapedUTF8(compiled_grammar_->decoded_vocab[token_id]) + << PrintAsEscapedUTF8(tokenizer_info_.GetDecodedVocab()[token_id]) << "\", state state:\n" << PrintStackState(); } @@ -261,14 +263,16 @@ bool GrammarMatcher::Impl::AcceptToken(int32_t token_id, bool verbose) { return accepted; } - if (compiled_grammar_->special_token_ids.count(token_id) > 0) { - XGRAMMAR_LOG(FATAL - ) << "Token id " - << token_id << ": " << compiled_grammar_->decoded_vocab[token_id] - << " is regarded as a special token, and cannot be accepted by the GrammarMatcher"; + const auto& special_token_ids = tokenizer_info_.GetSpecialTokenIds(); + if (std::find(special_token_ids.begin(), special_token_ids.end(), token_id) != + special_token_ids.end()) { + XGRAMMAR_LOG(FATAL) << "Token id " << token_id << ": " + << tokenizer_info_.GetDecodedVocab()[token_id] + << " is regarded as a special token, and cannot be accepted by the " + "GrammarMatcher"; } - const auto& token = compiled_grammar_->decoded_vocab[token_id]; + const auto& token = tokenizer_info_.GetDecodedVocab()[token_id]; int pos = 0; for (auto char_value : token) { if (!AcceptChar(char_value, verbose)) { @@ -342,7 +346,7 @@ void GrammarMatcher::Impl::FillNextTokenBitmask(DLTensor* next_token_bitmask) { ) << "GrammarMatcher has terminated after accepting the stop token, but is trying to " "find the next token mask"; CheckTokenBitmaskValidity(*next_token_bitmask, vocab_size_); - const auto& sorted_decoded_vocab = compiled_grammar_->sorted_decoded_vocab; + const auto& sorted_decoded_vocab = tokenizer_info_.GetSortedDecodedVocab(); const auto& catagorized_tokens_for_grammar = compiled_grammar_->catagorized_tokens_for_grammar; const auto& latest_stack_tops = stack_tops_history_.GetLatest(); @@ -568,7 +572,7 @@ void GrammarMatcher::Impl::SetTokenBitmask( DynamicBitset next_token_bitset( vocab_size_, reinterpret_cast(next_token_bitmask->data) ); - const auto& sorted_decoded_vocab = compiled_grammar_->sorted_decoded_vocab; + const auto& sorted_decoded_vocab = tokenizer_info_.GetSortedDecodedVocab(); if (rejected_indices.size() == 1 && rejected_indices[0] == -1) { // If rejected_indices is the universal set, the final accepted token set is just @@ -592,7 +596,7 @@ void GrammarMatcher::Impl::SetTokenBitmask( } } - for (int id : compiled_grammar_->special_token_ids) { + for (int id : tokenizer_info_.GetSpecialTokenIds()) { next_token_bitset.Set(id, false); } if (!can_reach_end) { diff --git a/cpp/grammar_matcher_base.h b/cpp/grammar_matcher_base.h index a30bbcd..762673c 100644 --- a/cpp/grammar_matcher_base.h +++ b/cpp/grammar_matcher_base.h @@ -280,7 +280,7 @@ inline void GrammarMatcherBase::PushInitialState( ) { if (init_rule_position == kInvalidRulePosition) { // Initialize the stack with the root rule. - auto root_rule = grammar_->GetMainRule(); + auto root_rule = grammar_->GetRootRule(); auto root_rule_body = grammar_->GetRuleExpr(root_rule.body_expr_id); tmp_new_stack_tops_.clear(); for (auto i : root_rule_body) { diff --git a/cpp/pybind/pybind.cc b/cpp/pybind/pybind.cc index b15d2bc..f45d391 100644 --- a/cpp/pybind/pybind.cc +++ b/cpp/pybind/pybind.cc @@ -39,11 +39,10 @@ PYBIND11_MODULE(xgrammar_bindings, m) { .def_static("from_vocab_and_metadata", &TokenizerInfo::FromVocabAndMetadata); auto pyCompiledGrammar = py::class_(m, "CompiledGrammar"); - pyCompiledGrammar.def(py::init&>()) - .def(py::init()); + pyCompiledGrammar.def(py::init()); auto pyCachedGrammarCompiler = py::class_(m, "CachedGrammarCompiler"); - pyCachedGrammarCompiler.def(py::init()) + pyCachedGrammarCompiler.def(py::init()) .def( "compile_json_grammar", &CachedGrammarCompiler::CompileJSONGrammar, diff --git a/cpp/support/thread_pool.h b/cpp/support/thread_pool.h new file mode 100644 index 0000000..b4e7ddc --- /dev/null +++ b/cpp/support/thread_pool.h @@ -0,0 +1,177 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file support/thread_pool.h + * \brief Thread pool. + */ +#ifndef XGRAMMAR_SUPPORT_THREAD_POOL_H_ +#define XGRAMMAR_SUPPORT_THREAD_POOL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "logging.h" + +namespace xgrammar { + +/*! + * \brief A thread pool implementation for parallel task execution. + * + * ThreadPool manages a pool of worker threads that can execute tasks asynchronously. + * Tasks are submitted to a queue and executed by available threads from the pool. + * The pool automatically handles thread synchronization and task distribution. + */ +class ThreadPool { + public: + /*! + * \brief Construct a new thread pool with the specified number of threads. + * \param num_threads Number of worker threads to create. Defaults to hardware concurrency. + * \note The pool starts the worker threads immediately upon construction. + */ + ThreadPool(size_t num_threads = std::thread::hardware_concurrency()) { + // Initialize thread pool with num_threads threads + for (size_t i = 0; i < num_threads; ++i) { + workers_.emplace_back([this] { + while (true) { + std::function task; + { + // Lock queue while waiting for new task + std::unique_lock lock(queue_mutex_); + queue_condition_.wait(lock, [this] { return shutdown_ || !task_queue_.empty(); }); + + // Exit thread if shutdown and queue is empty + if (shutdown_ && task_queue_.empty()) return; + + // Get task from queue + task = std::move(task_queue_.front()); + task_queue_.pop(); + } + // Execute task outside the lock to allow other threads to get new tasks + task(); + } + }); + } + } + + /*! + * \brief Add a new task to be executed by the thread pool. + * \tparam F Type of the function to execute + * \tparam Args Types of the arguments to pass to the function + * \param f Function to execute + * \param args Arguments to pass to the function + * \return std::shared_future containing the result of the function call + * \note Tasks are executed in FIFO order but may complete in any order. + */ + template + auto Submit(F&& f, Args&&... args) + -> std::shared_future::type> { + using return_type = typename std::result_of::type; + + // Package the task with its arguments into a shared pointer to allow safe capture in lambda + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...) + ); + + std::shared_future res = task->get_future().share(); + + { + std::unique_lock lock(queue_mutex_); + XGRAMMAR_CHECK(!shutdown_) << "Cannot submit task to stopped ThreadPool"; + + // Wrap task in lambda to allow type erasure via std::function + task_queue_.emplace([task]() { (*task)(); }); + } + queue_condition_.notify_one(); + return res; + } + + /*! + * \brief Add a new task to be executed by the thread pool without returning a future. + * \tparam F Type of the function to execute + * \tparam Args Types of the arguments to pass to the function + * \param f Function to execute + * \param args Arguments to pass to the function + * \note Tasks are executed asynchronously by the worker threads. + */ + template + void Execute(F&& f, Args&&... args) { + { + std::unique_lock lock(queue_mutex_); + XGRAMMAR_CHECK(!shutdown_) << "Cannot execute task in stopped ThreadPool"; + + // Wrap the function and its arguments into a std::function + task_queue_.emplace(std::bind(std::forward(f), std::forward(args)...)); + } + queue_condition_.notify_one(); + } + + /*! + * \brief Destructor that ensures graceful shutdown of the thread pool. + * + * Sets shutdown flag and waits for all threads to complete their current tasks + * before destroying the pool. Any remaining tasks in the queue will be executed + * before shutdown completes. + */ + ~ThreadPool() { + { + std::unique_lock lock(queue_mutex_); + shutdown_ = true; + } + queue_condition_.notify_all(); // Wake up all threads so they can exit + for (std::thread& worker : workers_) { + if (worker.joinable()) worker.join(); // Wait for thread to finish + } + } + + // Prevent copying or moving of the thread pool + ThreadPool(const ThreadPool&) = delete; + ThreadPool(ThreadPool&&) = delete; + ThreadPool& operator=(const ThreadPool&) = delete; + ThreadPool& operator=(ThreadPool&&) = delete; + + private: + /*! \brief Thread container */ + std::vector workers_; + /*! \brief Task queue */ + std::queue> task_queue_; + /*! \brief Mutex to protect task queue */ + std::mutex queue_mutex_; + /*! \brief Condition variable for thread synchronization */ + std::condition_variable queue_condition_; + /*! \brief Flag to indicate thread pool shutdown */ + bool shutdown_ = false; +}; + +void ParallelFor(int low, int high, int num_threads, std::function f) { + if (high - low == 1) { + f(low); + return; + } + + ThreadPool pool(num_threads); + + int total = high - low; + int chunk_size = (total + num_threads - 1) / num_threads; + + for (int t = 0; t < num_threads; ++t) { + int start = low + t * chunk_size; + int end = std::min(start + chunk_size, high); + + if (start >= end) break; // No more iterations to process + + pool.Execute([f, start, end]() { + for (int i = start; i < end; ++i) { + f(i); + } + }); + } + // ThreadPool destructor will wait for all tasks to complete +} + +} // namespace xgrammar + +#endif // XGRAMMAR_SUPPORT_THREAD_POOL_H_ diff --git a/cpp/tokenizer.cc b/cpp/tokenizer.cc index 5280ed6..20b2e32 100644 --- a/cpp/tokenizer.cc +++ b/cpp/tokenizer.cc @@ -10,6 +10,7 @@ #include #include #include +#include #include "support/encoding.h" #include "support/logging.h" @@ -24,16 +25,36 @@ class TokenizerInfo::Impl { bool prepend_space_in_tokenization ); - int GetVocabSize() const { return decoded_vocab_.size(); } + int GetVocabSize() const { return vocab_size_; } VocabType GetVocabType() const { return vocab_type_; } bool GetPrependSpaceInTokenization() const { return prepend_space_in_tokenization_; } const std::vector& GetDecodedVocab() { return decoded_vocab_; } + const std::vector& GetStopTokenIds() const { return stop_token_ids_; } + const std::vector& GetSpecialTokenIds() const { return special_token_ids_; } + const std::vector>& GetSortedDecodedVocab() const { + return sorted_decoded_vocab_; + } + std::string DumpMetadata() const; private: + /*! \brief The vocabulary type. */ VocabType vocab_type_; + /*! \brief Whether to prepend space in tokenization. */ bool prepend_space_in_tokenization_; + /*! \brief The vocabulary. Special tokens are included. */ std::vector decoded_vocab_; + /*! \brief The size of the vocabulary. */ + int vocab_size_; + /*! \brief All (id, token) pairs sorted in lexicographic order. This sorting is done to + * maximize prefix reuse during matching. Special tokens and stop tokens are not included. */ + std::vector> sorted_decoded_vocab_; + /*! \brief The stop tokens. When the GrammarMatcher can reach the end of the grammar, + * stop tokens can be accepted. */ + std::vector stop_token_ids_; + /*! \brief The special tokens. These tokens are ignored (masked out) during the grammar-guided + * generation. */ + std::vector special_token_ids_; }; /************* Metadata detection from huggingface tokenizer.json *************/ @@ -166,9 +187,9 @@ inline std::string SpaceReplacerDecoder(const std::string& token) { // \u2581 is the unicode for "lower one eighth block" // UTF8 encoding for \u2581 is 0xE2 0x96 0x81 std::string result; - for (size_t i = 0; i < token.size(); ++i) { - if (i + 2 < token.size() && token[i] == char(0xE2) && token[i + 1] == char(0x96) && - token[i + 2] == char(0x81)) { + for (int i = 0; i < static_cast(token.size()); ++i) { + if (i + 2 < static_cast(token.size()) && token[i] == char(0xE2) && + token[i + 1] == char(0x96) && token[i + 2] == char(0x81)) { result += ' '; i += 2; } else { @@ -246,11 +267,40 @@ TokenizerInfo::Impl::Impl( VocabType vocab_type, bool prepend_space_in_tokenization ) - : vocab_type_(vocab_type), prepend_space_in_tokenization_(prepend_space_in_tokenization) { - decoded_vocab_.reserve(encoded_vocab.size()); + : vocab_type_(vocab_type), + prepend_space_in_tokenization_(prepend_space_in_tokenization), + vocab_size_(encoded_vocab.size()) { + decoded_vocab_.reserve(vocab_size_); for (const auto& item : encoded_vocab) { decoded_vocab_.emplace_back(DecodeToken(item, vocab_type_)); } + + for (int i = 0; i < static_cast(decoded_vocab_.size()); ++i) { + const auto& token = decoded_vocab_[i]; + // TODO(yixin): Now we detect stop tokens from the token string. We should be able to pass + // the stop token set in. + // LLaMA2: + // LLaMA3: <|end_of_text|>, <|eot_id|> + // Phi-2: <|endoftext|> + // Gemma: , + if (token == "" || token == "<|end_of_text|>" || token == "<|eot_id|>" || + token == "<|endoftext|>" || token == "" || token == "<|eos|>" || + token == "" || token == "") { + stop_token_ids_.push_back(i); + } else if ((token[0] == '<' && token.back() == '>' && token.size() >= 3) || + token == "[@BOS@]") { + // gemma treats [@BOS@] as a special token + special_token_ids_.push_back(i); + } else { + sorted_decoded_vocab_.push_back({i, token}); + } + } + + auto f_compare_token = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(sorted_decoded_vocab_.begin(), sorted_decoded_vocab_.end(), f_compare_token); } std::string TokenizerInfo::Impl::DumpMetadata() const { @@ -276,6 +326,16 @@ bool TokenizerInfo::GetPrependSpaceInTokenization() const { const std::vector& TokenizerInfo::GetDecodedVocab() const { return pimpl_->GetDecodedVocab(); } +const std::vector& TokenizerInfo::GetStopTokenIds() const { + return pimpl_->GetStopTokenIds(); +} +const std::vector& TokenizerInfo::GetSpecialTokenIds() const { + return pimpl_->GetSpecialTokenIds(); +} +const std::vector>& TokenizerInfo::GetSortedDecodedVocab() const { + return pimpl_->GetSortedDecodedVocab(); +} + std::string TokenizerInfo::DumpMetadata() const { return pimpl_->DumpMetadata(); } TokenizerInfo TokenizerInfo::FromVocabAndMetadata( diff --git a/include/xgrammar/xgrammar.h b/include/xgrammar/xgrammar.h index 15be530..28ebc85 100644 --- a/include/xgrammar/xgrammar.h +++ b/include/xgrammar/xgrammar.h @@ -171,6 +171,9 @@ class TokenizerInfo { VocabType GetVocabType() const; bool GetPrependSpaceInTokenization() const; const std::vector& GetDecodedVocab() const; + const std::vector& GetStopTokenIds() const; + const std::vector& GetSpecialTokenIds() const; + const std::vector>& GetSortedDecodedVocab() const; static TokenizerInfo FromHuggingFace( const std::vector& encoded_vocab, const std::string& backend_str @@ -197,9 +200,13 @@ class CompiledGrammar { * \param grammar The grammar that the matcher follows. * \param decoded_vocab The tokens that the matcher requires for matching. */ - CompiledGrammar(const BNFGrammar& grammar, const std::vector& decoded_vocab); + CompiledGrammar( + const BNFGrammar& grammar, const std::vector& decoded_vocab, int max_threads = 8 + ); - CompiledGrammar(const BNFGrammar& grammar, const TokenizerInfo& tokenizer_info); + CompiledGrammar( + const BNFGrammar& grammar, const TokenizerInfo& tokenizer_info, int max_threads = 8 + ); XGRAMMAR_DEFINE_PIMPL_METHODS(CompiledGrammar); }; @@ -323,9 +330,9 @@ class CachedGrammarCompiler { * create grammar state compiled grammars with this vocabulary. * \param decoded_vocab The vocabulary that the grammar will use. */ - CachedGrammarCompiler(const std::vector& decoded_vocab); + CachedGrammarCompiler(const std::vector& decoded_vocab, int max_threads = 8); - CachedGrammarCompiler(const TokenizerInfo& tokenizer_info); + CachedGrammarCompiler(const TokenizerInfo& tokenizer_info, int max_threads = 8); /*! \brief Get the compiled grammar for pure JSON. */ CompiledGrammar CompileJSONGrammar(); diff --git a/python/xgrammar/xgrammar.py b/python/xgrammar/xgrammar.py index e679b89..e03022f 100644 --- a/python/xgrammar/xgrammar.py +++ b/python/xgrammar/xgrammar.py @@ -475,16 +475,19 @@ class CompiledGrammar(XGObject): grammar : BNFGrammar The BNF grammar to match. - tokenizer_or_vocab : Union[None, PreTrainedTokenizerBase, TokenizerInfo, List[Union[bytes, str]]], default: None - The tokenizer or the vocabulary. It can be None, a huggingface tokenizer, a tokenizer info, - or a list of raw tokens. + tokenizer_info : Optional[TokenizerInfo], default: None + The tokenizer info. If None, the grammar matcher can only handle string operations. - None means there is no vocabulary, then the grammar matcher can only handle string - operations. If a huggingface tokenizer or a list of raw tokens are provided, a TokenizerInfo - object will be constructed from the tokenizer or the vocabulary. + max_threads : int, default: 8 + The maximum number of threads used to compile the grammar. """ - def __init__(self, grammar: BNFGrammar, tokenizer_info: Optional[TokenizerInfo] = None) -> None: + def __init__( + self, + grammar: BNFGrammar, + tokenizer_info: Optional[TokenizerInfo] = None, + max_threads: int = 8, + ) -> None: if tokenizer_info is None: tokenizer_info = TokenizerInfo([]) elif not isinstance(tokenizer_info, TokenizerInfo): @@ -493,7 +496,9 @@ def __init__(self, grammar: BNFGrammar, tokenizer_info: Optional[TokenizerInfo] "to CompiledGrammar." ) - self.init_with_handle(_core.CompiledGrammar(grammar.handle, tokenizer_info.handle)) + self.init_with_handle( + _core.CompiledGrammar(grammar.handle, tokenizer_info.handle, max_threads) + ) class CachedGrammarCompiler(XGObject): @@ -503,18 +508,21 @@ class CachedGrammarCompiler(XGObject): Parameters ---------- - tokenizer_or_vocab : Union[PreTrainedTokenizerBase, TokenizerInfo, List[Union[bytes, str]]] - The tokenizer or the vocabulary. Its meaning is the same as in GrammarMatcher. + tokenizer_info : TokenizerInfo + The tokenizer info. + + max_threads : int, default: 8 + The maximum number of threads used to compile the grammar. """ - def __init__(self, tokenizer_info: TokenizerInfo): + def __init__(self, tokenizer_info: TokenizerInfo, max_threads: int = 8): if not isinstance(tokenizer_info, TokenizerInfo): raise ValueError( "Please convert the tokenizer to TokenizerInfo before passing it " "to CachedGrammarCompiler." ) - self.init_with_handle(_core.CachedGrammarCompiler(tokenizer_info.handle)) + self.init_with_handle(_core.CachedGrammarCompiler(tokenizer_info.handle, max_threads)) def compile_json_grammar(self) -> CompiledGrammar: """Get CompiledGrammar from the standard JSON.