Skip to content

Commit

Permalink
Recurse into recursive type alias unions (#1416)
Browse files Browse the repository at this point in the history
Fixes #1414
<!-- ELLIPSIS_HIDDEN -->

----

> [!IMPORTANT]
> Enhance recursive type alias union handling in BAML by using a
stack-based approach and updating client files and tests.
> 
>   - **Behavior**:
> - Modify `relevant_data_models()` in `render_output_format.rs` to use
a stack (`stack`) instead of a list (`start`) for processing recursive
type alias unions.
> - Ensure recursive type aliases are pushed onto the stack only if not
already present in `structural_recursive_aliases`.
>   - **Tests**:
> - Add `ReturnJsonEntry` function in `recursive-type-aliases.baml` to
test recursive type alias unions.
> - Add `test_union_of_recursive_alias_or_class` in `test_functions.py`
to verify new behavior.
>   - **Client Updates**:
> - Add `ReturnJsonEntry` function to `async_client.py`,
`sync_client.py`, and `client.rb` to handle new recursive type alias
unions.
> - Update TypeScript client files (`async_client.ts`, `sync_client.ts`,
`type_builder.ts`, `types.ts`) to include `SimpleTag` and `JsonTemplate`
types.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for c2e0830. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
antoniosarosi authored Feb 5, 2025
1 parent 9c068e0 commit a648559
Show file tree
Hide file tree
Showing 20 changed files with 366 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,13 @@ fn relevant_data_models<'a>(
let mut classes = Vec::new();
let mut recursive_classes = IndexSet::new();
let mut structural_recursive_aliases = IndexMap::new();
let mut start: Vec<baml_types::FieldType> = vec![output.clone()];
let mut stack: Vec<baml_types::FieldType> = vec![output.clone()];

// start.extend(ctx.type_alias_overrides.values().cloned());

let eval_ctx = ctx.eval_ctx(false);

while let Some(output) = start.pop() {
while let Some(output) = stack.pop() {
match ir.distribute_metadata(&output) {
(FieldType::Enum(enm), (constraints, streaming_behavior)) => {
if checked_types.insert(output.to_string()) {
Expand Down Expand Up @@ -280,24 +280,24 @@ fn relevant_data_models<'a>(
}
(FieldType::List(inner), _) | (FieldType::Optional(inner), _) => {
if !checked_types.contains(&inner.to_string()) {
start.push(inner.as_ref().clone());
stack.push(inner.as_ref().clone());
}
}
(FieldType::Map(k, v), _) => {
if checked_types.insert(output.to_string()) {
if !checked_types.contains(&k.to_string()) {
start.push(k.as_ref().clone());
stack.push(k.as_ref().clone());
}
if !checked_types.contains(&v.to_string()) {
start.push(v.as_ref().clone());
stack.push(v.as_ref().clone());
}
}
}
(FieldType::Tuple(options), _) | (FieldType::Union(options), _) => {
if checked_types.insert(output.to_string()) {
for inner in options {
if !checked_types.contains(&inner.to_string()) {
start.push(inner.clone());
stack.push(inner.clone());
}
}
}
Expand Down Expand Up @@ -352,7 +352,7 @@ fn relevant_data_models<'a>(

for (_, t, _, _) in fields.iter().as_ref() {
if !checked_types.contains(&t.to_string()) {
start.push(t.clone());
stack.push(t.clone());
}
}

Expand Down Expand Up @@ -395,7 +395,12 @@ fn relevant_data_models<'a>(
for cycle in ir.structural_recursive_alias_cycles() {
if cycle.contains_key(name) {
for (alias, target) in cycle.iter() {
structural_recursive_aliases.insert(alias.to_owned(), target.clone());
if structural_recursive_aliases
.insert(alias.to_owned(), target.clone())
.is_none()
{
stack.push(target.clone());
}
}
}
}
Expand All @@ -404,7 +409,12 @@ fn relevant_data_models<'a>(
for cycle in &ctx.recursive_type_alias_overrides {
if cycle.contains_key(name) {
for (alias, target) in cycle.iter() {
structural_recursive_aliases.insert(alias.to_owned(), target.clone());
if structural_recursive_aliases
.insert(alias.to_owned(), target.clone())
.is_none()
{
stack.push(target.clone());
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ function AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlia
client "openai/gpt-4o"
prompt r#"
Return the given linked list back:

{{ list }}

{{ ctx.output_format }}
"#
}
Expand All @@ -26,9 +26,9 @@ function ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> Cl
client "openai/gpt-4o"
prompt r#"
Return the given object back:

{{ cls }}

{{ ctx.output_format }}
"#
}
Expand All @@ -46,9 +46,29 @@ function RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> No
client "openai/gpt-4o"
prompt r#"
Return the given object back:

{{ cls }}

{{ ctx.output_format }}
"#
}

type JsonEntry = SimpleTag | JsonTemplate

type JsonTemplate = map<string, JsonEntry>

class SimpleTag {
field string
}

function ReturnJsonEntry(s: string) -> JsonTemplate {
client GPT4o
prompt #"
{{ _.role("user") }}

Extract info from this string:
{{ s }}

{{ ctx.output_format }}
"#
}
53 changes: 53 additions & 0 deletions integ-tests/python/baml_client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2074,6 +2074,29 @@ async def ReturnFailingAssert(
)
return cast(int, raw.cast_to(types, types, partial_types, False))

async def ReturnJsonEntry(
self,
s: str,
baml_options: BamlCallOptions = {},
) -> types.JsonTemplate:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"ReturnJsonEntry",
{
"s": s,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(types.JsonTemplate, raw.cast_to(types, types, partial_types, False))

async def ReturnMalformedConstraints(
self,
a: int,
Expand Down Expand Up @@ -6270,6 +6293,36 @@ def ReturnFailingAssert(
self.__ctx_manager.get(),
)

def ReturnJsonEntry(
self,
s: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[types.JsonTemplate, types.JsonTemplate]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function(
"ReturnJsonEntry",
{
"s": s,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlStream[types.JsonTemplate, types.JsonTemplate](
raw,
lambda x: cast(types.JsonTemplate, x.cast_to(types, types, partial_types, True)),
lambda x: cast(types.JsonTemplate, x.cast_to(types, types, partial_types, False)),
self.__ctx_manager.get(),
)

def ReturnMalformedConstraints(
self,
a: int,
Expand Down
2 changes: 1 addition & 1 deletion integ-tests/python/baml_client/inlinedbaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
"test-files/functions/output/optional-class.baml": "class ClassOptionalOutput {\n prop1 string\n prop2 string\n}\n\nfunction FnClassOptionalOutput(input: string) -> ClassOptionalOutput? {\n client GPT35\n prompt #\"\n Return a json blob for the following input:\n {{input}}\n\n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\n\nclass Blah {\n prop4 string?\n}\n\nclass ClassOptionalOutput2 {\n prop1 string?\n prop2 string?\n prop3 Blah?\n}\n\nfunction FnClassOptionalOutput2(input: string) -> ClassOptionalOutput2? {\n client GPT35\n prompt #\"\n Return a json blob for the following input:\n {{input}}\n\n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest FnClassOptionalOutput2 {\n functions [FnClassOptionalOutput2, FnClassOptionalOutput]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/output/optional.baml": "class OptionalTest_Prop1 {\n omega_a string\n omega_b int\n}\n\nenum OptionalTest_CategoryType {\n Aleph\n Beta\n Gamma\n}\n \nclass OptionalTest_ReturnType {\n omega_1 OptionalTest_Prop1?\n omega_2 string?\n omega_3 (OptionalTest_CategoryType?)[]\n} \n \nfunction OptionalTest_Function(input: string) -> (OptionalTest_ReturnType?)[]\n{ \n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest OptionalTest_Function {\n functions [OptionalTest_Function]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/output/recursive-class.baml": "class Node {\n data int\n next Node?\n}\n\nclass LinkedList {\n head Node?\n len int\n}\n\nclient<llm> O1 {\n provider \"openai\"\n options {\n model \"o1-mini\"\n default_role \"user\"\n }\n}\n\nfunction BuildLinkedList(input: int[]) -> LinkedList {\n client O1\n prompt #\"\n Build a linked list from the input array of integers.\n\n INPUT:\n {{ input }}\n\n {{ ctx.output_format }} \n \"#\n}\n\ntest TestLinkedList {\n functions [BuildLinkedList]\n args {\n input [1, 2, 3, 4, 5]\n }\n}\n",
"test-files/functions/output/recursive-type-aliases.baml": "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n",
"test-files/functions/output/recursive-type-aliases.baml": "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n\n {{ list }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n\n {{ cls }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n\n {{ cls }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype JsonEntry = SimpleTag | JsonTemplate\n\ntype JsonTemplate = map<string, JsonEntry>\n\nclass SimpleTag {\n field string\n}\n\nfunction ReturnJsonEntry(s: string) -> JsonTemplate {\n client GPT4o\n prompt #\"\n {{ _.role(\"user\") }}\n\n Extract info from this string:\n {{ s }}\n\n {{ ctx.output_format }}\n \"#\n}\n",
"test-files/functions/output/serialization-error.baml": "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}",
"test-files/functions/output/string-list.baml": "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map<string, string[]>\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map<string, RecursiveMapAlias>\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype JsonValue = int | string | bool | float | JsonObject | JsonArray\ntype JsonObject = map<string, JsonValue>\ntype JsonArray = JsonValue[]\n\nfunction JsonTypeAliasCycle(input: JsonValue) -> JsonValue {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given input back:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nclass RecursiveAliasDependency {\n value JsonValue\n}\n\nfunction TakeRecAliasDep(input: RecursiveAliasDependency) -> RecursiveAliasDependency {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given input back:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n",
Expand Down
3 changes: 3 additions & 0 deletions integ-tests/python/baml_client/partial_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ class SemanticContainer(BaseModel):
three_small_things: List["SmallThing"]
final_string: Optional[str] = None

class SimpleTag(BaseModel):
field: Optional[str] = None

class SmallThing(BaseModel):
i_16_digits: int
i_8_digits: Optional[int] = None
Expand Down
53 changes: 53 additions & 0 deletions integ-tests/python/baml_client/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,6 +2071,29 @@ def ReturnFailingAssert(
)
return cast(int, raw.cast_to(types, types, partial_types, False))

def ReturnJsonEntry(
self,
s: str,
baml_options: BamlCallOptions = {},
) -> types.JsonTemplate:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.call_function_sync(
"ReturnJsonEntry",
{
"s": s,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(types.JsonTemplate, raw.cast_to(types, types, partial_types, False))

def ReturnMalformedConstraints(
self,
a: int,
Expand Down Expand Up @@ -6268,6 +6291,36 @@ def ReturnFailingAssert(
self.__ctx_manager.get(),
)

def ReturnJsonEntry(
self,
s: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlSyncStream[types.JsonTemplate, types.JsonTemplate]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function_sync(
"ReturnJsonEntry",
{
"s": s,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlSyncStream[types.JsonTemplate, types.JsonTemplate](
raw,
lambda x: cast(types.JsonTemplate, x.cast_to(types, types, partial_types, True)),
lambda x: cast(types.JsonTemplate, x.cast_to(types, types, partial_types, False)),
self.__ctx_manager.get(),
)

def ReturnMalformedConstraints(
self,
a: int,
Expand Down
Loading

0 comments on commit a648559

Please sign in to comment.