Skip to content

Commit

Permalink
[FunctionCalling] Support TagDispatch (#146)
Browse files Browse the repository at this point in the history
This PR supports TagDispatch. Its grammar is:
```
rule ::= TagDispatch(("tag1", rule1), ("tag2", rule2), ...)
```
And its semantic is: allowing any input at start. When input matches
tag1, let the following input match rule1 until rule1 finishes. The same
for rule1 and tag2. When the rule finished, get to start, allow any
input, and check tag again.

Its backend is not supported yet.
  • Loading branch information
Ubospica authored Jan 10, 2025
1 parent b8c033a commit 154fda1
Show file tree
Hide file tree
Showing 14 changed files with 284 additions and 60 deletions.
2 changes: 1 addition & 1 deletion cpp/compiled_grammar_data_structure.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <vector>

// matcher_data_structure.h is included to use StackElement
#include "grammar_matcher_data_structure.h"
#include "persistent_stack.h"
#include "support/dynamic_bitset.h"
#include "support/utils.h"

Expand Down
15 changes: 15 additions & 0 deletions cpp/grammar_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ class GrammarBuilder {
);
}

/*!
* \brief Add a RuleExpr for tag dispatch.
* \param tag_dispatch_list A list of pairs of tag_expr_id and rule_id.
*/
int32_t AddTagDispatch(const std::vector<std::pair<int32_t, int32_t>>& tag_dispatch_list) {
std::vector<int32_t> data;
data.reserve(tag_dispatch_list.size() * 2);
for (const auto& [tag_expr_id, rule_id] : tag_dispatch_list) {
data.push_back(tag_expr_id);
data.push_back(rule_id);
}
return AddRuleExpr({RuleExprType::kTagDispatch, data.data(), static_cast<int32_t>(data.size())}
);
}

size_t NumRuleExprs() const { return grammar_->NumRuleExprs(); }
/*! \brief Get the rule_expr with the given id. */
RuleExpr GetRuleExpr(int32_t rule_expr_id) { return grammar_->GetRuleExpr(rule_expr_id); }
Expand Down
3 changes: 3 additions & 0 deletions cpp/grammar_data_structure.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ class Grammar::Impl {
kSequence,
// data format: [rule_expr_id0, rule_expr_id1, ...]
kChoices,
// data format: [tag_expr0, rule_id0, tag_expr1, rule_id1, ...]
// tag_expr should be a byte string, and rule_id should be a rule id
kTagDispatch,
};

/*! \brief The object representing a rule expr. */
Expand Down
34 changes: 20 additions & 14 deletions cpp/grammar_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class SingleElementExprEliminator : public GrammarMutator {
if (lookahead_assertion_id == -1) {
return -1;
}
auto rule_expr = grammar_->GetRuleExpr(lookahead_assertion_id);
auto rule_expr = old_grammar_->GetRuleExpr(lookahead_assertion_id);
XGRAMMAR_CHECK(rule_expr.type == RuleExprType::kSequence);

std::vector<int32_t> sequence_ids;
Expand Down Expand Up @@ -97,26 +97,26 @@ class NestedRuleUnwrapper : public GrammarMutator {

Grammar Apply(const Grammar& grammar) final {
Init(grammar);
for (int i = 0; i < static_cast<int>(grammar_->NumRules()); ++i) {
builder_.AddEmptyRule(grammar_->GetRule(i).name);
for (int i = 0; i < static_cast<int>(old_grammar_->NumRules()); ++i) {
builder_.AddEmptyRule(old_grammar_->GetRule(i).name);
}
for (int i = 0; i < static_cast<int>(grammar_->NumRules()); ++i) {
auto rule = grammar_->GetRule(i);
auto rule_expr = grammar_->GetRuleExpr(rule.body_expr_id);
for (int i = 0; i < static_cast<int>(old_grammar_->NumRules()); ++i) {
auto rule = old_grammar_->GetRule(i);
auto rule_expr = old_grammar_->GetRuleExpr(rule.body_expr_id);
cur_rule_name_ = rule.name;
auto new_body_expr_id = VisitRuleBody(rule_expr);
builder_.UpdateRuleBody(i, new_body_expr_id);
builder_.AddLookaheadAssertion(i, VisitLookaheadAssertion(rule.lookahead_assertion_id));
}
return builder_.Get(grammar_->GetRootRule().name);
return builder_.Get(old_grammar_->GetRootRule().name);
}

private:
int32_t VisitLookaheadAssertion(int32_t lookahead_assertion_id) final {
if (lookahead_assertion_id == -1) {
return -1;
}
auto assertion_expr = grammar_->GetRuleExpr(lookahead_assertion_id);
auto assertion_expr = old_grammar_->GetRuleExpr(lookahead_assertion_id);
return builder_.AddSequence(VisitSequence_(assertion_expr));
}

Expand All @@ -134,6 +134,8 @@ class NestedRuleUnwrapper : public GrammarMutator {
case RuleExprType::kCharacterClassStar:
case RuleExprType::kRuleRef:
return builder_.AddChoices({builder_.AddSequence({builder_.AddRuleExpr(rule_expr)})});
case RuleExprType::kTagDispatch:
return VisitTagDispatch(rule_expr);
default:
XGRAMMAR_LOG(FATAL) << "Unexpected sequence type: " << static_cast<int>(rule_expr.type);
}
Expand All @@ -147,7 +149,7 @@ class NestedRuleUnwrapper : public GrammarMutator {
std::vector<int32_t> new_choice_ids;
bool found_empty = false;
for (auto i : rule_expr) {
auto choice_expr = grammar_->GetRuleExpr(i);
auto choice_expr = old_grammar_->GetRuleExpr(i);
switch (choice_expr.type) {
case RuleExprType::kSequence:
VisitSequenceInChoices(choice_expr, &new_choice_ids, &found_empty);
Expand All @@ -164,6 +166,8 @@ class NestedRuleUnwrapper : public GrammarMutator {
case RuleExprType::kRuleRef:
VisitElementInChoices(choice_expr, &new_choice_ids);
break;
case RuleExprType::kTagDispatch:
XGRAMMAR_LOG(FATAL) << "TagDispatch should not be in choices";
default:
XGRAMMAR_LOG(FATAL) << "Unexpected choice type: " << static_cast<int>(choice_expr.type);
}
Expand Down Expand Up @@ -216,7 +220,7 @@ class NestedRuleUnwrapper : public GrammarMutator {
std::vector<int32_t> VisitSequence_(const RuleExpr& rule_expr) {
std::vector<int32_t> new_sequence_ids;
for (auto i : rule_expr) {
auto element_expr = grammar_->GetRuleExpr(i);
auto element_expr = old_grammar_->GetRuleExpr(i);
switch (element_expr.type) {
case RuleExprType::kSequence:
VisitSequenceInSequence(element_expr, &new_sequence_ids);
Expand All @@ -232,6 +236,8 @@ class NestedRuleUnwrapper : public GrammarMutator {
case RuleExprType::kRuleRef:
VisitElementInSequence(element_expr, &new_sequence_ids);
break;
case RuleExprType::kTagDispatch:
XGRAMMAR_LOG(FATAL) << "TagDispatch should not be in sequence";
default:
XGRAMMAR_LOG(FATAL) << "Unexpected sequence type: "
<< static_cast<int>(element_expr.type);
Expand Down Expand Up @@ -285,7 +291,7 @@ class ByteStringFuser : public GrammarMutator {
std::vector<int32_t> new_sequence_ids;
std::vector<int32_t> cur_byte_string;
for (auto i : rule_expr) {
auto element_expr = grammar_->GetRuleExpr(i);
auto element_expr = old_grammar_->GetRuleExpr(i);
if (element_expr.type == RuleExprType::kByteString) {
cur_byte_string.insert(cur_byte_string.end(), element_expr.begin(), element_expr.end());
continue;
Expand Down Expand Up @@ -315,11 +321,11 @@ std::vector<std::unique_ptr<GrammarMutator>> GrammarNormalizer::GetNormalizerLis

Grammar GrammarNormalizer::Apply(const Grammar& grammar) {
std::vector<std::unique_ptr<GrammarMutator>> normalizer_mutators = GetNormalizerList();
grammar_ = grammar;
old_grammar_ = grammar;
for (auto& mutator : normalizer_mutators) {
grammar_ = mutator->Apply(grammar_);
old_grammar_ = mutator->Apply(old_grammar_);
}
return grammar_;
return old_grammar_;
}

} // namespace xgrammar
38 changes: 28 additions & 10 deletions cpp/grammar_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class GrammarFunctor {
virtual ReturnType Apply(const Grammar& grammar) {
Init(grammar);
if constexpr (std::is_same<T, void>::value) {
for (int i = 0; i < static_cast<int>(grammar_->NumRules()); ++i) {
auto rule = grammar_->GetRule(i);
for (int i = 0; i < static_cast<int>(old_grammar_->NumRules()); ++i) {
auto rule = old_grammar_->GetRule(i);
cur_rule_name_ = rule.name;
VisitExpr(rule.body_expr_id);
VisitLookaheadAssertion(rule.lookahead_assertion_id);
Expand All @@ -52,18 +52,18 @@ class GrammarFunctor {
std::is_same<ReturnType, Grammar>::value) {
// First add empty rules to ensure the new rule ids the same as the old ones, then update
// the rule bodies
for (int i = 0; i < static_cast<int>(grammar_->NumRules()); ++i) {
builder_.AddEmptyRule(grammar_->GetRule(i).name);
for (int i = 0; i < static_cast<int>(old_grammar_->NumRules()); ++i) {
builder_.AddEmptyRule(old_grammar_->GetRule(i).name);
}
for (int i = 0; i < static_cast<int>(grammar_->NumRules()); ++i) {
auto rule = grammar_->GetRule(i);
for (int i = 0; i < static_cast<int>(old_grammar_->NumRules()); ++i) {
auto rule = old_grammar_->GetRule(i);
cur_rule_name_ = rule.name;
auto new_body_expr_id = VisitExpr(rule.body_expr_id);
builder_.UpdateRuleBody(i, new_body_expr_id);
// Handle lookahead assertion
builder_.AddLookaheadAssertion(i, VisitLookaheadAssertion(rule.lookahead_assertion_id));
}
return builder_.Get(grammar_->GetRootRule().name);
return builder_.Get(old_grammar_->GetRootRule().name);
} else {
return ReturnType();
}
Expand All @@ -79,7 +79,7 @@ class GrammarFunctor {

/*! \brief Initialize the functor. Should be called at the beginning of Apply(). */
virtual void Init(const Grammar& grammar) {
grammar_ = grammar;
old_grammar_ = grammar;
builder_ = GrammarBuilder();
}

Expand All @@ -93,7 +93,7 @@ class GrammarFunctor {

/*! \brief Visit a RuleExpr by id. */
virtual T VisitExpr(int32_t old_rule_expr_id) {
return VisitExpr(grammar_->GetRuleExpr(old_rule_expr_id));
return VisitExpr(old_grammar_->GetRuleExpr(old_rule_expr_id));
}

/*! \brief Visit a RuleExpr. Dispatch to the corresponding Visit function. */
Expand All @@ -113,6 +113,8 @@ class GrammarFunctor {
return VisitCharacterClassStar(rule_expr);
case RuleExprType::kRuleRef:
return VisitRuleRef(rule_expr);
case RuleExprType::kTagDispatch:
return VisitTagDispatch(rule_expr);
default:
XGRAMMAR_LOG(FATAL) << "Unexpected sequence type: " << static_cast<int>(rule_expr.type);
}
Expand Down Expand Up @@ -152,6 +154,22 @@ class GrammarFunctor {
}
}

virtual T VisitTagDispatch(const RuleExpr& rule_expr) {
if constexpr (std::is_same<T, void>::value) {
for (int i = 0; i < rule_expr.size(); i += 2) {
VisitExpr(rule_expr[i]);
}
} else if constexpr (std::is_same<T, int32_t>::value) {
std::vector<std::pair<int32_t, int32_t>> tag_dispatch_list;
for (int i = 0; i < rule_expr.size(); i += 2) {
tag_dispatch_list.push_back({VisitExpr(rule_expr[i]), rule_expr[i + 1]});
}
return builder_.AddTagDispatch(tag_dispatch_list);
} else {
return T();
}
}

/*! \brief Visit an element RuleExpr, including empty string, character class, and rule ref. */
virtual T VisitElement(const RuleExpr& rule_expr) {
if constexpr (std::is_same<T, void>::value) {
Expand Down Expand Up @@ -179,7 +197,7 @@ class GrammarFunctor {
virtual T VisitRuleRef(const RuleExpr& rule_expr) { return VisitElement(rule_expr); }

/*! \brief The grammar to visit or mutate. */
Grammar grammar_;
Grammar old_grammar_;
/*!
* \brief The builder to build the new grammar. It is empty when the mutator is constructed, and
* can be used to build a new grammar in subclasses.
Expand Down
4 changes: 2 additions & 2 deletions cpp/grammar_matcher.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* Copyright (c) 2024 by Contributors
* \file xgrammar/matcher.cc
* \file xgrammar/grammar_matcher.cc
* \brief This source file implement the matcher class, especially the logic related to LLM tokens,
* like accepting tokens, leveraging the token mask cache to generate the mask, etc. matcher_base.cc
* implements the basic matching algorithm from strings to grammar.
Expand All @@ -14,8 +14,8 @@
#include "compiled_grammar_data_structure.h"
#include "grammar_data_structure.h"
#include "grammar_matcher_base.h"
#include "grammar_matcher_data_structure.h"
#include "grammar_serializer.h"
#include "persistent_stack.h"
#include "support/dynamic_bitset.h"
#include "support/encoding.h"
#include "support/int_set.h"
Expand Down
4 changes: 2 additions & 2 deletions cpp/grammar_matcher_base.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* Copyright (c) 2024 by Contributors
* \file xgrammar/matcher_base.cc
* \file xgrammar/grammar_matcher_base.cc
* \brief This source file implements the basic matching algorithm from strings to grammar.
* matcher.cc will handle the logic related to LLM tokens, like accepting tokens, leveraging the
* token mask cache to generate the mask, etc.
Expand All @@ -12,7 +12,7 @@
#include <vector>

#include "grammar_data_structure.h"
#include "grammar_matcher_data_structure.h"
#include "persistent_stack.h"
#include "support/encoding.h"

namespace xgrammar {
Expand Down
8 changes: 4 additions & 4 deletions cpp/grammar_matcher_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
* \file xgrammar/grammar_matcher_base.h
* \brief The base class of GrammarMatcher. It implements a character-based matching automata.
*/
#ifndef XGRAMMAR_MATCHER_BASE_H_
#define XGRAMMAR_MATCHER_BASE_H_
#ifndef XGRAMMAR_GRAMMAR_MATCHER_BASE_H_
#define XGRAMMAR_GRAMMAR_MATCHER_BASE_H_

#include <xgrammar/grammar.h>

Expand All @@ -13,7 +13,7 @@
#include <vector>

#include "grammar_data_structure.h"
#include "grammar_matcher_data_structure.h"
#include "persistent_stack.h"

namespace xgrammar {

Expand Down Expand Up @@ -125,4 +125,4 @@ class GrammarMatcherBase {

} // namespace xgrammar

#endif // XGRAMMAR_MATCHER_BASE_H_
#endif // XGRAMMAR_GRAMMAR_MATCHER_BASE_H_
Loading

0 comments on commit 154fda1

Please sign in to comment.