From f1979d5e10352175880aeabaccc01000d70ded7c Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 24 Feb 2025 22:05:44 -0800 Subject: [PATCH] regex ut part 1 wip uri and uri-reference ipv6 json-pointer uri-template refactor email --- cpp/json_schema_converter.cc | 173 ++++++- tests/python/test_json_schema_converter.py | 504 +++++++++++++++++++++ 2 files changed, 675 insertions(+), 2 deletions(-) diff --git a/cpp/json_schema_converter.cc b/cpp/json_schema_converter.cc index 45d090a5..93e0bb73 100644 --- a/cpp/json_schema_converter.cc +++ b/cpp/json_schema_converter.cc @@ -842,8 +842,7 @@ std::string JSONSchemaConverter::VisitSchema( } else if (schema_obj.count("properties") || schema_obj.count("additionalProperties") || schema_obj.count("unevaluatedProperties")) { return VisitObject(schema_obj, rule_name); - } else if (schema_obj.count("items") || schema_obj.count("prefixItems") || - schema_obj.count("unevaluatedItems")) { + } else if (schema_obj.count("items") || schema_obj.count("prefixItems") || schema_obj.count("unevaluatedItems")) { return VisitArray(schema_obj, rule_name); } @@ -1208,6 +1207,176 @@ std::string JSONSchemaConverter::VisitString( ) { XGRAMMAR_CHECK(schema.count("type")); XGRAMMAR_CHECK(schema.at("type").get() == "string"); + if (schema.count("format")) { + std::string format = schema.at("format").get(); + if (format == "email") { + // refer to RFC 5321 and RFC 5322, but skipping `address-literal` at + // RFC 5321 section 4.1.2 currently + std::string atext = "[\\w!#$%&'*+/=?^`{|}~-]"; + std::string dot_string = "(" + atext + "+(\\." + atext + "+)*)"; + std::string quoted_string = + "\\\\\"(\\\\[\\x20-\\x7E]|[\\x20\\x21\\x23-\\x5B\\x5D-\\x7E])*\\\\\""; + std::string domain = + "([A-Za-z0-9]([\\-A-Za-z0-9]*[A-Za-z0-9])?)((\\.[A-Za-z0-9][\\-A-Za-z0-9]*[A-Za-z0-9])*)"; + std::string email_regex_pattern = + "^(" + dot_string + "|" + quoted_string + ")@" + domain + "$"; + std::string email_ebnf = RegexToEBNF(email_regex_pattern, false); + return "\"\\\"\" " + email_ebnf + " \"\\\"\""; + } + if (format == "date") { + // refer to RFC 3339, section 5.6 + std::string date_regex_pattern = "^(\\d\\d\\d\\d-(0[1-9]|1[0-2])-(0[1-9]|[1-2]\\d|3[01]))$"; + std::string date_ebnf = RegexToEBNF(date_regex_pattern, false); + return "\"\\\"\" " + date_ebnf + " \"\\\"\""; + } + if (format == "time") { + // refer to RFC 3339, section 5.6 + std::string time_regex_pattern = + "^([01]\\d|2[0-3]):[0-5]\\d:([0-5]\\d|60)(\\.\\d+)?(Z|[+-]([01]\\d|2[0-3]):[0-5]\\d)$"; + std::string time_ebnf = RegexToEBNF(time_regex_pattern, false); + return "\"\\\"\" " + time_ebnf + " \"\\\"\""; + } + if (format == "date-time") { + // refer to RFC 3339, section 5.6 + std::string date_time_regex_pattern = + "^(\\d\\d\\d\\d-(0[1-9]|1[0-2])-(0[1-9]|[1-2]\\d|3[01]))T([01]\\d|2[0-3]):([0-5]\\d|60):[" + "0-5]\\d(\\.\\d+)?(Z|[+-]([01]\\d|2[0-3]):[0-5]\\d)$"; + std::string date_time_ebnf = RegexToEBNF(date_time_regex_pattern, false); + return "\"\\\"\" " + date_time_ebnf + " \"\\\"\""; + } + if (format == "duration") { + // refer to RFC 3339, Appendix A + std::string duration_regex_pattern = + "^P((\\d+D|\\d+M(\\d+D)?|\\d+Y(\\d+M(\\d+D)?)?)(T(\\d+S|\\d+M(\\d+S)?|\\d+H(\\d+M(\\d+S)?" + ")?))?|T(\\d+S|\\d+M(\\d+S)?|\\d+H(\\d+M(\\d+S)?)?)|\\d+W)$"; + std::string duration_ebnf = RegexToEBNF(duration_regex_pattern, false); + return "\"\\\"\" " + duration_ebnf + " \"\\\"\""; + } + if (format == "ipv4") { + // refer to RFC 2673, section 3.2 + std::string decbyte = "(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)"; + std::string ipv4_regex_pattern = + "^" + decbyte + "\\." + decbyte + "\\." + decbyte + "\\." + decbyte + "$"; + std::string ipv4_ebnf = RegexToEBNF(ipv4_regex_pattern, false); + return "\"\\\"\" " + ipv4_ebnf + " \"\\\"\""; + } + if (format == "ipv6") { + // refer to RFC 3986, section 3.3.2 + std::string decbyte = "(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)"; + std::string ipv4 = "(" + decbyte + "\\." + decbyte + "\\." + decbyte + "\\." + decbyte + ")"; + std::string h16 = "([\\dA-Fa-f][\\dA-Fa-f]?[\\dA-Fa-f]?[\\dA-Fa-f]?)"; + std::string ls32 = "(" + h16 + ":" + h16 + "|" + ipv4 + ")"; + auto f = [h16](int low, int high, std::string end) { + std::string out = ""; + for (int i = 0; i < low; ++i) { + out += h16 + ":"; + } + for (int i = low; i < high; ++i) { + out += "(" + h16 + ":)?"; + } + return out + end; + }; + std::string ipv6_regex_pattern = + "^(" + f(6, 6, ls32) + "|::" + f(5, 5, ls32) + "|" + h16 + "?::" + f(4, 4, ls32) + "|(" + + f(0, 1, h16) + ")?::" + f(3, 3, ls32) + "|(" + f(0, 2, h16) + ")?::" + f(2, 2, ls32) + + "|(" + f(0, 3, h16) + ")?::" + f(1, 1, ls32) + "|(" + f(0, 4, h16) + ")?::" + ls32 + + "|(" + f(0, 5, h16) + ")?::" + h16 + "|(" + f(0, 6, h16) + ")?::)$"; + std::string ipv6_ebnf = RegexToEBNF(ipv6_regex_pattern, false); + return "\"\\\"\" " + ipv6_ebnf + " \"\\\"\""; + } + if (format == "hostname") { + // refer to RFC 1123, section 2.1 + std::string hostname_regex_pattern = + "^([a-z0-9]([a-z0-9-]*[a-z0-9])?)(\\.[a-z0-9]([a-z0-9-]*[a-z0-9])?)*$"; + std::string hostname_ebnf = RegexToEBNF(hostname_regex_pattern, false); + return "\"\\\"\" " + hostname_ebnf + " \"\\\"\""; + } + if (format == "uuid") { + // refer to RFC 4122, section 3 + std::string uuid_regex_pattern = ""; + std::string hex_digit = "[0-9A-Fa-f]"; + for (int i = 0; i < 8; ++i) uuid_regex_pattern += hex_digit; + uuid_regex_pattern += "-"; + for (int i = 0; i < 4; ++i) uuid_regex_pattern += hex_digit; + uuid_regex_pattern += "-"; + for (int i = 0; i < 4; ++i) uuid_regex_pattern += hex_digit; + uuid_regex_pattern += "-"; + for (int i = 0; i < 4; ++i) uuid_regex_pattern += hex_digit; + uuid_regex_pattern += "-"; + for (int i = 0; i < 12; ++i) uuid_regex_pattern += hex_digit; + uuid_regex_pattern = "^" + uuid_regex_pattern + "$"; + std::string uuid_ebnf = RegexToEBNF(uuid_regex_pattern, false); + return "\"\\\"\" " + uuid_ebnf + " \"\\\"\""; + } + if (format == "uri") { + // refer to RFC 3986, Appendix A, but skipping IP-literal and IPv4address currently + std::string schema = "[a-zA-Z][a-zA-Z+\\.-]*"; + std::string pchar = "([\\w\\.~!$&'()*+,;=:@-]|%[0-9A-Fa-f][0-9A-Fa-f])"; + std::string query_fragment_char = "([\\w\\.~!$&'()*+,;=:@/\\?-]|%[0-9A-Fa-f][0-9A-Fa-f])*"; + std::string query = "(\\?" + query_fragment_char + ")?"; + std::string fragment = "(#" + query_fragment_char + ")?"; + std::string path_abempty = "(/" + pchar + "*)*"; + std::string path_absolute_rootless_empty = "/?(" + pchar + "+(/" + pchar + "*)*)?"; + std::string userinfo = "([\\w\\.~!$&'()*+,;=:-]|%[0-9A-Fa-f][0-9A-Fa-f])*"; + std::string host = "([\\w\\.~!$&'()*+,;=-]|%[0-9A-Fa-f][0-9A-Fa-f])*"; + std::string authority = "(" + userinfo + "@)?" + host + "(:\\d*)?"; + std::string hier_part = + "(//" + authority + path_abempty + "|" + path_absolute_rootless_empty + ")"; + std::string uri_regex_pattern = "^" + schema + ":" + hier_part + query + fragment + "$"; + std::string uri_ebnf = RegexToEBNF(uri_regex_pattern, false); + return "\"\\\"\" " + uri_ebnf + " \"\\\"\""; + } + + if (format == "uri-reference") { + // refer to RFC 3986, Appendix A, but skipping IP-literal and IPv4address currently + std::string pchar = "([\\w\\.~!$&'()*+,;=:@-]|%[0-9A-Fa-f][0-9A-Fa-f])"; + std::string query_fragment_char = "([\\w\\.~!$&'()*+,;=:@/\\?-]|%[0-9A-Fa-f][0-9A-Fa-f])*"; + std::string query = "(\\?" + query_fragment_char + ")?"; + std::string fragment = "(#" + query_fragment_char + ")?"; + std::string path_abempty = "(/" + pchar + "*)*"; + std::string path_absolute = "/(" + pchar + "+(/" + pchar + "*)*)?"; + std::string segment_nz_nc = "([\\w\\.~!$&'()*+,;=@-]|%[0-9A-Fa-f][0-9A-Fa-f])+"; + std::string path_noscheme = segment_nz_nc + "(/" + pchar + "*)*"; + std::string userinfo = "([\\w\\.~!$&'()*+,;=:-]|%[0-9A-Fa-f][0-9A-Fa-f])*"; + std::string host = "([\\w\\.~!$&'()*+,;=-]|%[0-9A-Fa-f][0-9A-Fa-f])*"; + std::string authority = "(" + userinfo + "@)?" + host + "(:\\d*)?"; + std::string relative_part = + "(//" + authority + path_abempty + "|" + path_absolute + "|" + path_noscheme + ")?"; + std::string uri_reference_regex_pattern = "^" + relative_part + query + fragment + "$"; + std::string uri_reference_ebnf = RegexToEBNF(uri_reference_regex_pattern, false); + return "\"\\\"\" " + uri_reference_ebnf + " \"\\\"\""; + } + if (format == "uri-template") { + // refer to RFC 6570, section 2 + std::string literals = + "([\\x21\\x23-\\x24\\x26\\x28-\\x3B\\x3D\\x3F-\\x5B\\x5D\\x5F\\x61-\\x7A\\x7E]" + "|%[0-9A-Fa-f][0-9A-Fa-f])"; + std::string op = "[+#\\./;\\?&=,!@|]"; + std::string varchar = "(\\w|%[0-9A-Fa-f][0-9A-Fa-f])"; + std::string varname = varchar + "(\\.?" + varchar + ")*"; + std::string varspec = varname + "(:[1-9]\\d?\\d?\\d?|\\*)?"; + std::string variable_list = varspec + "(," + varspec + ")*"; + std::string expression = "\\{(" + op + ")?" + variable_list + "\\}"; + std::string uri_template_regex_pattern = "^(" + literals + "|" + expression + ")*$"; + std::string uri_template_ebnf = RegexToEBNF(uri_template_regex_pattern, false); + return "\"\\\"\" " + uri_template_ebnf + " \"\\\"\""; + } + if (format == "json-pointer") { + // refer to RFC 6901, section 3 + std::string json_pointer_regex_pattern = + "^(/([\\x00-\\x2E]|[\\x30-\\x7D]|[\\x7F-\\U0010FFFF]|~[01])*)*$"; + std::string json_pointer_ebnf = RegexToEBNF(json_pointer_regex_pattern, false); + return "\"\\\"\" " + json_pointer_ebnf + " \"\\\"\""; + } + if (format == "relative-json-pointer") { + // refer to draft-handrews-relative-json-pointer-01, section 3 + std::string relative_json_pointer_regex_pattern = + "^(0|[1-9][0-9]*)(#|(/([\\x00-\\x2E]|[\\x30-\\x7D]|[\\x7F-\\U0010FFFF]|~[01])*)*)$"; + std::string relative_json_pointer_ebnf = + RegexToEBNF(relative_json_pointer_regex_pattern, false); + return "\"\\\"\" " + relative_json_pointer_ebnf + " \"\\\"\""; + } + } WarnUnsupportedKeywords( schema, { diff --git a/tests/python/test_json_schema_converter.py b/tests/python/test_json_schema_converter.py index 7caa1ea6..7ef64d1d 100644 --- a/tests/python/test_json_schema_converter.py +++ b/tests/python/test_json_schema_converter.py @@ -1,10 +1,12 @@ import json import sys +import time from enum import Enum from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union import pytest from pydantic import BaseModel, Field, TypeAdapter, WithJsonSchema, create_model +from transformers import AutoTokenizer import xgrammar as xgr from xgrammar.testing import _generate_range_regex, _is_grammar_accept_string, _json_schema_to_ebnf @@ -980,5 +982,507 @@ def test_generate_range_regex(): assert _generate_range_regex(0, 10) == r"^(0|([1-9]|10))$" +email_instances_accepted = [ + (r"simple@example.com", True), + (r"very.common@example.com", True), + (r"FirstName.LastName@EasierReading.org", True), + (r"x@example.com", True), + (r"long.email-address-with-hyphens@and.subdomains.example.com", True), + (r"user.name+tag+sorting@example.com", True), + (r"name/surname@example.com", True), + (r"admin@example", True), + (r"example@s.example", True), + (r'" "@example.org', True), + (r'"john..doe"@example.org', True), + (r"mailhost!username@example.org", True), + (r'"very.(),:;<>[]\".VERY.\"very@\\ \"very\".unusual"@strange.example.com', True), + (r"user%example.com@example.org", True), + (r"user-@example.org", True), + (r"abc.example.com", False), + (r"a@b@c@example.com", False), + (r'a"b(c)d,e:f;gi[j\k]l@example.com', False), + (r'just"not"right@example.com', False), + (r'this is"not\allowed@example.com', False), + (r"this\ still\"not\\allowed@example.com", False), + (r"i.like.underscores@but_they_are_not_allowed_in_this_part", False), +] + + +@pytest.mark.parametrize("instance, accepted", email_instances_accepted) +def test_email_format(instance: str, accepted: bool): + class MainModel(BaseModel): + name: str = Field(json_schema_extra={"format": "email"}) + + schema = MainModel.model_json_schema() + check_schema_with_instance( + schema, MainModel(name=instance), is_accepted=accepted, any_whitespace=False + ) + + +date_instances_accepted = [ + (r"0000-01-01", True), + (r"9999-12-31", True), + (r"10-01-01", False), + (r"2025-00-01", False), + (r"2025-13-01", False), + (r"2025-01-00", False), + (r"2025-01-32", False), +] + + +@pytest.mark.parametrize("instance, accepted", date_instances_accepted) +def test_date_format(instance: str, accepted: bool): + class MainModel(BaseModel): + name: str = Field(json_schema_extra={"format": "date"}) + + schema = MainModel.model_json_schema() + check_schema_with_instance( + schema, MainModel(name=instance), is_accepted=accepted, any_whitespace=False + ) + + +time_instances_accepted = [ + (r"00:00:00Z", True), + (r"23:59:60Z", True), + (r"12:34:56Z", True), + (r"12:34:56+07:08", True), + (r"12:34:56-07:08", True), + (r"12:34:56.7Z", True), + (r"12:34:56.7+08:09", True), + (r"12:34:56.7-08:09", True), + (r"00:00:00", False), + (r"23:59:60", False), + (r"12:34:56.7", False), + (r"12:34:56.7890", False), + (r"24:00:00", False), + (r"00:60:00", False), + (r"00:00:61", False), + (r"00:00:00.", False), + (r"12:34:56+07:", False), + (r"12:34:56-07:", False), + (r"12:34:56.7+-08:09", False), +] + + +@pytest.mark.parametrize("instance, accepted", time_instances_accepted) +def test_time_format(instance: str, accepted: bool): + class MainModel(BaseModel): + name: str = Field(json_schema_extra={"format": "time"}) + + schema = MainModel.model_json_schema() + check_schema_with_instance( + schema, MainModel(name=instance), is_accepted=accepted, any_whitespace=False + ) + + +duration_instances_accepted = [ + (r"P0Y", True), + (r"P12M", True), + (r"P345D", True), + (r"P6789W", True), + (r"P01234D", True), + (r"PT9H", True), + (r"PT87M", True), + (r"PT654S", True), + (r"P1Y23M456D", True), + (r"P23M456D", True), + (r"P1Y0M456D", True), + (r"P1Y23M", True), + (r"PT9H87M654S", True), + (r"PT87M654S", True), + (r"PT9H0M654S", True), + (r"PT9H87M", True), + (r"P1Y23M456DT9H87M654S", True), + (r"P", False), + (r"PD", False), + (r"P1", False), + (r"PT", False), + (r"P1Y456D", False), + (r"PT9H654S", False), +] + + +@pytest.mark.parametrize("instance, accepted", duration_instances_accepted) +def test_duration_format(instance: str, accepted: bool): + class MainModel(BaseModel): + name: str = Field(json_schema_extra={"format": "duration"}) + + schema = MainModel.model_json_schema() + check_schema_with_instance( + schema, MainModel(name=instance), is_accepted=accepted, any_whitespace=False + ) + + +ipv6_instances_accepted = [ + (r"0123:4567:890a:bced:fABC:DEF0:1234:5678", True), + (r"::6666:6666:6666:6666:6666:6666", True), + (r"::6666:6666:6666:6666:6666", True), + (r"::6666:6666:6666:6666", True), + (r"::6666:6666:6666", True), + (r"::6666:6666", True), + (r"::6666", True), + (r"::", True), + (r"8888:8888:8888:8888:8888:8888::", True), + (r"8888:8888:8888:8888:8888::", True), + (r"8888:8888:8888:8888::", True), + (r"8888:8888:8888::", True), + (r"8888:8888::", True), + (r"8888::", True), + (r"1111::2222", True), + (r"1111:1111::2222", True), + (r"1111::2222:2222", True), + (r"1111:1111:1111::2222", True), + (r"1111:1111::2222:2222", True), + (r"1111::2222:2222:2222", True), + (r"1111:1111:1111:1111::2222", True), + (r"1111:1111:1111::2222:2222", True), + (r"1111:1111::2222:2222:2222", True), + (r"1111::2222:2222:2222:2222", True), + (r"1111:1111:1111:1111:1111::2222", True), + (r"1111:1111:1111:1111::2222:2222", True), + (r"1111:1111:1111::2222:2222:2222", True), + (r"1111:1111::2222:2222:2222:2222", True), + (r"1111::2222:2222:2222:2222:2222", True), + (r"1111:1111:1111:1111:1111:1111::2222", True), + (r"1111:1111:1111:1111:1111::2222:2222", True), + (r"1111:1111:1111:1111::2222:2222:2222", True), + (r"1111:1111:1111::2222:2222:2222:2222", True), + (r"1111:1111::2222:2222:2222:2222:2222", True), + (r"1111::2222:2222:2222:2222:2222:2222", True), + (r"0123:4567:890a:bced:fABC:DEF0:012.034.056.078", True), + (r"::111.111.222.222", True), + (r":", False), + (r":::", False), + (r"::5555:5555:5555:5555:5555:5555:5555:5555", False), + (r"5555::5555:5555:5555:5555:5555:5555:5555", False), + (r"5555:5555::5555:5555:5555:5555:5555:5555", False), + (r"5555:5555:5555::5555:5555:5555:5555:5555", False), + (r"5555:5555:5555:5555::5555:5555:5555:5555", False), + (r"5555:5555:5555:5555:5555::5555:5555:5555", False), + (r"5555:5555:5555:5555:5555:5555::5555:5555", False), + (r"5555:5555:5555:5555:5555:5555:5555::5555", False), + (r"5555:5555:5555:5555:5555:5555:5555:5555::", False), +] + + +@pytest.mark.parametrize("instance, accepted", ipv6_instances_accepted) +def test_ipv6_format(instance: str, accepted: bool): + class MainModel(BaseModel): + name: str = Field(json_schema_extra={"format": "ipv6"}) + + schema = MainModel.model_json_schema() + check_schema_with_instance( + schema, MainModel(name=instance), is_accepted=accepted, any_whitespace=False + ) + + +ipv4_instances_accepted = [ + # (r"0.0.0.0", True), + (r"00.00.00.00", True), + (r"000.000.000.000", True), + (r"255.255.255.255", True), + (r"1", False), + (r"1.", False), + (r"1.1", False), + (r"1.1.", False), + (r"1.1.1", False), + (r"1.1.1.", False), + (r"0001.0001.0001.0001", False), + (r"256.256.256.256", False), +] + + +@pytest.mark.parametrize("instance, accepted", ipv4_instances_accepted) +def test_ipv4_format(instance: str, accepted: bool): + class MainModel(BaseModel): + name: str = Field(json_schema_extra={"format": "ipv4"}) + + schema = MainModel.model_json_schema() + check_schema_with_instance( + schema, MainModel(name=instance), is_accepted=accepted, any_whitespace=False + ) + + +hostname_instances_accepted = [ + (r"0", True), + (r"9", True), + (r"a", True), + (r"z", True), + (r"www.github.com", True), + (r"w-w-w.g-i-t-h-u-b.c-o-m", True), + (r"ww-w.gi-th-ub.co-m", True), + (r"w--ww.git---hub.co----m", True), + (r".", False), + (r"-", False), + (r"-.", False), + (r".-", False), + (r"_", False), + (r"a.", False), + (r"-b", False), + (r"c-", False), + (r"d.-", False), + (r"e-.", False), + (r"-f.", False), + (r"g-.h", False), + (r"-i.j", False), +] + + +@pytest.mark.parametrize("instance, accepted", hostname_instances_accepted) +def test_hostname_format(instance: str, accepted: bool): + class MainModel(BaseModel): + name: str = Field(json_schema_extra={"format": "hostname"}) + + schema = MainModel.model_json_schema() + check_schema_with_instance( + schema, MainModel(name=instance), is_accepted=accepted, any_whitespace=False + ) + + +uuid_instances_accepted = [ + (r"00000000-0000-0000-0000-000000000000", True), + (r"FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF", True), + (r"01234567-89AB-CDEF-abcd-ef0123456789", True), + (r"-", False), + (r"----", False), + (r"AAAAAAA-AAAA-AAAA-AAAA-AAAAAAAAAAAA", False), + (r"BBBBBBBB-BBB-BBBB-BBBB-BBBBBBBBBBBB", False), + (r"CCCCCCCC-CCCC-CCC-CCCC-CCCCCCCCCCCC", False), + (r"DDDDDDDD-DDDD-DDDD-DDD-DDDDDDDDDDDD", False), + (r"EEEEEEEE-EEEE-EEEE-EEEE-EEEEEEEEEEE", False), + (r"AAAAAAAAA-AAAA-AAAA-AAAA-AAAAAAAAAAAA", False), + (r"BBBBBBBB-BBBBB-BBBB-BBBB-BBBBBBBBBBBB", False), + (r"CCCCCCCC-CCCC-CCCCC-CCCC-CCCCCCCCCCCC", False), + (r"DDDDDDDD-DDDD-DDDD-DDDDD-DDDDDDDDDDDD", False), + (r"EEEEEEEE-EEEE-EEEE-EEEE-EEEEEEEEEEEEE", False), +] + + +@pytest.mark.parametrize("instance, accepted", uuid_instances_accepted) +def test_uuid_format(instance: str, accepted: bool): + class MainModel(BaseModel): + name: str = Field(json_schema_extra={"format": "uuid"}) + + schema = MainModel.model_json_schema() + check_schema_with_instance( + schema, MainModel(name=instance), is_accepted=accepted, any_whitespace=False + ) + + +uri_instances_accepted = [ + (r"aaa:?azAZ09-._~%Ff!$&'()*+,;=:@#azAZ09-._~%Aa!$&'()*+,;=:@", True), + (r"z+.-:", True), + (r"abc:", True), + (r"abc:a", True), + (r"abc:/", True), + (r"abc:/a", True), + (r"abc://", True), + (r"abc://///////", True), + (r"abc://azAZ09-._~%Ff!$&'()*+,;=:@", True), + (r"abc://:", True), + (r"abc://:0123", True), + (r"abc://azAZ09-._~%Ff!$&'()*+,;=", True), + (r"xyz:/a", True), + (r"xyz:/azAZ09-._~%Ff!$&'()*+,;=:@", True), + (r"aaa:?[#]", False), + (r"abc://@@", False), + (r"abc://::", False), + (r"abc:/[]", False), +] + + +@pytest.mark.parametrize("instance, accepted", uri_instances_accepted) +def test_uri_format(instance: str, accepted: bool): + class MainModel(BaseModel): + name: str = Field(json_schema_extra={"format": "uri"}) + + schema = MainModel.model_json_schema() + check_schema_with_instance( + schema, MainModel(name=instance), is_accepted=accepted, any_whitespace=False + ) + + +uri_reference_instances_accepted = [ + (r"?azAZ09-._~%Ff!$&'()*+,;=:@#azAZ09-._~%Aa!$&'()*+,;=:@", True), + (r"", True), + (r"a", True), + (r"/", True), + (r"/a", True), + (r"//", True), + (r"/////////", True), + (r"//azAZ09-._~%Ff!$&'()*+,;=:@", True), + (r"//:", True), + (r"//:0123", True), + (r"//azAZ09-._~%Ff!$&'()*+,;=", True), + (r"/a", True), + (r"/azAZ09-._~%Ff!$&'()*+,;=:@", True), + (r"?[#]", False), + (r"//@@", False), + (r"//::", False), + (r"/[]", False), + (r":", False), +] + + +@pytest.mark.parametrize("instance, accepted", uri_reference_instances_accepted) +def test_uri_reference_format(instance: str, accepted: bool): + class MainModel(BaseModel): + name: str = Field(json_schema_extra={"format": "uri-reference"}) + + schema = MainModel.model_json_schema() + check_schema_with_instance( + schema, MainModel(name=instance), is_accepted=accepted, any_whitespace=False + ) + + +uri_template_instances_accepted = [ + (r"", True), + (r"!#$&()*+,-./09:;=?@AZ[]_az~%Ff", True), + (r"{+a}{#a}{.a}{/a}{;a}{?a}{&a}{=a}{,a}{!a}{@a}{|a}", True), + (r"{%Ff}", True), + (r"{i.j.k}", True), + (r"{a_b_c:1234}", True), + (r"{x_y_z*}", True), + (r'"', False), + (r"'", False), + (r"%", False), + (r"<", False), + (r">", False), + (r"\\", False), + (r"^", False), + (r"`", False), + (r"{", False), + (r"|", False), + (r"}", False), + (r"{n.}", False), + (r"{m:100001}", False), + (r"%1", False), + (r"%Gg", False), +] + + +@pytest.mark.parametrize("instance, accepted", uri_template_instances_accepted) +def test_uri_template_format(instance: str, accepted: bool): + class MainModel(BaseModel): + name: str = Field(json_schema_extra={"format": "uri-template"}) + + schema = MainModel.model_json_schema() + check_schema_with_instance( + schema, MainModel(name=instance), is_accepted=accepted, any_whitespace=False + ) + + +json_pointer_instances_accepted = [ + (r"/", True), + (r"//", True), + (r"/a/bc/def/ghij", True), + (r"/~0/~1/", True), + (r"abc", False), + (r"/~", False), + (r"/~2", False), +] + + +@pytest.mark.parametrize("instance, accepted", json_pointer_instances_accepted) +def test_json_pointer_format(instance: str, accepted: bool): + class MainModel(BaseModel): + name: str = Field(json_schema_extra={"format": "json-pointer"}) + + schema = MainModel.model_json_schema() + check_schema_with_instance( + schema, MainModel(name=instance), is_accepted=accepted, any_whitespace=False + ) + + +relative_json_pointer_instances_accepted = [ + (r"0/", True), + (r"123/a/bc/def/ghij", True), + (r"45/~0/~1/", True), + (r"6789#", True), + (r"#", False), + (r"abc", False), + (r"/", False), + (r"9/~2", False), +] + + +@pytest.mark.parametrize("instance, accepted", relative_json_pointer_instances_accepted) +def test_relative_json_pointer_format(instance: str, accepted: bool): + class MainModel(BaseModel): + name: str = Field(json_schema_extra={"format": "relative-json-pointer"}) + + schema = MainModel.model_json_schema() + check_schema_with_instance( + schema, MainModel(name=instance), is_accepted=accepted, any_whitespace=False + ) + + +string_format_instances = [ + (r"long.email-address-with-hyphens@and.subdomains.example.com", "email"), + (r'"very.(),:;<>[]\".VERY.\"very@\\ \"very\".unusual"@strange.example.com', "email"), + (r"128.255.000.222", "ipv4"), + (r"abcd:ABCD::0123:5678:000.111.222.123", "ipv6"), + (r"P1Y23M456DT9H87M654S", "duration"), + (r"2025-01-01T12:34:56.7+08:09", "date-time"), + (r"123--abc.efgh---789-xyz.rst-uvw", "hostname"), + (r"01234567-89AB-CDEF-abcd-ef0123456789", "uuid"), + ( + r"http://azAZ09-._~%Ff!$&'()*+,;=:@xyz:987/-/./+/*?aA0-._~%Ff!$&'()@#zZ9-._~%Aa!$&,;=:", + "uri", + ), + ( + r"//azAZ09-._~%Ff!$&'()*+,;=:@xyz:987/-/./+/*?aA0-._~%Ff!$&'()@#zZ9-._~%Aa!$&,;=:", + "uri-reference", + ), + (r"!#$&()*+,-./{+abc}{#def}{.ghi}{/jkl}{;mno:2468}", "uri-template"), + (r"/a/bc/def/ghij/~0~1//", "json-pointer"), + (r"1234/a/bc/def/ghij/~0~1//", "relative-json-pointer"), +] + + +@pytest.mark.parametrize("value, format", string_format_instances) +def test_mask_generation_format(value: str, format: str): + class MainModel(BaseModel): + name: str = Field(json_schema_extra={"format": format}) + + instance = json.dumps(MainModel(name=value).model_dump(mode="json")) + + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) + grammar_compiler = xgr.GrammarCompiler(tokenizer_info, cache_enabled=False) + + time_start = time.monotonic_ns() + ebnf = _json_schema_to_ebnf( + json.dumps(MainModel.model_json_schema()), + any_whitespace=None, + indent=None, + separators=None, + strict_mode=True, + ) + matcher_compiled_grammar = grammar_compiler.compile_grammar(ebnf) + time_end = time.monotonic_ns() + print(f"Time for preprocessing: {(time_end - time_start) / 1e3} us") + matcher = xgr.GrammarMatcher(matcher_compiled_grammar) + token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size) + + for c in instance.encode("utf-8"): + time_start = time.monotonic_ns() + matcher.fill_next_token_bitmask(token_bitmask) + time_end = time.monotonic_ns() + delta = (time_end - time_start) / 1e3 + if delta > 1000: + print(f"Time for fill_next_token_bitmask: {delta} us on char {bytes([c])}") + accepted = matcher._debug_accept_string(bytes([c])) + assert accepted + + time_start = time.monotonic_ns() + matcher.fill_next_token_bitmask(token_bitmask) + time_end = time.monotonic_ns() + print(f"Time for fill_next_token_bitmask: {(time_end - time_start) / 1e3} us") + + assert matcher.accept_token(tokenizer.eos_token_id) + assert matcher.is_terminated() + + if __name__ == "__main__": pytest.main(sys.argv)