From adb76d20f2958a88e388b18b291940c28ce4c687 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= Date: Fri, 6 Dec 2024 07:18:06 +0100 Subject: [PATCH] Simple pattern support + integer ranges (#106) This PR introduces simple regex pattern support + integer ranges. --- cpp/json_schema_converter.cc | 163 ++++++++++++++++++++- tests/python/test_json_schema_converter.py | 53 ++++++- 2 files changed, 209 insertions(+), 7 deletions(-) diff --git a/cpp/json_schema_converter.cc b/cpp/json_schema_converter.cc index 4a3612a5..b6db7ec9 100644 --- a/cpp/json_schema_converter.cc +++ b/cpp/json_schema_converter.cc @@ -11,11 +11,13 @@ #include #include #include +#include #include #include #include #include +#include "regex_converter.h" #include "support/logging.h" namespace xgrammar { @@ -652,6 +654,122 @@ std::string JSONSchemaConverter::VisitAny( kBasicArray + " | " + kBasicObject; } +std::string generateRangeRegex(std::optional start, std::optional end) { + if (!start && !end) { + return "^\\d+$"; // Match any positive number if no start or end is specified + } + + std::vector positiveParts; + std::vector negativeParts; + + auto generateGroup = [](int s, int e) -> std::string { + std::ostringstream oss; + + if (s == e) { + return std::to_string(s); + } + + std::string startStr = std::to_string(s); + std::string endStr = std::to_string(e); + + size_t commonPrefix = 0; + while (commonPrefix < startStr.size() && startStr[commonPrefix] == endStr[commonPrefix]) { + ++commonPrefix; + } + + if (commonPrefix > 0) { + oss << startStr.substr(0, commonPrefix); + } + + if (commonPrefix < startStr.size()) { + oss << "["; + oss << startStr[commonPrefix]; + if (startStr[commonPrefix] != endStr[commonPrefix]) { + oss << "-" << endStr[commonPrefix]; + } + oss << "]"; + + // Add trailing zero ranges + if (commonPrefix + 1 < startStr.size()) { + oss << "\\d{" << startStr.size() - commonPrefix - 1 << "}"; + } + } + + return oss.str(); + }; + + if (start && end) { + int rangeStart = start.value(); + int rangeEnd = end.value(); + + // Handle negative part of the range + if (rangeStart < 0) { + int negativeEnd = std::min(rangeEnd, -1); + while (rangeStart <= negativeEnd) { + int nextRangeEnd = (rangeStart / 10 - 1) * 10 + 9; // Handle negative tens group + if (nextRangeEnd < negativeEnd) { + nextRangeEnd = negativeEnd; + } + negativeParts.push_back("-" + generateGroup(-nextRangeEnd, -rangeStart)); + rangeStart = nextRangeEnd + 1; + } + } + + // Handle positive part of the range + if (rangeEnd >= 0) { + rangeStart = std::max(rangeStart, 0); + while (rangeStart <= rangeEnd) { + int nextRangeEnd = (rangeStart / 10 + 1) * 10 - 1; // Handle positive tens group + if (nextRangeEnd > rangeEnd) { + nextRangeEnd = rangeEnd; + } + positiveParts.push_back(generateGroup(rangeStart, nextRangeEnd)); + rangeStart = nextRangeEnd + 1; + } + } + } else if (start) { + if (start.value() < 0) { + negativeParts.push_back("-" + std::to_string(-start.value()) + "\\d*"); + } else { + positiveParts.push_back(std::to_string(start.value()) + "\\d*"); + } + } else if (end) { + if (end.value() < 0) { + negativeParts.push_back("-" + std::to_string(-end.value())); + } else { + positiveParts.push_back(std::to_string(end.value())); + } + } + + std::ostringstream result; + result << "^("; + if (!negativeParts.empty()) { + result << "("; + for (size_t i = 0; i < negativeParts.size(); ++i) { + if (i > 0) { + result << "|"; + } + result << negativeParts[i]; + } + result << ")"; + if (!positiveParts.empty()) { + result << "|"; + } + } + if (!positiveParts.empty()) { + result << "("; + for (size_t i = 0; i < positiveParts.size(); ++i) { + if (i > 0) { + result << "|"; + } + result << positiveParts[i]; + } + result << ")"; + } + result << ")$"; + return result.str(); +} + std::string JSONSchemaConverter::VisitInteger( const picojson::object& schema, const std::string& rule_name ) { @@ -661,12 +779,38 @@ std::string JSONSchemaConverter::VisitInteger( schema, { "multipleOf", - "minimum", - "maximum", - "exclusiveMinimum", - "exclusiveMaximum", } ); + std::string range_regex = ""; + try { + if (schema.count("minimum") || schema.count("maximum") || schema.count("exclusiveMinimum") || + schema.count("exclusiveMaximum")) { + std::optional start, end; + if (schema.count("minimum")) { + double start_double = schema.at("minimum").get(); + start = static_cast(start_double); + } + if (schema.count("exclusiveMinimum")) { + double start_double = schema.at("exclusiveMinimum").get(); + start = static_cast(start_double); + } + if (schema.count("maximum")) { + double end_double = schema.at("maximum").get(); + end = static_cast(end_double); + } + if (schema.count("exclusiveMaximum")) { + double end_double = schema.at("exclusiveMaximum").get(); + end = static_cast(end_double); + } + range_regex = generateRangeRegex(start, end); + } + if (!range_regex.empty()) { + std::string converted_regex = RegexToEBNF(range_regex, false); + return converted_regex; // not " " for numbers + } + } catch (const std::exception& e) { + XGRAMMAR_LOG(WARNING) << "Failed to convert range for integer schema"; + } return "(\"0\" | \"-\"? [1-9] [0-9]*)"; } @@ -698,10 +842,19 @@ std::string JSONSchemaConverter::VisitString( { "minLength", "maxLength", - "pattern", "format", } ); + if (schema.count("pattern")) { + try { + std::string regex_pattern = schema.at("pattern").get(); + std::string converted_regex = RegexToEBNF(regex_pattern, false); + return "\"\\\"\" " + converted_regex + " \"\\\"\""; + } catch (const std::exception& e) { + XGRAMMAR_LOG(WARNING) << "Failed to convert regex pattern " + << schema.at("pattern").get(); + } + } return "[\"] " + kBasicStringSub; } diff --git a/tests/python/test_json_schema_converter.py b/tests/python/test_json_schema_converter.py index 31fb1cd7..4234706e 100644 --- a/tests/python/test_json_schema_converter.py +++ b/tests/python/test_json_schema_converter.py @@ -1,10 +1,10 @@ import json import sys from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union import pytest -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import BaseModel, Field, TypeAdapter, WithJsonSchema, create_model import xgrammar as xgr from xgrammar.testing import _json_schema_to_ebnf, _match_grammar_with_string @@ -481,5 +481,54 @@ class MainModelSpace(BaseModel): check_schema_with_json(MainModelSpace.model_json_schema(by_alias=True), instance_space_str) +def test_restricted_string() -> None: + class MainModel(BaseModel): + restricted_string: str = Field(..., pattern=r"[a-f]") + + instance = MainModel(restricted_string="a") + instance_str = json.dumps(instance.model_dump(mode="json")) + check_schema_with_json(MainModel.model_json_schema(), instance_str) + + check_schema_with_json( + MainModel.model_json_schema(), '{"restricted_string": "j"}', check_accepted=False + ) + + +def test_complex_restrictions() -> None: + + string_without_quotes = Annotated[str, WithJsonSchema({"type": "string", "pattern": r"[^\"]*"})] + + class RestrictedModel(BaseModel): + restricted_string: string_without_quotes + restricted_value: Annotated[int, Field(strict=True, ge=0, lt=44)] + + # working instance + instance = RestrictedModel(restricted_string="a", restricted_value=42) + instance_str = json.dumps(instance.model_dump(mode="json")) + check_schema_with_json(RestrictedModel.model_json_schema(), instance_str) + + check_schema_with_json( + RestrictedModel.model_json_schema(), + '{"restricted_string": "j", "restricted_value": 45}', + check_accepted=False, + ) + + +def test_dynamic_model() -> None: + class MainModel(BaseModel): + restricted_string: Annotated[str, WithJsonSchema({"type": "string", "pattern": r"[a-f]"})] + + additional_fields = {} + additional_fields["restricted_string_dynamic"] = ( + Annotated[str, WithJsonSchema({"type": "string", "pattern": r"[a-x]"})], + ..., + ) + + CompleteModel = create_model("CompleteModel", **additional_fields) + instance = CompleteModel(restricted_string="a", restricted_string_dynamic="j") + instance_str = json.dumps(instance.model_dump(mode="json")) + check_schema_with_json(CompleteModel.model_json_schema(), instance_str) + + if __name__ == "__main__": pytest.main(sys.argv)