From 5753984c934414192bb44b8862350971e015647b Mon Sep 17 00:00:00 2001 From: Ivan Leo Date: Mon, 8 Jul 2024 12:13:12 +0800 Subject: [PATCH 1/3] Fixed up a new property called is_optional to represent Option types --- instruct-macros-types/src/lib.rs | 50 ++++++++++ instruct-macros/src/helpers/parameters.rs | 76 +++++++++++++++- instruct-macros/src/lib.rs | 2 + instruct-macros/tests/integration_test.rs | 13 +++ instruct-macros/tests/test_option.rs | 106 ++++++++++++++++++++++ 5 files changed, 242 insertions(+), 5 deletions(-) create mode 100644 instruct-macros/tests/test_option.rs diff --git a/instruct-macros-types/src/lib.rs b/instruct-macros-types/src/lib.rs index 79c7e40..68a365d 100644 --- a/instruct-macros-types/src/lib.rs +++ b/instruct-macros-types/src/lib.rs @@ -18,6 +18,28 @@ impl InstructMacroResult { InstructMacroResult::Enum(enum_info) => enum_info.wrap_info(new_name), } } + + pub fn override_description(self, new_description: String) -> InstructMacroResult { + match self { + InstructMacroResult::Struct(struct_info) => { + InstructMacroResult::Struct(struct_info.override_description(new_description)) + } + InstructMacroResult::Enum(enum_info) => { + InstructMacroResult::Enum(enum_info.override_description(new_description)) + } + } + } + + pub fn set_optional(self, is_optional: bool) -> InstructMacroResult { + match self { + InstructMacroResult::Struct(struct_info) => { + InstructMacroResult::Struct(struct_info.set_optional(is_optional)) + } + InstructMacroResult::Enum(enum_info) => { + InstructMacroResult::Enum(enum_info.set_optional(is_optional)) + } + } + } } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] @@ -25,6 +47,7 @@ pub struct StructInfo { pub name: String, pub description: String, pub parameters: Vec, + pub is_optional: bool, } impl StructInfo { @@ -32,6 +55,18 @@ impl StructInfo { self.name = new_name; Parameter::Struct(self) } + + pub fn override_description(mut self, new_description: String) -> StructInfo { + if new_description.len() > 0 { + self.description = new_description; + } + self + } + + pub fn set_optional(mut self, is_optional: bool) -> StructInfo { + self.is_optional = is_optional; + self + } } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] @@ -46,6 +81,7 @@ pub struct ParameterInfo { pub name: String, pub r#type: String, pub comment: String, + pub is_optional: bool, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] @@ -54,6 +90,7 @@ pub struct EnumInfo { pub r#enum: Vec, pub r#type: String, pub description: String, + pub is_optional: bool, } impl EnumInfo { @@ -61,6 +98,18 @@ impl EnumInfo { self.title = new_name; Parameter::Enum(self) } + + pub fn override_description(mut self, new_description: String) -> EnumInfo { + if new_description.len() > 0 { + self.description = new_description; + } + self + } + + pub fn set_optional(mut self, is_optional: bool) -> EnumInfo { + self.is_optional = is_optional; + self + } } pub struct FieldInfo { @@ -68,4 +117,5 @@ pub struct FieldInfo { pub description: String, pub r#type: String, pub is_complex: bool, + pub is_optional: bool, } diff --git a/instruct-macros/src/helpers/parameters.rs b/instruct-macros/src/helpers/parameters.rs index 69302a3..3737566 100644 --- a/instruct-macros/src/helpers/parameters.rs +++ b/instruct-macros/src/helpers/parameters.rs @@ -32,6 +32,12 @@ pub fn extract_parameter_information(fields: &syn::FieldsNamed) -> Vec Vec bool { "u8", "u16", "u32", "u64", "u128", "usize", ]; + let option_types: Vec = simple_types + .iter() + .map(|&t| format!("Option<{}>", t)) + .collect(); + if simple_types.contains(&field_type.as_str()) { return false; } - if field_type.starts_with("Vec<") { - let inner_type = &field_type[4..field_type.len() - 1]; - return simple_types.contains(&inner_type); + if option_types.contains(&field_type) { + return false; } true } +fn is_option_type(field_type: &str) -> bool { + field_type.starts_with("Option<") && field_type.ends_with(">") +} + +fn extract_nested_type(field_type: &str) -> String { + if is_option_type(field_type) { + field_type[7..field_type.len() - 1].to_string() + } else { + field_type.to_string() + } +} + pub fn extract_parameters(fields: &syn::FieldsNamed) -> Vec { extract_parameter_information(fields) .iter() @@ -69,6 +92,7 @@ pub fn extract_parameters(fields: &syn::FieldsNamed) -> Vec Vec", simple_type)), false); + } + + // Complex types + + assert_eq!(is_complex_type("Option".to_string()), true); + } + + #[test] + fn test_extract_nested_type() { + // Test cases for extract_nested_type function + let test_cases = vec![("Option", "User")]; + + for (input, expected) in test_cases { + assert_eq!(extract_nested_type(input), expected); + } + } +} diff --git a/instruct-macros/src/lib.rs b/instruct-macros/src/lib.rs index 7af5302..2e5eea4 100644 --- a/instruct-macros/src/lib.rs +++ b/instruct-macros/src/lib.rs @@ -40,6 +40,7 @@ fn generate_instruct_macro_enum(input: &DeriveInput) -> proc_macro2::TokenStream r#enum: vec![#(#enum_variants.to_string()),*], r#type: stringify!(#name).to_string(), description: #description.to_string(), + is_optional:false }) }; @@ -121,6 +122,7 @@ fn generate_instruct_macro_struct(input: &DeriveInput) -> proc_macro2::TokenStre name: stringify!(#name).to_string(), description: #description.to_string(), parameters, + is_optional:false }) } diff --git a/instruct-macros/tests/integration_test.rs b/instruct-macros/tests/integration_test.rs index 84ae34b..cd14f82 100644 --- a/instruct-macros/tests/integration_test.rs +++ b/instruct-macros/tests/integration_test.rs @@ -36,13 +36,16 @@ mod tests { name: "field1".to_string(), r#type: "String".to_string(), comment: "This is a sample example that spans across three lines".to_string(), + is_optional: false, }), Parameter::Field(ParameterInfo { name: "field2".to_string(), r#type: "str".to_string(), comment: "This is a test field".to_string(), + is_optional: false, }), ], + is_optional: false, }; let info_struct = match info { @@ -117,11 +120,13 @@ mod tests { name: "name".to_string(), r#type: "String".to_string(), comment: "".to_string(), + is_optional: false, }), Parameter::Field(ParameterInfo { name: "age".to_string(), r#type: "u8".to_string(), comment: "".to_string(), + is_optional: false, }), Parameter::Struct(StructInfo { name: "address".to_string(), @@ -131,15 +136,19 @@ mod tests { name: "street".to_string(), r#type: "String".to_string(), comment: "".to_string(), + is_optional: false, }), Parameter::Field(ParameterInfo { name: "city".to_string(), r#type: "String".to_string(), comment: "".to_string(), + is_optional: false, }), ], + is_optional: false, }), ], + is_optional: false, }; let info_struct = match info { @@ -171,6 +180,7 @@ mod tests { ], r#type: "Status".to_string(), description: "".to_string(), + is_optional: false, }; let info_enum = match info { @@ -209,6 +219,7 @@ mod tests { name: "name".to_string(), r#type: "String".to_string(), comment: "".to_string(), + is_optional: false, }), Parameter::Enum(EnumInfo { title: "status".to_string(), @@ -219,8 +230,10 @@ mod tests { ], r#type: "Status".to_string(), description: "This is an enum representing the status of a person".to_string(), + is_optional: false, }), ], + is_optional: false, }; let info_struct = match info { diff --git a/instruct-macros/tests/test_option.rs b/instruct-macros/tests/test_option.rs new file mode 100644 index 0000000..2e1a034 --- /dev/null +++ b/instruct-macros/tests/test_option.rs @@ -0,0 +1,106 @@ +extern crate instruct_macros_types; + +use instruct_macros::InstructMacro; +use instruct_macros_types::{ + InstructMacro, InstructMacroResult, Parameter, ParameterInfo, StructInfo, +}; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_option_type_support() { + #[derive(InstructMacro, Debug)] + #[allow(dead_code)] + #[description("This is a struct with Option types")] + struct TestOptionStruct { + #[description("This is an optional string field")] + pub field1: Option, + #[description("This is an optional integer field")] + pub field2: Option, + } + + let info = TestOptionStruct::get_info(); + let desired_struct = StructInfo { + name: "TestOptionStruct".to_string(), + description: "This is a struct with Option types".to_string(), + is_optional: false, + parameters: vec![ + Parameter::Field(ParameterInfo { + name: "field1".to_string(), + r#type: "Option".to_string(), + comment: "This is an optional string field".to_string(), + is_optional: true, + }), + Parameter::Field(ParameterInfo { + name: "field2".to_string(), + r#type: "Option".to_string(), + comment: "This is an optional integer field".to_string(), + is_optional: true, + }), + ], + }; + + let info_struct = match info { + InstructMacroResult::Struct(s) => s, + _ => panic!("Expected StructInfo"), + }; + + assert_eq!(info_struct, desired_struct); + } + + #[test] + fn test_option_maybe_struct() { + #[derive(InstructMacro, Debug)] + #[allow(dead_code)] + #[description("This is a user struct")] + struct User { + #[description("This is the user's name")] + pub name: String, + #[description("This is the user's age")] + pub age: i32, + } + + #[derive(InstructMacro, Debug)] + #[allow(dead_code)] + #[description("This is a struct with Option type")] + struct MaybeUser { + #[description("This is an optional user field")] + pub user: Option, + } + + let info = MaybeUser::get_info(); + let desired_struct = StructInfo { + name: "MaybeUser".to_string(), + description: "This is a struct with Option type".to_string(), + parameters: vec![Parameter::Struct(StructInfo { + name: "user".to_string(), + description: "This is an optional user field".to_string(), + parameters: vec![ + Parameter::Field(ParameterInfo { + name: "name".to_string(), + r#type: "String".to_string(), + comment: "This is the user's name".to_string(), + is_optional: false, + }), + Parameter::Field(ParameterInfo { + name: "age".to_string(), + r#type: "i32".to_string(), + comment: "This is the user's age".to_string(), + is_optional: false, + }), + ], + is_optional: true, + })], + is_optional: false, + }; + + let info_struct = match info { + InstructMacroResult::Struct(s) => s, + _ => panic!("Expected StructInfo"), + }; + + assert_eq!(info_struct, desired_struct); + } +} From 55ca8bd8509bdfed347765953097153d519f9721 Mon Sep 17 00:00:00 2001 From: Ivan Leo Date: Mon, 8 Jul 2024 12:42:39 +0800 Subject: [PATCH 2/3] Adding a test and support for option type parsing in instructor.rs' --- instructor/src/helpers/response_model.rs | 142 +++++++++++++++++++++-- instructor/src/lib.rs | 19 ++- instructor/tests/test_option.rs | 54 +++++++++ 3 files changed, 198 insertions(+), 17 deletions(-) create mode 100644 instructor/tests/test_option.rs diff --git a/instructor/src/helpers/response_model.rs b/instructor/src/helpers/response_model.rs index 8359149..4ae42f5 100644 --- a/instructor/src/helpers/response_model.rs +++ b/instructor/src/helpers/response_model.rs @@ -1,20 +1,27 @@ use std::collections::HashMap; -use instruct_macros_types::{Parameter, StructInfo}; +use instruct_macros_types::{Parameter, ParameterInfo, StructInfo}; use openai_api_rs::v1::chat_completion::{self, JSONSchemaDefine}; fn get_required_properties(info: &StructInfo) -> Vec { let mut required = Vec::new(); + for param in info.parameters.iter() { match param { Parameter::Field(field_info) => { - required.push(field_info.name.clone()); + if !field_info.is_optional { + required.push(field_info.name.clone()); + } } Parameter::Struct(struct_info) => { - required.push(struct_info.name.clone()); + if !struct_info.is_optional { + required.push(struct_info.name.clone()); + } } Parameter::Enum(enum_info) => { - required.push(enum_info.title.clone()); + if !enum_info.is_optional { + required.push(enum_info.title.clone()); + } } } } @@ -27,10 +34,19 @@ fn convert_parameter_type(info: &str) -> chat_completion::JSONSchemaType { "u8" | "i8" | "u16" | "i16" | "u32" | "i32" | "u64" | "i64" | "u128" | "i128" | "usize" | "isize" => chat_completion::JSONSchemaType::Number, "bool" => chat_completion::JSONSchemaType::Boolean, + _ => panic!("Unsupported type: {}", info), } } +fn get_base_type(field_info: &ParameterInfo) -> &str { + if field_info.r#type.starts_with("Option<") && field_info.r#type.ends_with('>') { + &field_info.r#type[7..field_info.r#type.len() - 1] + } else { + &field_info.r#type + } +} + fn get_response_model_parameters(t: &StructInfo) -> HashMap> { let mut properties = HashMap::new(); @@ -40,7 +56,8 @@ fn get_response_model_parameters(t: &StructInfo) -> HashMap, + } + + let struct_info = StructWithOptionalField::get_info(); + let parsed_model: StructInfo = match struct_info { + InstructMacroResult::Struct(info) => info, + _ => { + panic!("Expected StructInfo but got a different InstructMacroResult variant"); + } + }; + let parameters = get_response_model(parsed_model); + + let expected_parameters = chat_completion::FunctionParameters { + schema_type: chat_completion::JSONSchemaType::Object, + properties: Some({ + let mut props = std::collections::HashMap::new(); + props.insert( + "name".to_string(), + Box::new(chat_completion::JSONSchemaDefine { + schema_type: Some(chat_completion::JSONSchemaType::String), + description: Some("The name of the user".to_string()), + ..Default::default() + }), + ); + props.insert( + "age".to_string(), + Box::new(chat_completion::JSONSchemaDefine { + schema_type: Some(chat_completion::JSONSchemaType::Number), + description: Some("The age of the user".to_string()), + ..Default::default() + }), + ); + props + }), + required: Some(vec!["name".to_string()]), // Only "name" should be required + }; + + assert_eq!(expected_parameters, parameters); + } + + #[test] + fn test_struct_with_nested_optional_field() { + #[derive(InstructMacro, Debug, Serialize, Deserialize)] + struct User { + name: String, + age: u8, + } + + #[derive(InstructMacro, Debug, Serialize, Deserialize)] + struct MaybeUser { + user: Option, } + + let struct_info = MaybeUser::get_info(); + let parsed_model: StructInfo = match struct_info { + InstructMacroResult::Struct(info) => info, + _ => { + panic!("Expected StructInfo but got a different InstructMacroResult variant"); + } + }; + let parameters = get_response_model(parsed_model); + + let expected_parameters = chat_completion::FunctionParameters { + schema_type: chat_completion::JSONSchemaType::Object, + properties: Some({ + let mut props = std::collections::HashMap::new(); + props.insert( + "user".to_string(), + Box::new(chat_completion::JSONSchemaDefine { + schema_type: Some(chat_completion::JSONSchemaType::Object), + description: Some("".to_string()), + properties: Some({ + let mut user_props = std::collections::HashMap::new(); + user_props.insert( + "name".to_string(), + Box::new(chat_completion::JSONSchemaDefine { + schema_type: Some(chat_completion::JSONSchemaType::String), + description: Some("".to_string()), + ..Default::default() + }), + ); + user_props.insert( + "age".to_string(), + Box::new(chat_completion::JSONSchemaDefine { + schema_type: Some(chat_completion::JSONSchemaType::Number), + description: Some("".to_string()), + ..Default::default() + }), + ); + user_props + }), + ..Default::default() + }), + ); + props + }), + required: Some(vec![]), // No required fields + }; + assert_eq!(expected_parameters, parameters); } } diff --git a/instructor/src/lib.rs b/instructor/src/lib.rs index a1f3e71..31c09a8 100644 --- a/instructor/src/lib.rs +++ b/instructor/src/lib.rs @@ -6,7 +6,7 @@ use openai_api_rs::v1::{ error::APIError, }; -use instruct_macros_types::{InstructMacro, InstructMacroResult, StructInfo}; +use instruct_macros_types::{InstructMacro, InstructMacroResult, Parameter, StructInfo}; pub struct InstructorClient { client: Client, @@ -43,9 +43,12 @@ impl InstructorClient { name: None, }; req.messages.push(new_message); + + println!("Error encountered: {}", error); } let result = self._retry_sync::(req.clone(), parsed_model.clone()); + match result { Ok(value) => { match T::validate(&value) { @@ -85,10 +88,12 @@ impl InstructorClient { function: chat_completion::Function { name: parsed_model.name.clone(), description: Some(parsed_model.description.clone()), - parameters: helpers::get_response_model(parsed_model), + parameters: helpers::get_response_model(parsed_model.clone()), }, }; + let parameters_json = serde_json::to_string(&func_call.function).unwrap(); + let req = req .tools(vec![func_call]) .tool_choice(chat_completion::ToolChoiceType::Auto); @@ -104,18 +109,22 @@ impl InstructorClient { 1 => { let tool_call = &tool_calls[0]; let arguments = tool_call.function.arguments.clone().unwrap(); - return serde_json::from_str(&arguments); } _ => { // TODO: Support multiple tool calls at some point let error_message = - format!("Unexpected number of tool calls: {:?}", tool_calls); + format!("Unexpected number of tool calls: {:?}. PLease only generate a single tool call.", tool_calls); return Err(serde::de::Error::custom(error_message)); } } } - _ => panic!("Unexpected finish reason"), + _ => { + let error_message = + "You must call a tool. Make sure to adhere to the provided response format." + .to_string(); + return Err(serde::de::Error::custom(error_message)); + } } } } diff --git a/instructor/tests/test_option.rs b/instructor/tests/test_option.rs new file mode 100644 index 0000000..e7ae602 --- /dev/null +++ b/instructor/tests/test_option.rs @@ -0,0 +1,54 @@ +extern crate instruct_macros; +extern crate instruct_macros_types; + +use instruct_macros::InstructMacro; +use instruct_macros_types::{Parameter, ParameterInfo, StructInfo}; +use instructor_ai::from_openai; +use openai_api_rs::v1::api::Client; + +#[cfg(test)] +mod tests { + use std::env; + + use openai_api_rs::v1::{ + chat_completion::{self, ChatCompletionRequest}, + common::GPT4_O, + }; + use serde::{Deserialize, Serialize}; + + use super::*; + + #[test] + fn test_from_openai() { + let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let instructor_client = from_openai(client); + + #[derive(InstructMacro, Debug, Serialize, Deserialize)] + struct UserInfo { + name: String, + age: u8, + } + + #[derive(InstructMacro, Debug, Serialize, Deserialize)] + struct MaybeUser { + #[description("This is an optional user field. If the user is not present, the field will be null")] + user: Option, + } + + let req = ChatCompletionRequest::new( + GPT4_O.to_string(), + vec![chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::user, + content: chat_completion::Content::Text(String::from("It's a beautiful day out")), + name: None, + }], + ); + + let result = instructor_client + .chat_completion::(req, 3) + .unwrap(); + + println!("{:?}", result); + // assert!(result.user.is_none()); + } +} From 6f48976671bf77ccdf84c51d0f5d64ebecb88729 Mon Sep 17 00:00:00 2001 From: Ivan Leo Date: Mon, 8 Jul 2024 12:44:03 +0800 Subject: [PATCH 3/3] Added assertion for option test --- instructor/tests/test_option.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/instructor/tests/test_option.rs b/instructor/tests/test_option.rs index e7ae602..f904336 100644 --- a/instructor/tests/test_option.rs +++ b/instructor/tests/test_option.rs @@ -48,7 +48,6 @@ mod tests { .chat_completion::(req, 3) .unwrap(); - println!("{:?}", result); - // assert!(result.user.is_none()); + assert!(result.user.is_none()); } }