Skip to content

Commit

Permalink
[Refactor][Breaking] Rename start->begin in StructuralTag (#221)
Browse files Browse the repository at this point in the history
This PR renames the field "start" to "begin" in StructuralTagItem. The
name "begin" better aligns with the naming convention in python.
  • Loading branch information
Ubospica authored Feb 26, 2025
1 parent 6996ded commit 3dac5ae
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 32 deletions.
2 changes: 1 addition & 1 deletion cpp/grammar_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ template <>
struct hash<xgrammar::StructuralTagItem> {
size_t operator()(const xgrammar::StructuralTagItem& tag) const {
return xgrammar::HashCombine(
std::hash<std::string>{}(tag.start),
std::hash<std::string>{}(tag.begin),
std::hash<std::string>{}(tag.schema),
std::hash<std::string>{}(tag.end)
);
Expand Down
10 changes: 5 additions & 5 deletions cpp/grammar_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -639,11 +639,11 @@ class StructuralTagGrammarCreatorImpl : public SubGrammarAdder {
std::vector<int32_t> seq_elements;
seq_elements.reserve(3);

// Add start suffix (everything after trigger)
XGRAMMAR_DCHECK(tag.start.size() >= triggers[i].size())
<< "Tag start must be at least as long as trigger";
if (tag.start.size() > triggers[i].size()) {
seq_elements.push_back(builder_.AddByteString(tag.start.substr(triggers[i].size())));
// Add begin suffix (everything after trigger)
XGRAMMAR_DCHECK(tag.begin.size() >= triggers[i].size())
<< "Tag begin must be at least as long as trigger";
if (tag.begin.size() > triggers[i].size()) {
seq_elements.push_back(builder_.AddByteString(tag.begin.substr(triggers[i].size())));
}

// Create and visit schema grammar for this tag
Expand Down
12 changes: 6 additions & 6 deletions cpp/structural_tag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,24 @@ Grammar StructuralTagToGrammar(
bool found = false;
for (int it_trigger = 0; it_trigger < static_cast<int>(sorted_triggers.size()); ++it_trigger) {
const auto& trigger = sorted_triggers[it_trigger];
if (trigger.size() <= tag.start.size() &&
std::string_view(tag.start).substr(0, trigger.size()) == trigger) {
if (trigger.size() <= tag.begin.size() &&
std::string_view(tag.begin).substr(0, trigger.size()) == trigger) {
tag_groups[it_trigger].push_back(std::make_pair(tag, schema_grammars[it_tag]));
found = true;
break;
}
}
XGRAMMAR_CHECK(found) << "Tag " << tag.start << " does not match any trigger";
XGRAMMAR_CHECK(found) << "Tag " << tag.begin << " does not match any trigger";
}

// Step 3: Combine the tags to form a grammar
// root ::= TagDispatch((trigger1, rule1), (trigger2, rule2), ...)
// Suppose tag1 and tag2 matches trigger1, then
// rule1 ::= (tag1.start[trigger1.size():] + ToEBNF(tag1.schema) + tag1.end) |
// (tag2.start[trigger1.size():] + ToEBNF(tag2.schema) + tag2.end) | ...
// rule1 ::= (tag1.begin[trigger1.size():] + ToEBNF(tag1.schema) + tag1.end) |
// (tag2.begin[trigger1.size():] + ToEBNF(tag2.schema) + tag2.end) | ...
//
// Suppose tag3 matches trigger2, then
// rule2 ::= (tag3.start[trigger2.size():] + ToEBNF(tag3.schema) + tag3.end)
// rule2 ::= (tag3.begin[trigger2.size():] + ToEBNF(tag3.schema) + tag3.end)
//
// ...
return StructuralTagGrammarCreator::Apply(sorted_triggers, tag_groups);
Expand Down
4 changes: 2 additions & 2 deletions include/xgrammar/grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
namespace xgrammar {

struct StructuralTagItem {
std::string start;
std::string begin;
std::string schema;
std::string end;

bool operator==(const StructuralTagItem& other) const {
return start == other.start && schema == other.schema && end == other.end;
return begin == other.begin && schema == other.schema && end == other.end;
}
};

Expand Down
2 changes: 1 addition & 1 deletion python/xgrammar/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def compile_structural_tag(
compiled_grammar : CompiledGrammar
The compiled grammar.
"""
tags_tuple = [(tag.start, _handle_pydantic_schema(tag.schema_), tag.end) for tag in tags]
tags_tuple = [(tag.begin, _handle_pydantic_schema(tag.schema_), tag.end) for tag in tags]
return CompiledGrammar._create_from_handle(
self._handle.compile_structural_tag(tags_tuple, triggers)
)
Expand Down
18 changes: 9 additions & 9 deletions python/xgrammar/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class StructuralTagItem(BaseModel):
Attributes
----------
start : str
The start tag.
begin : str
The begin tag.
schema_ : Union[str, Type[BaseModel]]
The schema.
Expand All @@ -23,7 +23,7 @@ class StructuralTagItem(BaseModel):
The end tag.
"""

start: str
begin: str
schema_: Union[str, Type[BaseModel]] = Field(alias="schema")
end: str

Expand Down Expand Up @@ -190,14 +190,14 @@ def from_structural_tag(tags: List[StructuralTagItem], triggers: List[str]) -> "
The tags parameter is used to specify the output pattern. It is especially useful for LLM
function calling, where the pattern is:
<function=func_name>{"arg1": ..., "arg2": ...}</function>.
This pattern consists of three parts: a start tag (<function=func_name>), a parameter list
This pattern consists of three parts: a begin tag (<function=func_name>), a parameter list
according to some schema ({"arg1": ..., "arg2": ...}), and an end tag (</function>). This
pattern can be described in a StructuralTagItem with a start tag, a schema, and an end tag.
pattern can be described in a StructuralTagItem with a begin tag, a schema, and an end tag.
The structural tag is able to handle multiple such patterns by passing them into multiple
tags.
The triggers parameter is used to trigger the dispatching of different grammars. The trigger
should be a prefix of a provided start tag. When the trigger is encountered, the
should be a prefix of a provided begin tag. When the trigger is encountered, the
corresponding tag should be used to constrain the following output. There can be multiple
tags matching the same trigger. Then if the trigger is encountered, the following output
should match one of the tags. For example, in function calling, the triggers can be
Expand Down Expand Up @@ -235,13 +235,13 @@ def from_structural_tag(tags: List[StructuralTagItem], triggers: List[str]) -> "
>>> arg3: float
>>> arg4: List[str]
>>> tags = [
>>> StructuralTagItem(start="<function=f>", schema=Schema1, end="</function>"),
>>> StructuralTagItem(start="<function=g>", schema=Schema2, end="</function>"),
>>> StructuralTagItem(begin="<function=f>", schema=Schema1, end="</function>"),
>>> StructuralTagItem(begin="<function=g>", schema=Schema2, end="</function>"),
>>> ]
>>> triggers = ["<function="]
>>> grammar = Grammar.from_structural_tag(tags, triggers)
"""
tags_tuple = [(tag.start, _handle_pydantic_schema(tag.schema_), tag.end) for tag in tags]
tags_tuple = [(tag.begin, _handle_pydantic_schema(tag.schema_), tag.end) for tag in tags]
return Grammar._create_from_handle(_core.Grammar.from_structural_tag(tags_tuple, triggers))

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ class Schema2(BaseModel):
arg4: List[str]

tags = [
xgr.StructuralTagItem(start="<function=f1>", schema=Schema1, end="</function>"),
xgr.StructuralTagItem(start="<function=f2>", schema=Schema1, end="</function>"),
xgr.StructuralTagItem(start="<function=g>", schema=Schema2, end="</function>"),
xgr.StructuralTagItem(begin="<function=f1>", schema=Schema1, end="</function>"),
xgr.StructuralTagItem(begin="<function=f2>", schema=Schema1, end="</function>"),
xgr.StructuralTagItem(begin="<function=g>", schema=Schema2, end="</function>"),
]
# in real cases, we should use one trigger: "<function=" and dispatch to two tags
# but here we use two triggers for testing such cases
Expand Down Expand Up @@ -203,9 +203,9 @@ class Schema2(BaseModel):
arg4: List[str]

tags = [
xgr.StructuralTagItem(start="<function=f1>", schema=Schema1, end="</function>"),
xgr.StructuralTagItem(start="<function=f2>", schema=Schema1, end="</function>"),
xgr.StructuralTagItem(start="<function=g>", schema=Schema2, end="</function>"),
xgr.StructuralTagItem(begin="<function=f1>", schema=Schema1, end="</function>"),
xgr.StructuralTagItem(begin="<function=f2>", schema=Schema1, end="</function>"),
xgr.StructuralTagItem(begin="<function=g>", schema=Schema2, end="</function>"),
]

# in real cases, we should use one trigger: "<function=" and dispatch to two tags
Expand All @@ -232,10 +232,10 @@ class Schema2(BaseModel):
# Set up grammar from schemas
tags = [
xgr.StructuralTagItem(
start="<function=f>", schema=json.dumps(Schema1.model_json_schema()), end="</function>"
begin="<function=f>", schema=json.dumps(Schema1.model_json_schema()), end="</function>"
),
xgr.StructuralTagItem(
start="<function=g>", schema=json.dumps(Schema2.model_json_schema()), end="</function>"
begin="<function=g>", schema=json.dumps(Schema2.model_json_schema()), end="</function>"
),
]
triggers = ["<function=f", "<function=g"]
Expand Down

0 comments on commit 3dac5ae

Please sign in to comment.