Skip to content

Commit

Permalink
Implement tests for dynamic types (#1343)
Browse files Browse the repository at this point in the history
Docs #1307
Discussion #1310 
<!-- ELLIPSIS_HIDDEN -->


----

> [!IMPORTANT]
> Implement comprehensive testing and support for dynamic types in the
BAML engine, including parsing, validation, and runtime handling.
> 
>   - **Behavior**:
> - Add support for dynamic types in `to_baml_arg.rs` by checking for
`dynamic_type` attribute.
> - Implement `type_builder` blocks in `repr.rs` to handle dynamic type
definitions within tests.
>     - Add validation for `type_builder` blocks in `lib.rs`.
>   - **Testing**:
> - Add tests for dynamic types in `dynamic_types.baml`,
`dynamic_types_external_cycle_errors.baml`,
`dynamic_types_internal_cycle_errors.baml`,
`dynamic_types_parser_errors.baml`, and
`dynamic_types_validation_errors.baml`.
> - Implement `run_type_builder_block_test` in `test_runtime.rs` to test
dynamic type handling.
>   - **Parsing**:
> - Update `parse_type_expression_block.rs` and
`parse_value_expression_block.rs` to handle `dynamic` and `type_builder`
blocks.
> - Add `parse_type_builder_block.rs` for parsing `type_builder` blocks.
>   - **Runtime**:
> - Extend `RuntimeContextManager` and `RuntimeContext` to support
dynamic type overrides.
> - Implement `get_test_type_builder` in `runtime_interface.rs` to
retrieve type builders for tests.
>   - **WASM**:
> - Update `runtime_wasm/mod.rs` to handle dynamic types in WASM
environment.
>   - **Syntax Highlighting**:
> - Update `codemirror-lang-baml` and `vscode-ext` to support
`type_builder` and `dynamic` syntax highlighting.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 483363d. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
antoniosarosi authored Feb 4, 2025
1 parent 026bc21 commit 7f852d0
Show file tree
Hide file tree
Showing 41 changed files with 1,700 additions and 186 deletions.
3 changes: 3 additions & 0 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ RUN curl https://mise.run | sh \
# Install Rust
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y

# Install WASM tools
RUN cargo install [email protected] wasm-pack

# Install Infisical
RUN curl -1sLf 'https://dl.cloudsmith.io/public/infisical/infisical-cli/setup.deb.sh' | sudo -E bash \
&& sudo apt update && sudo apt install -y infisical
4 changes: 3 additions & 1 deletion engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,9 @@ impl ArgCoercer {
(FieldType::Enum(name), _) => match value {
BamlValue::String(s) => {
if let Ok(e) = ir.find_enum(name) {
if e.walk_values().any(|v| v.item.elem.0 == *s) {
if e.walk_values().any(|v| v.item.elem.0 == *s)
|| e.item.attributes.get("dynamic_type").is_some()
{
Ok(BamlValue::Enum(name.to_string(), s.to_string()))
} else {
scope.push_error(format!(
Expand Down
151 changes: 115 additions & 36 deletions engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::collections::HashSet;

use anyhow::{anyhow, Result};
use baml_types::{
Constraint, ConstraintLevel, FieldType, JinjaExpression, Resolvable, StreamingBehavior, StringOr,
UnresolvedValue,
Constraint, ConstraintLevel, FieldType, JinjaExpression, Resolvable, StreamingBehavior,
StringOr, UnresolvedValue,
};
use either::Either;
use indexmap::{IndexMap, IndexSet};
Expand All @@ -15,7 +15,9 @@ use internal_baml_parser_database::{
Attributes, ParserDatabase, PromptAst, RetryPolicyStrategy, TypeWalker,
};

use internal_baml_schema_ast::ast::{self, Attribute, FieldArity, SubType, ValExpId, WithName, WithSpan};
use internal_baml_schema_ast::ast::{
self, Attribute, FieldArity, SubType, ValExpId, WithName, WithSpan,
};
use internal_llm_client::{ClientProvider, ClientSpec, UnresolvedClientProperty};
use serde::Serialize;

Expand Down Expand Up @@ -179,6 +181,9 @@ impl IntermediateRepr {
db: &ParserDatabase,
configuration: Configuration,
) -> Result<IntermediateRepr> {
// TODO: We're iterating over the AST tops once for every property in
// the IR. Easy performance optimization here by iterating only one time
// and distributing the tops to the appropriate IR properties.
let mut repr = IntermediateRepr {
enums: db
.walk_enums()
Expand Down Expand Up @@ -347,10 +352,7 @@ fn to_ir_attributes(
});
let streaming_done = streaming_done.as_ref().and_then(|v| {
if *v {
Some((
"stream.done".to_string(),
UnresolvedValue::Bool(true, ()),
))
Some(("stream.done".to_string(), UnresolvedValue::Bool(true, ())))
} else {
None
}
Expand Down Expand Up @@ -594,7 +596,6 @@ impl WithRepr<FieldType> for ast::FieldType {
),
};


let use_metadata = has_constraints || has_special_streaming_behavior;
let with_constraints = if use_metadata {
FieldType::WithMetadata {
Expand All @@ -609,30 +610,6 @@ impl WithRepr<FieldType> for ast::FieldType {
}
}

// #[derive(serde::Serialize, Debug)]
// pub enum Identifier {
// /// Starts with env.*
// ENV(String),
// /// The path to a Local Identifer + the local identifer. Separated by '.'
// #[allow(dead_code)]
// Ref(Vec<String>),
// /// A string without spaces or '.' Always starts with a letter. May contain numbers
// Local(String),
// /// Special types (always lowercase).
// Primitive(baml_types::TypeValue),
// }

// impl Identifier {
// pub fn name(&self) -> String {
// match self {
// Identifier::ENV(k) => k.clone(),
// Identifier::Ref(r) => r.join("."),
// Identifier::Local(l) => l.clone(),
// Identifier::Primitive(p) => p.to_string(),
// }
// }
// }

type TemplateStringId = String;

#[derive(Debug)]
Expand Down Expand Up @@ -717,7 +694,15 @@ impl WithRepr<Enum> for EnumWalker<'_> {

fn repr(&self, db: &ParserDatabase) -> Result<Enum> {
Ok(Enum {
name: self.name().to_string(),
// TODO: #1343 Temporary solution until we implement scoping in the AST.
name: if self.ast_type_block().is_dynamic_type_def {
self.name()
.strip_prefix(ast::DYNAMIC_TYPE_NAME_PREFIX)
.unwrap()
.to_string()
} else {
self.name().to_string()
},
values: self
.values()
.map(|w| {
Expand Down Expand Up @@ -803,7 +788,15 @@ impl WithRepr<Class> for ClassWalker<'_> {

fn repr(&self, db: &ParserDatabase) -> Result<Class> {
Ok(Class {
name: self.name().to_string(),
// TODO: #1343 Temporary solution until we implement scoping in the AST.
name: if self.ast_type_block().is_dynamic_type_def {
self.name()
.strip_prefix(ast::DYNAMIC_TYPE_NAME_PREFIX)
.unwrap()
.to_string()
} else {
self.name().to_string()
},
static_fields: self
.static_fields()
.map(|e| e.node(db))
Expand Down Expand Up @@ -1118,6 +1111,21 @@ impl WithRepr<RetryPolicy> for ConfigurationWalker<'_> {
}
}

// TODO: #1343 Temporary solution until we implement scoping in the AST.
#[derive(Debug)]
pub enum TypeBuilderEntry {
Enum(Node<Enum>),
Class(Node<Class>),
TypeAlias(Node<TypeAlias>),
}

// TODO: #1343 Temporary solution until we implement scoping in the AST.
#[derive(Debug)]
pub struct TestTypeBuilder {
pub entries: Vec<TypeBuilderEntry>,
pub structural_recursive_alias_cycles: Vec<IndexMap<String, FieldType>>,
}

#[derive(serde::Serialize, Debug)]
pub struct TestCaseFunction(String);

Expand All @@ -1133,6 +1141,7 @@ pub struct TestCase {
pub functions: Vec<Node<TestCaseFunction>>,
pub args: IndexMap<String, UnresolvedValue<()>>,
pub constraints: Vec<Constraint>,
pub type_builder: TestTypeBuilder,
}

impl WithRepr<TestCaseFunction> for (&ConfigurationWalker<'_>, usize) {
Expand Down Expand Up @@ -1180,6 +1189,69 @@ impl WithRepr<TestCase> for ConfigurationWalker<'_> {
let functions = (0..self.test_case().functions.len())
.map(|i| (self, i).node(db))
.collect::<Result<Vec<_>>>()?;

// TODO: #1343 Temporary solution until we implement scoping in the AST.
let enums = self
.test_case()
.type_builder_scoped_db
.walk_enums()
.filter(|e| {
self.test_case().type_builder_scoped_db.ast()[e.id].is_dynamic_type_def
|| db.find_type_by_str(e.name()).is_none()
})
.map(|e| e.node(&self.test_case().type_builder_scoped_db))
.collect::<Result<Vec<Node<Enum>>>>()?;
let classes = self
.test_case()
.type_builder_scoped_db
.walk_classes()
.filter(|c| {
self.test_case().type_builder_scoped_db.ast()[c.id].is_dynamic_type_def
|| db.find_type_by_str(c.name()).is_none()
})
.map(|c| c.node(&self.test_case().type_builder_scoped_db))
.collect::<Result<Vec<Node<Class>>>>()?;
let type_aliases = self
.test_case()
.type_builder_scoped_db
.walk_type_aliases()
.filter(|a| db.find_type_by_str(a.name()).is_none())
.map(|a| a.node(&self.test_case().type_builder_scoped_db))
.collect::<Result<Vec<Node<TypeAlias>>>>()?;
let mut type_builder_entries = Vec::new();

for e in enums {
type_builder_entries.push(TypeBuilderEntry::Enum(e));
}
for c in classes {
type_builder_entries.push(TypeBuilderEntry::Class(c));
}
for a in type_aliases {
type_builder_entries.push(TypeBuilderEntry::TypeAlias(a));
}

let mut recursive_aliases = vec![];
for cycle in self
.test_case()
.type_builder_scoped_db
.recursive_alias_cycles()
{
let mut component = IndexMap::new();
for id in cycle {
let alias = &self.test_case().type_builder_scoped_db.ast()[*id];
// Those are global cycles, skip.
if db.find_type_by_str(alias.name()).is_some() {
continue;
}
// Cycles defined in the scoped test type builder block.
component.insert(
alias.name().to_string(),
alias.value.repr(&self.test_case().type_builder_scoped_db)?,
);
}
recursive_aliases.push(component);
}

Ok(TestCase {
name: self.name().to_string(),
args: self
Expand All @@ -1195,9 +1267,14 @@ impl WithRepr<TestCase> for ConfigurationWalker<'_> {
.constraints
.into_iter()
.collect::<Vec<_>>(),
type_builder: TestTypeBuilder {
entries: type_builder_entries,
structural_recursive_alias_cycles: recursive_aliases,
},
})
}
}

#[derive(Debug, Clone, Serialize)]
pub enum Prompt {
// The prompt stirng, and a list of input replacer keys (raw key w/ magic string, and key to replace with)
Expand Down Expand Up @@ -1440,7 +1517,6 @@ mod tests {
let alias = class.find_field("field").unwrap();

assert_eq!(*alias.r#type(), FieldType::Primitive(TypeValue::Int));

}

#[test]
Expand All @@ -1461,7 +1537,10 @@ mod tests {
let class = ir.find_class("Test").unwrap();
let alias = class.find_field("field").unwrap();

let FieldType::WithMetadata { base, constraints, .. } = alias.r#type() else {
let FieldType::WithMetadata {
base, constraints, ..
} = alias.r#type()
else {
panic!(
"expected resolved constrained type, found {:?}",
alias.r#type()
Expand Down
19 changes: 17 additions & 2 deletions engine/baml-lib/baml-core/src/ir/walker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use internal_llm_client::ClientSpec;
use std::collections::{HashMap, HashSet};

use super::{
repr::{self, FunctionConfig, WithRepr},
Class, Client, Enum, EnumValue, Field, FunctionNode, IRHelper, Impl, RetryPolicy,
repr::{self, FunctionConfig, TypeBuilderEntry, WithRepr},
Class, Client, Enum, EnumValue, Field, FieldType, FunctionNode, IRHelper, Impl, RetryPolicy,
TemplateString, TestCase, TypeAlias, Walker,
};
use crate::ir::jinja_helpers::render_expression;
Expand Down Expand Up @@ -224,6 +224,21 @@ impl<'a> Walker<'a, (&'a FunctionNode, &'a TestCase)> {
.collect()
}

// TODO: #1343 Temporary solution until we implement scoping in the AST.
pub fn type_builder_contents(&self) -> &[TypeBuilderEntry] {
&self.item.1.elem.type_builder.entries
}

// TODO: #1343 Temporary solution until we implement scoping in the AST.
pub fn type_builder_recursive_aliases(&self) -> &[IndexMap<String, FieldType>] {
&self
.item
.1
.elem
.type_builder
.structural_recursive_alias_cycles
}

pub fn function(&'a self) -> Walker<'a, &'a FunctionNode> {
Walker {
db: self.db,
Expand Down
Loading

0 comments on commit 7f852d0

Please sign in to comment.