From 5fb25426832195a81b667f352d6f41aa49f5cfd0 Mon Sep 17 00:00:00 2001 From: tghamm Date: Sun, 1 Dec 2024 07:18:00 -0500 Subject: [PATCH 1/8] function calling hello world --- Mistral.SDK.Tests/FunctionCalling.cs | 45 ++ Mistral.SDK/Common/Function.cs | 459 ++++++++++++++++++ Mistral.SDK/Common/FunctionAttribute.cs | 17 + .../Common/FunctionParameterAttribute.cs | 23 + .../Common/FunctionPropertyAttribute.cs | 53 ++ Mistral.SDK/Common/Tool.cs | 430 ++++++++++++++++ .../Completions/CompletionsEndpoint.cs | 22 + ...kJsonOption.cs => MistralSdkJsonOption.cs} | 13 +- .../Converters/ToolChoiceTypeConverter.cs | 36 ++ Mistral.SDK/DTOs/ChatCompletionRequest.cs | 17 +- Mistral.SDK/DTOs/ChatCompletionResponse.cs | 3 + Mistral.SDK/DTOs/ChatMessage.cs | 28 +- Mistral.SDK/DTOs/Choice.cs | 5 +- Mistral.SDK/DTOs/Tool.cs | 85 ++++ Mistral.SDK/DTOs/ToolCall.cs | 29 ++ Mistral.SDK/DTOs/ToolChoiceType.cs | 20 + Mistral.SDK/Embeddings/EmbeddingsEndpoint.cs | 2 +- Mistral.SDK/EndpointBase.cs | 8 +- .../ChatCompletionRequestConverter.cs | 43 ++ Mistral.SDK/Extensions/TypeExtensions.cs | 276 +++++++++++ Mistral.SDK/Models/ModelsEndpoint.cs | 2 +- 21 files changed, 1603 insertions(+), 13 deletions(-) create mode 100644 Mistral.SDK.Tests/FunctionCalling.cs create mode 100644 Mistral.SDK/Common/Function.cs create mode 100644 Mistral.SDK/Common/FunctionAttribute.cs create mode 100644 Mistral.SDK/Common/FunctionParameterAttribute.cs create mode 100644 Mistral.SDK/Common/FunctionPropertyAttribute.cs create mode 100644 Mistral.SDK/Common/Tool.cs rename Mistral.SDK/Converters/{MistalSdkJsonOption.cs => MistralSdkJsonOption.cs} (68%) create mode 100644 Mistral.SDK/Converters/ToolChoiceTypeConverter.cs create mode 100644 Mistral.SDK/DTOs/Tool.cs create mode 100644 Mistral.SDK/DTOs/ToolCall.cs create mode 100644 Mistral.SDK/DTOs/ToolChoiceType.cs create mode 100644 Mistral.SDK/Extensions/ChatCompletionRequestConverter.cs create mode 100644 Mistral.SDK/Extensions/TypeExtensions.cs diff --git a/Mistral.SDK.Tests/FunctionCalling.cs b/Mistral.SDK.Tests/FunctionCalling.cs new file mode 100644 index 0000000..af3c666 --- /dev/null +++ b/Mistral.SDK.Tests/FunctionCalling.cs @@ -0,0 +1,45 @@ +using Mistral.SDK.Common; +using Mistral.SDK.DTOs; + +namespace Mistral.SDK.Tests +{ + [TestClass] + public class FunctionCalling + { + [TestMethod] + public async Task TestBasicFunction() + { + var client = new MistralClient(); + var messages = new List() + { + new ChatMessage(ChatMessage.RoleEnum.User, "What is the current weather in San Francisco?") + }; + var request = new ChatCompletionRequest("mistral-large-latest", messages); + + request.ToolChoice = ToolChoiceType.Auto; + + var tools = new List + { + Common.Tool.FromFunc("Get_Weather", + ([FunctionParameter("Location of the weather", true)]string location)=> "72 degrees and sunny") + }; + + + request.Tools = tools; + + var response = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + messages.Add(response.Choices.First().Message); + + foreach (var toolCall in response.ToolCalls) + { + var resp = toolCall.Invoke(); + messages.Add(new ChatMessage(toolCall, resp)); + } + + var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("72")); + } + } +} diff --git a/Mistral.SDK/Common/Function.cs b/Mistral.SDK/Common/Function.cs new file mode 100644 index 0000000..3a7cf2f --- /dev/null +++ b/Mistral.SDK/Common/Function.cs @@ -0,0 +1,459 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using System.Text.Json; +using System.Threading.Tasks; +using System.Threading; +using Mistral.SDK.Converters; +using Mistral.SDK.Extensions; + +namespace Mistral.SDK.Common +{ + public sealed class Function + { + public Function() { } + + private const string NameRegex = "^[a-zA-Z0-9_-]{1,64}$"; + + + + /// + /// Creates a new function description to insert into a chat conversation. + /// + /// + /// Required. The name of the function to generate arguments for based on the context in a message.
+ /// May contain a-z, A-Z, 0-9, underscores and dashes, with a maximum length of 64 characters. Recommended to not begin with a number or a dash. + /// + /// + /// An optional description of the function, used by the API to determine if it is useful to include in the response. + /// + /// + /// An optional JSON object describing the parameters of the function that the model can generate. + /// + public Function(string name, string description = null, JsonNode parameters = null) + { + if (!System.Text.RegularExpressions.Regex.IsMatch(name, NameRegex)) + { + throw new ArgumentException($"The name of the function does not conform to naming standards: {NameRegex}"); + } + + Name = name; + Description = description; + Parameters = parameters; + functionCache[Name] = this; + } + + public Function(string name, string type, Dictionary additionalData) + { + if (!System.Text.RegularExpressions.Regex.IsMatch(name, NameRegex)) + { + throw new ArgumentException($"The name of the function does not conform to naming standards: {NameRegex}"); + } + + Name = name; + Type = type; + Description = null; + Parameters = null; + AdditionalData = additionalData.Select(p => + new KeyValuePair(p.Key, JsonSerializer.SerializeToElement(p.Value))) + .ToDictionary(pair => pair.Key, pair => pair.Value); + functionCache[Name] = this; + } + + /// + /// Creates a new function description to insert into a chat conversation. + /// + /// + /// Required. The name of the function to generate arguments for based on the context in a message.
+ /// May contain a-z, A-Z, 0-9, underscores and dashes, with a maximum length of 64 characters. Recommended to not begin with a number or a dash. + /// + /// + /// An optional description of the function, used by the API to determine if it is useful to include in the response. + /// + /// + /// An optional JSON describing the parameters of the function that the model can generate. + /// + public Function(string name, string description, string parameters) + { + if (!System.Text.RegularExpressions.Regex.IsMatch(name, NameRegex)) + { + throw new ArgumentException($"The name of the function does not conform to naming standards: {NameRegex}"); + } + + Name = name; + Description = description; + Parameters = JsonNode.Parse(parameters); + functionCache[Name] = this; + } + + + + + internal Function(string name, string description, MethodInfo method, object instance = null) + { + if (!System.Text.RegularExpressions.Regex.IsMatch(name, NameRegex)) + { + throw new ArgumentException($"The name of the function does not conform to naming standards: {NameRegex}"); + } + + Name = name; + Description = description; + MethodInfo = method; + Parameters = method.GenerateJsonSchema(); + Instance = instance; + functionCache[Name] = this; + } + + #region Func<,> Overloads + + public static Function FromFunc(string name, Func function, string description = null) + => new(name, description, function.Method, function.Target); + + public static Function FromFunc(string name, Func function, string description = null) + => new(name, description, function.Method, function.Target); + + public static Function FromFunc(string name, Func function, string description = null) + => new(name, description, function.Method, function.Target); + + public static Function FromFunc(string name, Func function, string description = null) + => new(name, description, function.Method, function.Target); + + public static Function FromFunc(string name, Func function, string description = null) + => new(name, description, function.Method, function.Target); + + public static Function FromFunc(string name, Func function, string description = null) + => new(name, description, function.Method, function.Target); + + public static Function FromFunc(string name, Func function, string description = null) + => new(name, description, function.Method, function.Target); + + public static Function FromFunc(string name, Func function, string description = null) + => new(name, description, function.Method, function.Target); + + public static Function FromFunc(string name, Func function, string description = null) + => new(name, description, function.Method, function.Target); + + public static Function FromFunc(string name, Func function, string description = null) + => new(name, description, function.Method, function.Target); + + public static Function FromFunc(string name, Func function, string description = null) + => new(name, description, function.Method, function.Target); + + public static Function FromFunc(string name, Func function, string description = null) + => new(name, description, function.Method, function.Target); + + #endregion Func<,> Overloads + + internal Function(Function other) => CopyFrom(other); + + /// + /// The name of the function to generate arguments for.
+ /// May contain a-z, A-Z, 0-9, and underscores and dashes, with a maximum length of 64 characters. + /// Recommended to not begin with a number or a dash. + ///
+ [JsonInclude] + [JsonPropertyName("name")] + public string Name { get; private set; } + + [JsonInclude] + [JsonPropertyName("type")] + public string Type { get; private set; } + + [JsonExtensionData] + public Dictionary AdditionalData { get; set; } = new Dictionary(); + + + /// + /// Id to Send to the API. + /// + [JsonIgnore] + [JsonPropertyName("id")] + public string Id { get; set; } + + /// + /// The optional description of the function. + /// + [JsonInclude] + [JsonPropertyName("description")] + public string Description { get; private set; } + + private string parametersString; + + private JsonNode parameters; + + /// + /// The optional parameters of the function. + /// Describe the parameters that the model should generate in JSON schema format (json-schema.org). + /// + [JsonInclude] + [JsonPropertyName("parameters")] + public JsonNode Parameters + { + get + { + if (parameters == null && + !string.IsNullOrWhiteSpace(parametersString)) + { + parameters = JsonNode.Parse(parametersString); + } + + return parameters; + } + private set => parameters = value; + } + + private string argumentsString; + + private JsonNode arguments; + + /// + /// The arguments to use when calling the function. + /// + [JsonIgnore] + [JsonPropertyName("arguments")] + public JsonNode Arguments + { + get + { + if (arguments == null && + !string.IsNullOrWhiteSpace(argumentsString)) + { + arguments = JsonValue.Create(argumentsString); + } + + return arguments; + } + internal set => arguments = value; + } + + /// + /// The instance of the object to invoke the method on. + /// + [JsonIgnore] + internal object Instance { get; } + + + + /// + /// The method to invoke. + /// + [JsonIgnore] + public MethodInfo MethodInfo { get; } + + internal void CopyFrom(Function other) + { + if (!string.IsNullOrWhiteSpace(other.Name)) + { + Name = other.Name; + } + + if (!string.IsNullOrWhiteSpace(other.Description)) + { + Description = other.Description; + } + + if (other.Arguments != null) + { + argumentsString += other.Arguments.ToString(); + } + + if (other.Parameters != null) + { + parametersString += other.Parameters.ToString(); + } + + if (other.AdditionalData != null) + { + AdditionalData = other.AdditionalData; + } + + if (other.Type != null) + { + Type = other.Type; + } + } + + #region Function Invoking Utilities + + private static readonly ConcurrentDictionary functionCache = new(); + + /// + /// Invokes the function and returns the result as json. + /// + /// The result of the function as json. + public string Invoke() + { + try + { + var (function, invokeArgs) = ValidateFunctionArguments(); + + if (function.MethodInfo.ReturnType == typeof(void)) + { + function.MethodInfo.Invoke(function.Instance, invokeArgs); + return "{\"result\": \"success\"}"; + } + + var result = Invoke(); + return JsonSerializer.Serialize(new { result }); + } + catch (Exception e) + { + Console.WriteLine(e); + return JsonSerializer.Serialize(new { error = e.Message }); + } + } + + /// + /// Invokes the function and returns the result. + /// + /// The expected return type. + /// The result of the function. + public T Invoke() + { + try + { + var (function, invokeArgs) = ValidateFunctionArguments(); + var result = function.MethodInfo.Invoke(function.Instance, invokeArgs); + return result == null ? default : (T)result; + } + catch (Exception e) + { + Console.WriteLine(e); + throw; + } + } + + /// + /// Invokes the function and returns the result as json. + /// + /// Optional, . + /// The result of the function as json. + public async Task InvokeAsync(CancellationToken cancellationToken = default) + { + try + { + var (function, invokeArgs) = ValidateFunctionArguments(cancellationToken); + + if (function.MethodInfo.ReturnType == typeof(Task)) + { + if (function.MethodInfo.Invoke(function.Instance, invokeArgs) is not Task task) + { + throw new InvalidOperationException($"The function {Name} did not return a valid Task."); + } + + await task; + return "{\"result\": \"success\"}"; + } + + var result = await InvokeAsync(cancellationToken); + return JsonSerializer.Serialize(new { result }); + } + catch (Exception e) + { + Console.WriteLine(e); + return JsonSerializer.Serialize(new { error = e.Message }); + } + } + + /// + /// Invokes the function and returns the result. + /// + /// Expected return type. + /// Optional, . + /// The result of the function. + public async Task InvokeAsync(CancellationToken cancellationToken = default) + { + try + { + var (function, invokeArgs) = ValidateFunctionArguments(cancellationToken); + + if (function.MethodInfo.Invoke(function.Instance, invokeArgs) is not Task task) + { + throw new InvalidOperationException($"The function {Name} did not return a valid Task."); + } + + await task; + // ReSharper disable once InconsistentNaming + const string Result = nameof(Result); + var resultProperty = task.GetType().GetProperty(Result); + return (T)resultProperty?.GetValue(task); + } + catch (Exception e) + { + Console.WriteLine(e); + throw; + } + } + + private (Function function, object[] invokeArgs) ValidateFunctionArguments(CancellationToken cancellationToken = default) + { + if (Parameters != null && Parameters.AsObject().Count > 0 && Arguments == null) + { + throw new ArgumentException($"Function {Name} has parameters but no arguments are set."); + } + + if (!functionCache.TryGetValue(Name, out var function)) + { + throw new InvalidOperationException($"Failed to find a valid function for {Name}"); + } + + if (function.MethodInfo == null) + { + throw new InvalidOperationException($"Failed to find a valid method for {Name}"); + } + + var requestedArgs = arguments != null + ? JsonSerializer.Deserialize>(Arguments.ToString()) + : new(); + var methodParams = function.MethodInfo.GetParameters(); + var invokeArgs = new object[methodParams.Length]; + + for (var i = 0; i < methodParams.Length; i++) + { + var parameter = methodParams[i]; + + if (parameter.Name == null) + { + throw new InvalidOperationException($"Failed to find a valid parameter name for {function.MethodInfo.DeclaringType}.{function.MethodInfo.Name}()"); + } + + if (requestedArgs.TryGetValue(parameter.Name, out var value)) + { + if (parameter.ParameterType == typeof(CancellationToken)) + { + invokeArgs[i] = cancellationToken; + } + else if (value is string @enum && parameter.ParameterType.IsEnum) + { + invokeArgs[i] = Enum.Parse(parameter.ParameterType, @enum, true); + } + else if (value is JsonElement element) + { + invokeArgs[i] = JsonSerializer.Deserialize(element.GetRawText(), parameter.ParameterType, MistralSdkJsonOption.Options); + } + else + { + invokeArgs[i] = value; + } + } + else if (parameter.HasDefaultValue) + { + invokeArgs[i] = parameter.DefaultValue; + } + else + { + throw new ArgumentException($"Missing argument for parameter '{parameter.Name}'"); + } + } + + return (function, invokeArgs); + } + + #endregion Function Invoking Utilities + } +} diff --git a/Mistral.SDK/Common/FunctionAttribute.cs b/Mistral.SDK/Common/FunctionAttribute.cs new file mode 100644 index 0000000..722828f --- /dev/null +++ b/Mistral.SDK/Common/FunctionAttribute.cs @@ -0,0 +1,17 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Mistral.SDK.Common +{ + [AttributeUsage(AttributeTargets.Method, Inherited = false)] + public sealed class FunctionAttribute : Attribute + { + public FunctionAttribute(string description = null) + { + Description = description; + } + + public string Description { get; } + } +} diff --git a/Mistral.SDK/Common/FunctionParameterAttribute.cs b/Mistral.SDK/Common/FunctionParameterAttribute.cs new file mode 100644 index 0000000..34d2096 --- /dev/null +++ b/Mistral.SDK/Common/FunctionParameterAttribute.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Mistral.SDK.Common +{ + [AttributeUsage(AttributeTargets.Parameter)] + public sealed class FunctionParameterAttribute : Attribute + { + /// + /// Function parameter attribute to help describe the parameter for the function. + /// + /// The description of the parameter and its usage. + public FunctionParameterAttribute(string description, bool required) + { + Description = description; + Required = required; + } + + public string Description { get; } + public bool Required { get; set; } + } +} diff --git a/Mistral.SDK/Common/FunctionPropertyAttribute.cs b/Mistral.SDK/Common/FunctionPropertyAttribute.cs new file mode 100644 index 0000000..6004efe --- /dev/null +++ b/Mistral.SDK/Common/FunctionPropertyAttribute.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Mistral.SDK.Common +{ + [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)] + public sealed class FunctionPropertyAttribute : Attribute + { + /// + /// Property Attribute to help with function calling. + /// + /// + /// The description of the property + /// + /// + /// Is the property required? + /// + /// + /// The default value. + /// + /// + /// Enums or other possible values. + /// + public FunctionPropertyAttribute(string description = null, bool required = false, object defaultValue = null, params object[] possibleValues) + { + Description = description; + Required = required; + DefaultValue = defaultValue; + PossibleValues = possibleValues; + } + + /// + /// The description of the property + /// + public string Description { get; } + + /// + /// Is the property required? + /// + public bool Required { get; } + + /// + /// The default value. + /// + public object DefaultValue { get; } + + /// + /// Enums or other possible values. + /// + public object[] PossibleValues { get; } + } +} diff --git a/Mistral.SDK/Common/Tool.cs b/Mistral.SDK/Common/Tool.cs new file mode 100644 index 0000000..25f5bdb --- /dev/null +++ b/Mistral.SDK/Common/Tool.cs @@ -0,0 +1,430 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text.Json.Serialization; +using System.Threading.Tasks; +using System.Threading; + +namespace Mistral.SDK.Common +{ + public sealed class Tool + { + public Tool() { } + + public Tool(Tool other) => CopyFrom(other); + + public Tool(Function function) + { + Function = function; + Type = nameof(function); + } + + public static implicit operator Tool(Function function) => new(function); + + public static Tool Retrieval { get; } = new() { Type = "retrieval" }; + + public static Tool CodeInterpreter { get; } = new() { Type = "code_interpreter" }; + + [JsonInclude] + [JsonPropertyName("id")] + public string Id { get; private set; } + + [JsonInclude] + [JsonPropertyName("index")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? Index { get; private set; } + + [JsonInclude] + [JsonPropertyName("type")] + public string Type { get; private set; } = "function"; + + [JsonInclude] + [JsonPropertyName("function")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public Function Function { get; private set; } + + internal void CopyFrom(Tool other) + { + if (!string.IsNullOrWhiteSpace(other?.Id)) + { + Id = other.Id; + } + + if (other is { Index: not null }) + { + Index = other.Index.Value; + } + + if (!string.IsNullOrWhiteSpace(other?.Type)) + { + Type = other.Type; + } + + if (other?.Function != null) + { + if (Function == null) + { + Function = new Function(other.Function); + } + else + { + Function.CopyFrom(other.Function); + } + } + } + + /// + /// Invokes the function and returns the result as json. + /// + /// The result of the function as json. + public string InvokeFunction() => Function.Invoke(); + + /// + /// Invokes the function and returns the result. + /// + /// The type to deserialize the result to. + /// The result of the function. + public T InvokeFunction() => Function.Invoke(); + + /// + /// Invokes the function and returns the result as json. + /// + /// Optional, A token to cancel the request. + /// The result of the function as json. + public async Task InvokeFunctionAsync(CancellationToken cancellationToken = default) + => await Function.InvokeAsync(cancellationToken).ConfigureAwait(false); + + /// + /// Invokes the function and returns the result. + /// + /// The type to deserialize the result to. + /// Optional, A token to cancel the request. + /// The result of the function. + public async Task InvokeFunctionAsync(CancellationToken cancellationToken = default) + => await Function.InvokeAsync(cancellationToken).ConfigureAwait(false); + + private static readonly List toolCache = new() + { + Retrieval, + CodeInterpreter + }; + + /// + /// Clears the tool cache of all previously registered tools. + /// + public static void ClearRegisteredTools() + { + toolCache.Clear(); + toolCache.Add(CodeInterpreter); + toolCache.Add(Retrieval); + } + + /// + /// Checks if tool exists in cache. + /// + /// The tool to check. + /// True, if the tool is already registered in the tool cache. + public static bool IsToolRegistered(Tool tool) + => toolCache.Any(knownTool => + knownTool.Type == "function" && + knownTool.Function.Name == tool.Function.Name && + ReferenceEquals(knownTool.Function.Instance, tool.Function.Instance)); + + /// + /// Tries to register a tool into the Tool cache. + /// + /// The tool to register. + /// True, if the tool was added to the cache. + public static bool TryRegisterTool(Tool tool) + { + if (IsToolRegistered(tool)) + { + return false; + } + + if (tool.Type != "function") + { + throw new InvalidOperationException("Only function tools can be registered."); + } + + toolCache.Add(tool); + return true; + + } + + private static bool TryGetTool(string name, object instance, out Tool tool) + { + foreach (var knownTool in toolCache.Where(knownTool => + knownTool.Type == "function" && + knownTool.Function.Name == name && + ReferenceEquals(knownTool, instance))) + { + tool = knownTool; + return true; + } + + tool = null; + return false; + } + + /// + /// Gets a list of all available tools. + /// + /// + /// This method will scan all assemblies for static methods decorated with the . + /// + /// Optional, Whether to include the default tools (Retrieval and CodeInterpreter). + /// Optional, Whether to force an update of the tool cache. + /// Optional, whether to force the tool cache to be cleared before updating. + /// A list of all available tools. + public static IReadOnlyList GetAllAvailableTools(bool includeDefaults = true, bool forceUpdate = false, bool clearCache = false) + { + if (clearCache) + { + ClearRegisteredTools(); + } + + if (forceUpdate || toolCache.All(tool => tool.Type != "function")) + { + var tools = new List(); + tools.AddRange( + from assembly in AppDomain.CurrentDomain.GetAssemblies() + from type in assembly.GetTypes() + from method in type.GetMethods() + where method.IsStatic + let functionAttribute = method.GetCustomAttribute() + where functionAttribute != null + let name = $"{type.FullName}.{method.Name}".Replace('.', '_') + let description = functionAttribute.Description + select new Function(name, description, method) + into function + select new Tool(function)); + + foreach (var newTool in tools.Where(tool => + !toolCache.Any(knownTool => + knownTool.Type == "function" && knownTool.Function.Name == tool.Function.Name && knownTool.Function.Instance == null))) + { + toolCache.Add(newTool); + } + } + + return !includeDefaults + ? toolCache.Where(tool => tool.Type == "function").ToList() + : toolCache; + } + + /// + /// Get or create a tool from a static method. + /// + /// + /// If the tool already exists, it will be returned. Otherwise, a new tool will be created.
+ /// The method doesn't need to be decorated with the .
+ ///
+ /// The type containing the static method. + /// The name of the method. + /// Optional, The description of the method. + /// The tool for the method. + public static Tool GetOrCreateTool(Type type, string methodName, string description = null) + { + var method = type.GetMethod(methodName) ?? + throw new InvalidOperationException($"Failed to find a valid method for {type.FullName}.{methodName}()"); + + if (!method.IsStatic) + { + throw new InvalidOperationException($"Method {type.FullName}.{methodName}() must be static. Use GetOrCreateTool(object instance, string methodName) instead."); + } + + var functionName = $"{type.FullName}.{method.Name}".Replace('.', '_').Replace("+", "__"); + + if (TryGetTool(functionName, null, out var tool)) + { + return tool; + } + + tool = new Tool(new Function(functionName, description ?? string.Empty, method)); + toolCache.Add(tool); + return tool; + } + + /// + /// Get or create a tool from a method of an instance of an object. + /// + /// + /// If the tool already exists, it will be returned. Otherwise, a new tool will be created.
+ /// The method doesn't need to be decorated with the .
+ ///
+ /// The instance of the object containing the method. + /// The name of the method. + /// Optional, The description of the method. + /// The tool for the method. + public static Tool GetOrCreateTool(object instance, string methodName, string description = null) + { + var type = instance.GetType(); + var method = type.GetMethod(methodName) ?? + throw new InvalidOperationException($"Failed to find a valid method for {type.FullName}.{methodName}()"); + + var functionName = $"{type.FullName}.{method.Name}".Replace('.', '_').Replace("+", "__"); ; + + if (TryGetTool(functionName, instance, out var tool)) + { + return tool; + } + + tool = new Tool(new Function(functionName, description ?? string.Empty, method, instance)); + toolCache.Add(tool); + return tool; + } + + #region Func<,> Overloads + + public static Tool FromFunc(string name, Func function, string description = null) + { + if (TryGetTool(name, function, out var tool)) + { + return tool; + } + + tool = new Tool(Function.FromFunc(name, function, description)); + toolCache.Add(tool); + return tool; + } + + public static Tool FromFunc(string name, Func function, string description = null) + { + if (TryGetTool(name, function, out var tool)) + { + return tool; + } + + tool = new Tool(Function.FromFunc(name, function, description)); + toolCache.Add(tool); + return tool; + } + + public static Tool FromFunc(string name, Func function, + string description = null) + { + if (TryGetTool(name, function, out var tool)) + { + return tool; + } + + tool = new Tool(Function.FromFunc(name, function, description)); + toolCache.Add(tool); + return tool; + } + + public static Tool FromFunc(string name, Func function, + string description = null) + { + if (TryGetTool(name, function, out var tool)) + { + return tool; + } + + tool = new Tool(Function.FromFunc(name, function, description)); + toolCache.Add(tool); + return tool; + } + + public static Tool FromFunc(string name, Func function, string description = null) + { + if (TryGetTool(name, function, out var tool)) + { + return tool; + } + + tool = new Tool(Function.FromFunc(name, function, description)); + toolCache.Add(tool); + return tool; + } + + public static Tool FromFunc(string name, Func function, string description = null) + { + if (TryGetTool(name, function, out var tool)) + { + return tool; + } + + tool = new Tool(Function.FromFunc(name, function, description)); + toolCache.Add(tool); + return tool; + } + + public static Tool FromFunc(string name, Func function, string description = null) + { + if (TryGetTool(name, function, out var tool)) + { + return tool; + } + + tool = new Tool(Function.FromFunc(name, function, description)); + toolCache.Add(tool); + return tool; + } + + public static Tool FromFunc(string name, Func function, string description = null) + { + if (TryGetTool(name, function, out var tool)) + { + return tool; + } + + tool = new Tool(Function.FromFunc(name, function, description)); + toolCache.Add(tool); + return tool; + } + + public static Tool FromFunc(string name, Func function, string description = null) + { + if (TryGetTool(name, function, out var tool)) + { + return tool; + } + + tool = new Tool(Function.FromFunc(name, function, description)); + toolCache.Add(tool); + return tool; + } + + public static Tool FromFunc(string name, Func function, string description = null) + { + if (TryGetTool(name, function, out var tool)) + { + return tool; + } + + tool = new Tool(Function.FromFunc(name, function, description)); + toolCache.Add(tool); + return tool; + } + + public static Tool FromFunc(string name, Func function, string description = null) + { + if (TryGetTool(name, function, out var tool)) + { + return tool; + } + + tool = new Tool(Function.FromFunc(name, function, description)); + toolCache.Add(tool); + return tool; + } + + + public static Tool FromFunc(string name, Func function, string description = null) + { + if (TryGetTool(name, function, out var tool)) + { + return tool; + } + + tool = new Tool(Function.FromFunc(name, function, description)); + toolCache.Add(tool); + return tool; + } + + #endregion Func<,> Overloads + } +} diff --git a/Mistral.SDK/Completions/CompletionsEndpoint.cs b/Mistral.SDK/Completions/CompletionsEndpoint.cs index 4fa9b45..aeed83e 100644 --- a/Mistral.SDK/Completions/CompletionsEndpoint.cs +++ b/Mistral.SDK/Completions/CompletionsEndpoint.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Net.Http; +using System.Net.Mime; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -28,7 +29,28 @@ internal CompletionsEndpoint(MistralClient client) : base(client) { } public async Task GetCompletionAsync(ChatCompletionRequest request, CancellationToken cancellationToken = default) { request.Stream = false; + var response = await HttpRequest(Url, HttpMethod.Post, request, cancellationToken).ConfigureAwait(false); + + var toolCalls = new List(); + foreach (var message in response.Choices) + { + if (message.Message.ToolCalls is null) continue; + foreach (var returned_tool in message.Message.ToolCalls) + { + var tool = request.Tools?.FirstOrDefault(t => t.Function.Name == returned_tool.Function.Name); + if (tool != null) + { + tool.Function.Arguments = returned_tool.Function.Arguments; + tool.Function.Id = returned_tool.Id; + toolCalls.Add(tool.Function); + } + } + } + response.ToolCalls = toolCalls; + + + return response; } diff --git a/Mistral.SDK/Converters/MistalSdkJsonOption.cs b/Mistral.SDK/Converters/MistralSdkJsonOption.cs similarity index 68% rename from Mistral.SDK/Converters/MistalSdkJsonOption.cs rename to Mistral.SDK/Converters/MistralSdkJsonOption.cs index 9000173..e1db8dc 100644 --- a/Mistral.SDK/Converters/MistalSdkJsonOption.cs +++ b/Mistral.SDK/Converters/MistralSdkJsonOption.cs @@ -1,10 +1,11 @@ -using System.Text.Json; +using System.Collections.Generic; +using System.Text.Json; using System.Text.Json.Serialization; using Mistral.SDK.DTOs; namespace Mistral.SDK.Converters; -public static class MistalSdkJsonOption +public static class MistralSdkJsonOption { #if NET8_0_OR_GREATER @@ -29,6 +30,14 @@ public static class MistalSdkJsonOption [JsonSerializable(typeof(ModelList))] [JsonSerializable(typeof(ResponseFormat))] [JsonSerializable(typeof(Usage))] +[JsonSerializable(typeof(Common.Function))] +[JsonSerializable(typeof(List))] +[JsonSerializable(typeof(Common.Tool))] +[JsonSerializable(typeof(List))] +[JsonSerializable(typeof(List))] +[JsonSerializable(typeof(decimal?))] +[JsonSerializable(typeof(bool?))] +[JsonSerializable(typeof(ToolChoiceType))] public sealed partial class JsonContext : JsonSerializerContext; #endif \ No newline at end of file diff --git a/Mistral.SDK/Converters/ToolChoiceTypeConverter.cs b/Mistral.SDK/Converters/ToolChoiceTypeConverter.cs new file mode 100644 index 0000000..b54fbfc --- /dev/null +++ b/Mistral.SDK/Converters/ToolChoiceTypeConverter.cs @@ -0,0 +1,36 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Text.Json.Serialization; +using System.Text.Json; +using Mistral.SDK.DTOs; + +namespace Mistral.SDK.Converters +{ + public class ToolChoiceTypeConverter : JsonConverter + { + public override ToolChoiceType Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + string value = reader.GetString(); + return value switch + { + "auto" => ToolChoiceType.Auto, + "any" => ToolChoiceType.Any, + "none" => ToolChoiceType.none, + _ => throw new JsonException($"Unknown tool choice type: {value}") + }; + } + + public override void Write(Utf8JsonWriter writer, ToolChoiceType value, JsonSerializerOptions options) + { + string roleString = value switch + { + ToolChoiceType.Auto => "auto", + ToolChoiceType.Any => "any", + ToolChoiceType.none => "none", + _ => throw new InvalidOperationException("Invalid tool choice type") + }; + writer.WriteStringValue(roleString); + } + } +} diff --git a/Mistral.SDK/DTOs/ChatCompletionRequest.cs b/Mistral.SDK/DTOs/ChatCompletionRequest.cs index 646da5c..124da24 100644 --- a/Mistral.SDK/DTOs/ChatCompletionRequest.cs +++ b/Mistral.SDK/DTOs/ChatCompletionRequest.cs @@ -2,10 +2,13 @@ using System; using System.Collections.Generic; using System.ComponentModel.DataAnnotations; +using System.Linq; using System.Text.Json.Serialization; +using Mistral.SDK.Extensions; namespace Mistral.SDK.DTOs { + [JsonConverter(typeof(ChatCompletionRequestConverter))] public class ChatCompletionRequest { /// @@ -110,10 +113,20 @@ public class ChatCompletionRequest /// /// [JsonPropertyName("response_format")] - public ResponseFormat ResponseFormat { get; set; } + public ResponseFormat ResponseFormat { get; set; } + + [JsonPropertyName("tool_choice")] + [JsonConverter(typeof(ToolChoiceTypeConverter))] + public ToolChoiceType ToolChoice { get; set; } = ToolChoiceType.none; - IEnumerable Validate() + [JsonIgnore] + public IList Tools { get; set; } + + [JsonPropertyName("tools")] + private List ToolsForMistral => Tools?.ToList(); + + IEnumerable Validate() { // Temperature (decimal?) maximum if (this.Temperature > (decimal?)1) diff --git a/Mistral.SDK/DTOs/ChatCompletionResponse.cs b/Mistral.SDK/DTOs/ChatCompletionResponse.cs index 565895b..0ea4e76 100644 --- a/Mistral.SDK/DTOs/ChatCompletionResponse.cs +++ b/Mistral.SDK/DTOs/ChatCompletionResponse.cs @@ -45,5 +45,8 @@ public class ChatCompletionResponse /// [JsonPropertyName("usage")] public Usage Usage { get; set; } + + [JsonIgnore] + public List ToolCalls { get; set; } = new List(); } } diff --git a/Mistral.SDK/DTOs/ChatMessage.cs b/Mistral.SDK/DTOs/ChatMessage.cs index 8e03ef0..6109a2d 100644 --- a/Mistral.SDK/DTOs/ChatMessage.cs +++ b/Mistral.SDK/DTOs/ChatMessage.cs @@ -1,4 +1,5 @@ -using System.Text.Json.Serialization; +using System.Collections.Generic; +using System.Text.Json.Serialization; using Mistral.SDK.Converters; namespace Mistral.SDK.DTOs @@ -16,6 +17,20 @@ public class ChatMessage this.Content = content; } + public ChatMessage(Common.Function? toolCall, string content = default(string)) + { + this.Role = RoleEnum.Tool; + this.Name = toolCall.Name; + this.ToolCallId = toolCall.Id; + this.Content = content; + } + + public ChatMessage() + { + + } + + [JsonConverter(typeof(JsonPropertyNameEnumConverter))] public enum RoleEnum { @@ -38,7 +53,10 @@ public enum RoleEnum /// [JsonPropertyName("assistant")] //[EnumMember(Value = "assistant")] - Assistant = 3 + Assistant = 3, + + [JsonPropertyName("tool")] + Tool = 4 } /// @@ -47,12 +65,18 @@ public enum RoleEnum [JsonPropertyName("role")] public RoleEnum? Role { get; set; } + [JsonPropertyName("name")] + public string Name { get; set; } /// /// Gets or Sets Content /// [JsonPropertyName("content")] public string Content { get; set; } + [JsonPropertyName("tool_calls")] + public List ToolCalls { get; set; } + [JsonPropertyName("tool_call_id")] + public string ToolCallId { get; set; } } } diff --git a/Mistral.SDK/DTOs/Choice.cs b/Mistral.SDK/DTOs/Choice.cs index 00ae278..4b4d2de 100644 --- a/Mistral.SDK/DTOs/Choice.cs +++ b/Mistral.SDK/DTOs/Choice.cs @@ -24,7 +24,10 @@ public enum FinishReasonEnum /// Enum ModelLength for value: model_length /// [JsonPropertyName("model_length")] - ModelLength = 3 + ModelLength = 3, + + [JsonPropertyName("tool_calls")] + ToolCalls = 4 } /// diff --git a/Mistral.SDK/DTOs/Tool.cs b/Mistral.SDK/DTOs/Tool.cs new file mode 100644 index 0000000..4856335 --- /dev/null +++ b/Mistral.SDK/DTOs/Tool.cs @@ -0,0 +1,85 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Text.Json.Serialization; + +namespace Mistral.SDK.DTOs +{ + public class Function + { + [JsonPropertyName("name")] + public string Name { get; set; } + + [JsonPropertyName("description")] + public string Description { get; set; } + + [JsonPropertyName("parameters")] + public Parameter Parameters { get; set; } + } + + + /// + /// Parameter Class + /// + public class Parameter + { + /// + /// Type of the Schema, default is object + /// + [JsonPropertyName("type")] + public string Type { get; set; } = "object"; + + /// + /// Properties of the Schema + /// + [JsonPropertyName("properties")] + public Dictionary Properties { get; set; } + + /// + /// Required Properties + /// + [JsonPropertyName("required")] + public IList Required { get; set; } + } + /// + /// Serializable Tool Class + /// + public class Tool + { + /// + /// Tool Type + /// + [JsonPropertyName("type")] + public string Type { get; set; } = "function"; + + /// + /// Tool Input Schema + /// + [JsonPropertyName("function")] + public Function Function { get; set; } + } + + /// + /// Property Definition Class + /// + public class Property + { + /// + /// Property Type + /// + [JsonPropertyName("type")] + public string Type { get; set; } + + /// + /// Enum Values as Strings (if applicable) + /// + [JsonPropertyName("enum")] + public string[] Enum { get; set; } + + /// + /// Description of the Property + /// + [JsonPropertyName("description")] + public string Description { get; set; } + } +} diff --git a/Mistral.SDK/DTOs/ToolCall.cs b/Mistral.SDK/DTOs/ToolCall.cs new file mode 100644 index 0000000..d6a0c37 --- /dev/null +++ b/Mistral.SDK/DTOs/ToolCall.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; + +namespace Mistral.SDK.DTOs +{ + public class ToolCall + { + [JsonPropertyName("id")] + public string Id { get; set; } + + [JsonPropertyName("type")] + public string Type { get; set; } + + [JsonPropertyName("function")] + public ToolCallParameter Function { get; set; } + } + + public class ToolCallParameter + { + [JsonPropertyName("name")] + public string Name { get; set; } + + [JsonPropertyName("arguments")] + public JsonNode Arguments { get; set; } + } +} diff --git a/Mistral.SDK/DTOs/ToolChoiceType.cs b/Mistral.SDK/DTOs/ToolChoiceType.cs new file mode 100644 index 0000000..66512c7 --- /dev/null +++ b/Mistral.SDK/DTOs/ToolChoiceType.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Runtime.Serialization; +using System.Text; +using System.Text.Json.Serialization; +using Mistral.SDK.Converters; + +namespace Mistral.SDK.DTOs +{ + [JsonConverter(typeof(ToolChoiceTypeConverter))] + public enum ToolChoiceType + { + [EnumMember(Value = "auto")] + Auto, + [EnumMember(Value = "any")] + Any, + [EnumMember(Value = "none")] + none + } +} diff --git a/Mistral.SDK/Embeddings/EmbeddingsEndpoint.cs b/Mistral.SDK/Embeddings/EmbeddingsEndpoint.cs index 1003c68..84ebbc4 100644 --- a/Mistral.SDK/Embeddings/EmbeddingsEndpoint.cs +++ b/Mistral.SDK/Embeddings/EmbeddingsEndpoint.cs @@ -37,7 +37,7 @@ public async Task GetEmbeddingsAsync(EmbeddingRequest request #endif var res = await JsonSerializer.DeserializeAsync( - new MemoryStream(Encoding.UTF8.GetBytes(resultAsString)), MistalSdkJsonOption.Options, cancellationToken: cancellationToken) + new MemoryStream(Encoding.UTF8.GetBytes(resultAsString)), MistralSdkJsonOption.Options, cancellationToken: cancellationToken) .ConfigureAwait(false); return res; diff --git a/Mistral.SDK/EndpointBase.cs b/Mistral.SDK/EndpointBase.cs index 55d4a0f..2a38c94 100644 --- a/Mistral.SDK/EndpointBase.cs +++ b/Mistral.SDK/EndpointBase.cs @@ -94,7 +94,7 @@ protected async Task HttpRequest(string url = null, Http #endif var res = await JsonSerializer.DeserializeAsync( - new MemoryStream(Encoding.UTF8.GetBytes(resultAsString)), MistalSdkJsonOption.Options, cancellationToken: cancellationToken) + new MemoryStream(Encoding.UTF8.GetBytes(resultAsString)), MistralSdkJsonOption.Options, cancellationToken: cancellationToken) .ConfigureAwait(false); return res; @@ -117,7 +117,7 @@ protected async Task HttpRequestRaw(string url = null, Http } else { - string jsonContent = JsonSerializer.Serialize(postData, MistalSdkJsonOption.Options ?? + string jsonContent = JsonSerializer.Serialize(postData, MistralSdkJsonOption.Options ?? new JsonSerializerOptions() { DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull }); var stringContent = new StringContent(jsonContent, Encoding.UTF8, "application/json"); req.Content = stringContent; @@ -201,14 +201,14 @@ protected async IAsyncEnumerable HttpStreamingRequest(st else if (currentEvent.EventType == null) { var res = await JsonSerializer.DeserializeAsync( - new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), MistalSdkJsonOption.Options, cancellationToken: cancellationToken) + new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), MistralSdkJsonOption.Options, cancellationToken: cancellationToken) .ConfigureAwait(false); yield return res; } else if (currentEvent.EventType != null) { var res = await JsonSerializer.DeserializeAsync( - new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), MistalSdkJsonOption.Options, cancellationToken: cancellationToken) + new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), MistralSdkJsonOption.Options, cancellationToken: cancellationToken) .ConfigureAwait(false); throw new Exception(res.Error.Message); } diff --git a/Mistral.SDK/Extensions/ChatCompletionRequestConverter.cs b/Mistral.SDK/Extensions/ChatCompletionRequestConverter.cs new file mode 100644 index 0000000..7a4d760 --- /dev/null +++ b/Mistral.SDK/Extensions/ChatCompletionRequestConverter.cs @@ -0,0 +1,43 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using Mistral.SDK.DTOs; + +namespace Mistral.SDK.Extensions +{ + public class ChatCompletionRequestConverter: JsonConverter + { + public override ChatCompletionRequest Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + // Implement the Read method if deserialization is needed + throw new NotImplementedException(); + } + + public override void Write(Utf8JsonWriter writer, ChatCompletionRequest value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + var properties = typeof(ChatCompletionRequest).GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + + foreach (var property in properties) + { + if (property.GetCustomAttribute() != null) + continue; + + var jsonPropertyName = property.GetCustomAttribute()?.Name ?? property.Name; + var propertyValue = property.GetValue(value); + + if (options.DefaultIgnoreCondition == JsonIgnoreCondition.WhenWritingNull && propertyValue == null) + continue; + + writer.WritePropertyName(jsonPropertyName); + JsonSerializer.Serialize(writer, propertyValue, property.PropertyType, options); + } + + writer.WriteEndObject(); + } + } +} diff --git a/Mistral.SDK/Extensions/TypeExtensions.cs b/Mistral.SDK/Extensions/TypeExtensions.cs new file mode 100644 index 0000000..c98bf9a --- /dev/null +++ b/Mistral.SDK/Extensions/TypeExtensions.cs @@ -0,0 +1,276 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Text; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using System.Text.Json; +using System.Threading; +using Mistral.SDK.Common; +using Mistral.SDK.Converters; + +namespace Mistral.SDK.Extensions +{ + internal static class TypeExtensions + { + public static JsonObject GenerateJsonSchema(this MethodInfo methodInfo) + { + var parameters = methodInfo.GetParameters(); + + if (parameters.Length == 0) + { + return new JsonObject + { + ["type"] = "object" + }; + } + + var schema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject() + }; + var requiredParameters = new JsonArray(); + + foreach (var parameter in parameters) + { + if (parameter.ParameterType == typeof(CancellationToken)) + { + continue; + } + + if (string.IsNullOrWhiteSpace(parameter.Name)) + { + throw new InvalidOperationException($"Failed to find a valid parameter name for {methodInfo.DeclaringType}.{methodInfo.Name}()"); + } + + if (!parameter.HasDefaultValue) + { + requiredParameters.Add(parameter.Name); + } + + schema["properties"]![parameter.Name] = GenerateJsonSchema(parameter.ParameterType, schema); + + var functionParameterAttribute = parameter.GetCustomAttribute(); + + if (functionParameterAttribute != null) + { + schema["properties"]![parameter.Name]!["description"] = functionParameterAttribute.Description; + } + } + + if (requiredParameters.Count > 0) + { + schema["required"] = requiredParameters; + } + + return schema; + } + + public static JsonObject GenerateJsonSchema(this Type type, JsonObject rootSchema) + { + var schema = new JsonObject(); + + if (!type.IsPrimitive && + type != typeof(Guid) && + type != typeof(DateTime) && + type != typeof(DateTimeOffset) && + rootSchema["definitions"] != null && + rootSchema["definitions"].AsObject().ContainsKey(type.FullName)) + { + return new JsonObject { ["$ref"] = $"#/definitions/{type.FullName}" }; + } + + if (type == typeof(string) || type == typeof(char)) + { + schema["type"] = "string"; + } + else if (type == typeof(int) || + type == typeof(long) || + type == typeof(uint) || + type == typeof(byte) || + type == typeof(sbyte) || + type == typeof(ulong) || + type == typeof(short) || + type == typeof(ushort)) + { + schema["type"] = "integer"; + } + else if (type == typeof(float) || + type == typeof(double) || + type == typeof(decimal)) + { + schema["type"] = "number"; + } + else if (type == typeof(bool)) + { + schema["type"] = "boolean"; + } + else if (type == typeof(DateTime) || type == typeof(DateTimeOffset)) + { + schema["type"] = "string"; + schema["format"] = "date-time"; + } + else if (type == typeof(Guid)) + { + schema["type"] = "string"; + schema["format"] = "uuid"; + } + else if (type.IsEnum) + { + schema["type"] = "string"; + schema["enum"] = new JsonArray(); + + foreach (var value in Enum.GetValues(type)) + { + schema["enum"].AsArray().Add(JsonNode.Parse(JsonSerializer.Serialize(value, MistralSdkJsonOption.Options))); + } + } + else if (type.IsArray || (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(List<>))) + { + schema["type"] = "array"; + var elementType = type.GetElementType() ?? type.GetGenericArguments()[0]; + + if (rootSchema["definitions"] != null && + rootSchema["definitions"].AsObject().ContainsKey(elementType.FullName)) + { + schema["items"] = new JsonObject { ["$ref"] = $"#/definitions/{elementType.FullName}" }; + } + else + { + schema["items"] = GenerateJsonSchema(elementType, rootSchema); + } + } + else + { + schema["type"] = "object"; + rootSchema["definitions"] ??= new JsonObject(); + rootSchema["definitions"][type.FullName] = new JsonObject(); + + var properties = type.GetProperties(BindingFlags.Public | BindingFlags.Instance | BindingFlags.DeclaredOnly); + var fields = type.GetFields(BindingFlags.Public | BindingFlags.Instance | BindingFlags.DeclaredOnly); + var members = new List(properties.Length + fields.Length); + members.AddRange(properties); + members.AddRange(fields); + + var memberInfo = new JsonObject(); + var memberProperties = new JsonArray(); + + foreach (var member in members) + { + var memberType = GetMemberType(member); + var functionPropertyAttribute = member.GetCustomAttribute(); + var jsonPropertyAttribute = member.GetCustomAttribute(); + var jsonIgnoreAttribute = member.GetCustomAttribute(); + var propertyName = jsonPropertyAttribute?.Name ?? member.Name; + + JsonObject propertyInfo; + + if (rootSchema["definitions"] != null && + rootSchema["definitions"].AsObject().ContainsKey(memberType.FullName)) + { + propertyInfo = new JsonObject { ["$ref"] = $"#/definitions/{memberType.FullName}" }; + } + else + { + propertyInfo = GenerateJsonSchema(memberType, rootSchema); + } + + // override properties with values from function property attribute + if (functionPropertyAttribute != null) + { + propertyInfo["description"] = functionPropertyAttribute.Description; + + if (functionPropertyAttribute.Required) + { + memberProperties.Add(propertyName); + } + + JsonNode defaultValue = null; + + if (functionPropertyAttribute.DefaultValue != null) + { + defaultValue = JsonNode.Parse(JsonSerializer.Serialize(functionPropertyAttribute.DefaultValue, MistralSdkJsonOption.Options)); + propertyInfo["default"] = defaultValue; + } + + if (functionPropertyAttribute.PossibleValues is { Length: > 0 }) + { + var enums = new JsonArray(); + + foreach (var value in functionPropertyAttribute.PossibleValues) + { + var @enum = JsonNode.Parse(JsonSerializer.Serialize(value, MistralSdkJsonOption.Options)); + + if (defaultValue == null) + { + enums.Add(@enum); + } + else + { + if (@enum != defaultValue) + { + enums.Add(@enum); + } + } + } + + if (defaultValue != null && !enums.Contains(defaultValue)) + { + enums.Add(JsonNode.Parse(defaultValue.ToJsonString(MistralSdkJsonOption.Options))); + } + + propertyInfo["enum"] = enums; + } + } + else if (jsonIgnoreAttribute != null) + { + // only add members that are required + switch (jsonIgnoreAttribute.Condition) + { + case JsonIgnoreCondition.Never: + case JsonIgnoreCondition.WhenWritingDefault: + memberProperties.Add(propertyName); + break; + case JsonIgnoreCondition.Always: + case JsonIgnoreCondition.WhenWritingNull: + default: + memberProperties.Remove(propertyName); + break; + } + } + else if (Nullable.GetUnderlyingType(memberType) == null) + { + memberProperties.Add(propertyName); + } + + memberInfo[propertyName] = propertyInfo; + } + + schema["properties"] = memberInfo; + + if (memberProperties.Count > 0) + { + schema["required"] = memberProperties; + } + + rootSchema["definitions"] ??= new JsonObject(); + rootSchema["definitions"][type.FullName] = schema; + return new JsonObject { ["$ref"] = $"#/definitions/{type.FullName}" }; + } + + return schema; + } + + private static Type GetMemberType(MemberInfo member) + => member switch + { + FieldInfo fieldInfo => fieldInfo.FieldType, + PropertyInfo propertyInfo => propertyInfo.PropertyType, + _ => throw new ArgumentException($"{nameof(MemberInfo)} must be of type {nameof(FieldInfo)}, {nameof(PropertyInfo)}", nameof(member)) + }; + + + + } +} diff --git a/Mistral.SDK/Models/ModelsEndpoint.cs b/Mistral.SDK/Models/ModelsEndpoint.cs index 2a7a50c..fa531f2 100644 --- a/Mistral.SDK/Models/ModelsEndpoint.cs +++ b/Mistral.SDK/Models/ModelsEndpoint.cs @@ -35,7 +35,7 @@ public async Task GetModelsAsync(CancellationToken cancellationToken #endif var res = await JsonSerializer.DeserializeAsync( - new MemoryStream(Encoding.UTF8.GetBytes(resultAsString)), MistalSdkJsonOption.Options, cancellationToken: cancellationToken) + new MemoryStream(Encoding.UTF8.GetBytes(resultAsString)), MistralSdkJsonOption.Options, cancellationToken: cancellationToken) .ConfigureAwait(false); return res; From abf4297fd02361a2e091ca20e0f2b3cbaf497f7a Mon Sep 17 00:00:00 2001 From: tghamm Date: Sun, 1 Dec 2024 07:52:32 -0500 Subject: [PATCH 2/8] more integration tests --- Mistral.SDK.Tests/FunctionCalling.cs | 243 ++++++++++++++++++ .../Converters/MistralSdkJsonOption.cs | 2 +- 2 files changed, 244 insertions(+), 1 deletion(-) diff --git a/Mistral.SDK.Tests/FunctionCalling.cs b/Mistral.SDK.Tests/FunctionCalling.cs index af3c666..4b5d8eb 100644 --- a/Mistral.SDK.Tests/FunctionCalling.cs +++ b/Mistral.SDK.Tests/FunctionCalling.cs @@ -1,11 +1,117 @@ using Mistral.SDK.Common; using Mistral.SDK.DTOs; +using System.Globalization; namespace Mistral.SDK.Tests { [TestClass] public class FunctionCalling { + public enum TempType + { + Fahrenheit, + Celsius + } + + [Function("This function returns the weather for a given location")] + public static async Task GetWeather([FunctionParameter("Location of the weather", true)] string location, + [FunctionParameter("Unit of temperature, celsius or fahrenheit", true)] TempType tempType) + { + await Task.Yield(); + return "72 degrees and sunny"; + } + + [Function("Get the current user's name")] + public static async Task GetCurrentUser() + { + await Task.Yield(); + return "Mistral"; + } + + public static class StaticObjectTool + { + + public static string GetWeather(string location) + { + return "72 degrees and sunny"; + } + } + + public class InstanceObjectTool + { + + public string GetWeather(string location) + { + return "72 degrees and sunny"; + } + } + + [TestMethod] + public async Task TestStaticObjectTool() + { + var client = new MistralClient(); + var messages = new List() + { + new ChatMessage(ChatMessage.RoleEnum.User, "What is the weather in San Francisco, CA?") + }; + var request = new ChatCompletionRequest("mistral-large-latest", messages); + + request.ToolChoice = ToolChoiceType.Auto; + + request.Tools = new List + { + Common.Tool.GetOrCreateTool(typeof(StaticObjectTool), nameof(GetWeather), "This function returns the weather for a given location") + }; + + var response = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + messages.Add(response.Choices.First().Message); + + foreach (var toolCall in response.ToolCalls) + { + var resp = toolCall.Invoke(); + messages.Add(new ChatMessage(toolCall, resp)); + } + + var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("72")); + } + + [TestMethod] + public async Task TestInstanceObjectTool() + { + var client = new MistralClient(); + var messages = new List() + { + new ChatMessage(ChatMessage.RoleEnum.User, "What is the weather in San Francisco, CA?") + }; + var request = new ChatCompletionRequest("mistral-large-latest", messages); + + request.ToolChoice = ToolChoiceType.Auto; + + var objectInstance = new InstanceObjectTool(); + request.Tools = new List + { + Common.Tool.GetOrCreateTool(objectInstance, nameof(GetWeather), "This function returns the weather for a given location") + }; + + var response = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + messages.Add(response.Choices.First().Message); + + foreach (var toolCall in response.ToolCalls) + { + var resp = toolCall.Invoke(); + messages.Add(new ChatMessage(toolCall, resp)); + } + + var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("72")); + } + + [TestMethod] public async Task TestBasicFunction() { @@ -41,5 +147,142 @@ public async Task TestBasicFunction() Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("72")); } + + [TestMethod] + public async Task TestBasicToolDeclaredGlobally() + { + var client = new MistralClient(); + var messages = new List() + { + new ChatMessage(ChatMessage.RoleEnum.User, "What is the current weather in San Francisco?") + }; + var request = new ChatCompletionRequest("mistral-large-latest", messages); + + request.ToolChoice = ToolChoiceType.Auto; + + request.Tools = Common.Tool.GetAllAvailableTools(includeDefaults: false, forceUpdate: true, clearCache: true).ToList(); + + var response = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + messages.Add(response.Choices.First().Message); + + foreach (var toolCall in response.ToolCalls) + { + var resp = await toolCall.InvokeAsync(); + messages.Add(new ChatMessage(toolCall, resp)); + } + + var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("72")); + } + + [TestMethod] + public async Task TestTestEmptyArgsAndMultiTool() + { + var client = new MistralClient(); + var messages = new List() + { + new ChatMessage(ChatMessage.RoleEnum.User, "What is the current user's name?") + }; + var request = new ChatCompletionRequest("mistral-large-latest", messages); + + request.ToolChoice = ToolChoiceType.Auto; + + request.Tools = Common.Tool.GetAllAvailableTools(includeDefaults: false, forceUpdate: true, clearCache: true).ToList(); + + var response = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + messages.Add(response.Choices.First().Message); + + foreach (var toolCall in response.ToolCalls) + { + var resp = await toolCall.InvokeAsync(); + messages.Add(new ChatMessage(toolCall, resp)); + } + + var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("Mistral")); + } + + [TestMethod] + public async Task TestMathFuncTool() + { + var client = new MistralClient(); + var messages = new List() + { + new ChatMessage(ChatMessage.RoleEnum.User,"How many characters are in the word Christmas, multiply by 5, add 6, subtract 2, then divide by 2.1?") + }; + var request = new ChatCompletionRequest("mistral-large-latest", messages); + + request.ToolChoice = ToolChoiceType.Auto; + + request.Tools = new List + { + Common.Tool.FromFunc("ChristmasMathFunction", + ([FunctionParameter("word to start with", true)]string word, + [FunctionParameter("number to multiply word count by", true)]int multiplier, + [FunctionParameter("amount to add to word count", true)]int addition, + [FunctionParameter("amount to subtract from word count", true)]int subtraction, + [FunctionParameter("amount to divide word count by", true)]double divisor) => + { + return ((word.Length * multiplier + addition - subtraction) / divisor).ToString(CultureInfo.InvariantCulture); + }, "Function that can be used to determine the number of characters in a word combined with a mathematical formula") + }; + + var response = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + messages.Add(response.Choices.First().Message); + + foreach (var toolCall in response.ToolCalls) + { + var resp = toolCall.Invoke(); + messages.Add(new ChatMessage(toolCall, resp)); + } + + var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("23")); + } + + + [TestMethod] + public async Task TestBoolTool() + { + var client = new MistralClient(); + var messages = new List() + { + new ChatMessage(ChatMessage.RoleEnum.User,"Should I roll the dice? Your answer should contain the word yes or no.") + }; + var request = new ChatCompletionRequest("mistral-large-latest", messages); + + request.ToolChoice = ToolChoiceType.Auto; + + request.Tools = new List + { + Common.Tool.FromFunc("Dice_Roller", + ([FunctionParameter("Decides whether to roll the dice", true)]bool rollDice)=> + { + return "no"; + }, "Decides whether the user should roll the dice") + }; + + var response = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + messages.Add(response.Choices.First().Message); + + foreach (var toolCall in response.ToolCalls) + { + var resp = toolCall.Invoke(); + messages.Add(new ChatMessage(toolCall, resp)); + } + + var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("no")); + } + + } } diff --git a/Mistral.SDK/Converters/MistralSdkJsonOption.cs b/Mistral.SDK/Converters/MistralSdkJsonOption.cs index e1db8dc..5421e3f 100644 --- a/Mistral.SDK/Converters/MistralSdkJsonOption.cs +++ b/Mistral.SDK/Converters/MistralSdkJsonOption.cs @@ -9,7 +9,7 @@ public static class MistralSdkJsonOption { #if NET8_0_OR_GREATER - public static readonly JsonSerializerOptions Options = JsonContext.Default.Options; + public static readonly JsonSerializerOptions Options = null; #else public static readonly JsonSerializerOptions Options = null; #endif From 678ed6f2ce69988c1b6b46fad0566f7b9eb78e5b Mon Sep 17 00:00:00 2001 From: tghamm Date: Sun, 1 Dec 2024 08:42:00 -0500 Subject: [PATCH 3/8] streaming function calling --- Mistral.SDK.Tests/FunctionCalling.cs | 52 +++++++++++++++++++ .../Completions/CompletionsEndpoint.cs | 19 +++++++ 2 files changed, 71 insertions(+) diff --git a/Mistral.SDK.Tests/FunctionCalling.cs b/Mistral.SDK.Tests/FunctionCalling.cs index 4b5d8eb..6869fc6 100644 --- a/Mistral.SDK.Tests/FunctionCalling.cs +++ b/Mistral.SDK.Tests/FunctionCalling.cs @@ -111,6 +111,58 @@ public async Task TestInstanceObjectTool() Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("72")); } + [TestMethod] + public async Task TestBasicFunctionStreaming() + { + var client = new MistralClient(); + var messages = new List() + { + new ChatMessage(ChatMessage.RoleEnum.User, "How many characters are in the word Christmas, multiply by 5, add 6, subtract 2, then divide by 2.1?") + }; + var request = new ChatCompletionRequest("mistral-large-latest", messages); + + request.ToolChoice = ToolChoiceType.Any; + + request.Tools = new List + { + Common.Tool.FromFunc("ChristmasMathFunction", + ([FunctionParameter("word to start with", true)]string word, + [FunctionParameter("number to multiply word count by", true)]int multiplier, + [FunctionParameter("amount to add to word count", true)]int addition, + [FunctionParameter("amount to subtract from word count", true)]int subtraction, + [FunctionParameter("amount to divide word count by", true)]double divisor) => + { + return ((word.Length * multiplier + addition - subtraction) / divisor).ToString(CultureInfo.InvariantCulture); + }, "Function that can be used to determine the number of characters in a word combined with a mathematical formula") + }; + var responses = new List(); + await foreach (var response in client.Completions.StreamCompletionAsync(request)) + { + responses.Add(response); + } + + messages.Add(responses.First(p => p.Choices.First().Delta.ToolCalls != null).Choices.First().Delta); + + foreach (var toolCall in responses.First(p => p.Choices.First().Delta.ToolCalls != null).ToolCalls) + { + var resp = toolCall.Invoke(); + messages.Add(new ChatMessage(toolCall, resp)); + } + + request.ToolChoice = ToolChoiceType.none; + + var finalMessage = string.Empty; + await foreach (var response in client.Completions.StreamCompletionAsync(request)) + { + finalMessage += response.Choices.First().Delta.Content; + } + //var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + + Assert.IsTrue(finalMessage.Contains("23")); + } + + + [TestMethod] public async Task TestBasicFunction() diff --git a/Mistral.SDK/Completions/CompletionsEndpoint.cs b/Mistral.SDK/Completions/CompletionsEndpoint.cs index aeed83e..fdc247f 100644 --- a/Mistral.SDK/Completions/CompletionsEndpoint.cs +++ b/Mistral.SDK/Completions/CompletionsEndpoint.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.AI; using Mistral.SDK.DTOs; +using ChatMessage = Mistral.SDK.DTOs.ChatMessage; namespace Mistral.SDK.Completions { @@ -64,6 +65,24 @@ public async IAsyncEnumerable StreamCompletionAsync(Chat request.Stream = true; await foreach (var result in HttpStreamingRequest(Url, HttpMethod.Post, request, cancellationToken).WithCancellation(cancellationToken).ConfigureAwait(false)) { + var toolCalls = new List(); + foreach (var message in result.Choices) + { + if (message.Delta.ToolCalls is null) continue; + foreach (var returned_tool in message.Delta.ToolCalls) + { + var tool = request.Tools?.FirstOrDefault(t => t.Function.Name == returned_tool.Function.Name); + if (tool != null) + { + tool.Function.Arguments = returned_tool.Function.Arguments; + tool.Function.Id = returned_tool.Id; + toolCalls.Add(tool.Function); + } + } + } + + result.Choices.First().Delta.Role = ChatMessage.RoleEnum.Assistant; + result.ToolCalls = toolCalls; yield return result; } } From 411c7b1c29ab4a588af80aabfd95eefbdb0c831f Mon Sep 17 00:00:00 2001 From: tghamm Date: Sun, 1 Dec 2024 11:56:11 -0500 Subject: [PATCH 4/8] hello world IChatClient function calling --- Mistral.SDK.Tests/ChatClient.cs | 26 ++ Mistral.SDK.Tests/Mistral.SDK.Tests.csproj | 2 + .../CompletionsEndpoint.ChatClient.cs | 284 ++++++++++++++++++ .../Completions/CompletionsEndpoint.cs | 135 +-------- Mistral.SDK/DTOs/ChatMessage.cs | 8 + 5 files changed, 323 insertions(+), 132 deletions(-) create mode 100644 Mistral.SDK/Completions/CompletionsEndpoint.ChatClient.cs diff --git a/Mistral.SDK.Tests/ChatClient.cs b/Mistral.SDK.Tests/ChatClient.cs index e7e9cd8..26f7d11 100644 --- a/Mistral.SDK.Tests/ChatClient.cs +++ b/Mistral.SDK.Tests/ChatClient.cs @@ -86,5 +86,31 @@ public async Task TestMistralCompletionSafeWithOptions() Assert.IsTrue(!string.IsNullOrEmpty(response.Message.Text)); } + + [TestMethod] + public async Task TestNonStreamingFunctionCalls() + { + IChatClient client = new MistralClient().Completions + .AsBuilder() + .UseFunctionInvocation() + .Build(); + + ChatOptions options = new() + { + ModelId = "mistral-large-latest", + MaxOutputTokens = 512, + ToolMode = ChatToolMode.RequireAny, + Tools = [AIFunctionFactory.Create((string personName) => personName switch { + "Alice" => "25", + _ => "40" + }, "GetPersonAge", "Gets the age of the person whose name is specified.")] + }; + + var res = await client.CompleteAsync("How old is Alice?", options); + + Assert.IsTrue( + res.Message.Text?.Contains("25") is true, + res.Message.Text); + } } } \ No newline at end of file diff --git a/Mistral.SDK.Tests/Mistral.SDK.Tests.csproj b/Mistral.SDK.Tests/Mistral.SDK.Tests.csproj index 376ac32..30a1318 100644 --- a/Mistral.SDK.Tests/Mistral.SDK.Tests.csproj +++ b/Mistral.SDK.Tests/Mistral.SDK.Tests.csproj @@ -13,6 +13,8 @@ + + diff --git a/Mistral.SDK/Completions/CompletionsEndpoint.ChatClient.cs b/Mistral.SDK/Completions/CompletionsEndpoint.ChatClient.cs new file mode 100644 index 0000000..bf48e60 --- /dev/null +++ b/Mistral.SDK/Completions/CompletionsEndpoint.ChatClient.cs @@ -0,0 +1,284 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; +using System.Threading; +using Microsoft.Extensions.AI; +using Mistral.SDK.DTOs; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using System.Text.Json; + +namespace Mistral.SDK.Completions +{ + public partial class CompletionsEndpoint : IChatClient + { + + async Task IChatClient.CompleteAsync( + IList chatMessages, ChatOptions options, CancellationToken cancellationToken) + { + var response = await GetCompletionAsync(CreateRequest(chatMessages, options), cancellationToken).ConfigureAwait(false); + + Microsoft.Extensions.AI.ChatMessage message = new(ChatRole.Assistant, ProcessResponseContent(response)); + + var completion = new ChatCompletion(message) + { + CompletionId = response.Id, + ModelId = response.Model + }; + + if (response.Usage is { } usage) + { + completion.Usage = new UsageDetails() + { + InputTokenCount = usage.PromptTokens, + OutputTokenCount = usage.CompletionTokens, + TotalTokenCount = usage.TotalTokens + }; + } + + //foreach (Choice choice in response.Choices) + //{ + // ChatRole role = choice.Message.Role switch + // { + // DTOs.ChatMessage.RoleEnum.System => ChatRole.System, + // DTOs.ChatMessage.RoleEnum.Assistant => ChatRole.User, + // _ => ChatRole.User, + // }; + + // completion.Choices.Add(new Microsoft.Extensions.AI.ChatMessage(role, choice.Message.Content)); + + // if (completion.FinishReason is null && choice.FinishReason != null) + // { + // completion.FinishReason = choice.FinishReason switch + // { + // Choice.FinishReasonEnum.Length => ChatFinishReason.Length, + // Choice.FinishReasonEnum.ModelLength => ChatFinishReason.Length, + // _ => ChatFinishReason.Stop + // }; + // } + //} + + return completion; + } + + private static UsageDetails CreateUsageDetails(Usage usage) => + new() + { + InputTokenCount = usage.PromptTokens, + OutputTokenCount = usage.CompletionTokens, + AdditionalProperties = new() + { + [nameof(usage.TotalTokens)] = usage.TotalTokens + } + }; + + async IAsyncEnumerable IChatClient.CompleteStreamingAsync( + IList chatMessages, ChatOptions options, [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var response in StreamCompletionAsync(CreateRequest(chatMessages, options), cancellationToken).WithCancellation(cancellationToken).ConfigureAwait(false)) + { + foreach (var choice in response.Choices) + { + yield return new StreamingChatCompletionUpdate + { + ChoiceIndex = choice.Index, + CompletionId = response.Id, + ModelId = response.Model, + Role = choice.Delta?.Role switch + { + DTOs.ChatMessage.RoleEnum.System => ChatRole.System, + DTOs.ChatMessage.RoleEnum.Assistant => ChatRole.User, + _ => ChatRole.User, + }, + FinishReason = choice.FinishReason switch + { + Choice.FinishReasonEnum.Length => ChatFinishReason.Length, + Choice.FinishReasonEnum.ModelLength => ChatFinishReason.Length, + _ => ChatFinishReason.Stop + }, + Text = choice.Delta?.Content, + }; + } + + if (response.Usage is { } usage) + { + yield return new StreamingChatCompletionUpdate() + { + CompletionId = response.Id, + ModelId = response.Model, + Contents = new List() + { + new UsageContent(new UsageDetails() + { + InputTokenCount = usage.PromptTokens, + OutputTokenCount = usage.CompletionTokens, + TotalTokenCount = usage.TotalTokens + }) + }, + }; + } + } + } + + private static ChatCompletionRequest CreateRequest(IList chatMessages, ChatOptions options) + { + var messages = chatMessages.Select(m => + { + DTOs.ChatMessage.RoleEnum role = + m.Role == ChatRole.System ? DTOs.ChatMessage.RoleEnum.System : + m.Role == ChatRole.User ? DTOs.ChatMessage.RoleEnum.User : + m.Role == ChatRole.Tool ? DTOs.ChatMessage.RoleEnum.Tool : + DTOs.ChatMessage.RoleEnum.Assistant; + + foreach (var content in m.Contents) + { + if (content is Microsoft.Extensions.AI.FunctionResultContent frc) + { + + return new DTOs.ChatMessage(frc.CallId, frc.Name, frc.Result?.ToString()); + } + else if (content is Microsoft.Extensions.AI.FunctionCallContent fcc) + { + return new DTOs.ChatMessage() + { + Role = DTOs.ChatMessage.RoleEnum.Assistant, + ToolCalls = new List() + { + new ToolCall() + { + Id = fcc.CallId, + Function = new ToolCallParameter() + { + Arguments = JsonSerializer.SerializeToNode(fcc.Arguments), + Name = fcc.Name, + } + } + } + }; + } + + } + + return new DTOs.ChatMessage(role, string.Concat(m.Contents.OfType())); + }).ToList(); + + var request = new ChatCompletionRequest( + model: options?.ModelId, + messages: messages, + temperature: (decimal?)options?.Temperature, + topP: (decimal?)options?.TopP, + maxTokens: options?.MaxOutputTokens, + safePrompt: options?.AdditionalProperties?.TryGetValue(nameof(ChatCompletionRequest.SafePrompt), out bool safePrompt) is true, + randomSeed: (int?)options?.Seed); + + if (options.ResponseFormat is ChatResponseFormatJson) + { + request.ResponseFormat = new ResponseFormat() { Type = ResponseFormat.ResponseFormatEnum.JSON }; + } + + if (options.Tools is { Count: > 0 }) + { + + if (options.ToolMode is RequiredChatToolMode r) + { + request.ToolChoice = ToolChoiceType.Any; + } + else if (options.ToolMode is AutoChatToolMode a) + { + request.ToolChoice = ToolChoiceType.Auto; + } + + request.Tools = options + .Tools + .OfType() + .Select(f => new Common.Tool(new Common.Function(f.Metadata.Name, f.Metadata.Description, FunctionParameters.CreateSchema(f)))) + .ToList(); + } + + return request; + } + + private static List ProcessResponseContent(ChatCompletionResponse response) + { + List contents = new(); + + foreach (var content in response.Choices) + { + if (content.Message.ToolCalls is not null) + { + contents.Add(new Microsoft.Extensions.AI.TextContent(content.Message.Content)); + + foreach (var toolCall in content.Message.ToolCalls) + { + Dictionary arguments = null; + if (toolCall.Function.Arguments is not null) + { + string jsonString = toolCall.Function.Arguments.AsValue().ToJsonString(); + jsonString = System.Text.RegularExpressions.Regex.Unescape(jsonString); // Decode Unicode escape sequences + arguments = JsonSerializer.Deserialize>(toolCall.Function.Arguments.ToString()); + } + + contents.Add(new FunctionCallContent( + toolCall.Id, + toolCall.Function.Name, + arguments)); + } + } + else + { + contents.Add(new Microsoft.Extensions.AI.TextContent(content.Message.Content)); + } + } + + return contents; + } + + void IDisposable.Dispose() { } + + object IChatClient.GetService(Type serviceType, object key) => + key is null && serviceType?.IsInstanceOfType(this) is true ? this : null; + + ChatClientMetadata IChatClient.Metadata => _metadata ??= new ChatClientMetadata(nameof(MistralClient), new Uri(Url)); + + private ChatClientMetadata _metadata; + + + private sealed class FunctionParameters + { + private static readonly JsonElement s_defaultParameterSchema = JsonDocument.Parse("{}").RootElement; + + [JsonPropertyName("type")] + public string Type { get; set; } = "object"; + + [JsonPropertyName("required")] + public List Required { get; set; } = []; + + [JsonPropertyName("properties")] + public Dictionary Properties { get; set; } = []; + + public static JsonNode CreateSchema(AIFunction f) + { + var parameters = f.Metadata.Parameters; + + FunctionParameters schema = new(); + + foreach (AIFunctionParameterMetadata parameter in parameters) + { + schema.Properties.Add(parameter.Name, parameter.Schema is JsonElement e ? e : s_defaultParameterSchema); + + if (parameter.IsRequired) + { + schema.Required.Add(parameter.Name); + } + } + + return JsonSerializer.SerializeToNode(schema); + } + } + } + + +} diff --git a/Mistral.SDK/Completions/CompletionsEndpoint.cs b/Mistral.SDK/Completions/CompletionsEndpoint.cs index fdc247f..c842a60 100644 --- a/Mistral.SDK/Completions/CompletionsEndpoint.cs +++ b/Mistral.SDK/Completions/CompletionsEndpoint.cs @@ -6,13 +6,12 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; -using Microsoft.Extensions.AI; using Mistral.SDK.DTOs; using ChatMessage = Mistral.SDK.DTOs.ChatMessage; namespace Mistral.SDK.Completions { - public class CompletionsEndpoint : EndpointBase, IChatClient + public partial class CompletionsEndpoint : EndpointBase { /// /// Constructor of the api endpoint. Rather than instantiating this yourself, access it through an instance of as . @@ -87,136 +86,8 @@ public async IAsyncEnumerable StreamCompletionAsync(Chat } } - async Task IChatClient.CompleteAsync( - IList chatMessages, ChatOptions options, CancellationToken cancellationToken) - { - var response = await GetCompletionAsync(CreateRequest(chatMessages, options), cancellationToken).ConfigureAwait(false); - - var completion = new ChatCompletion(new List()) - { - CompletionId = response.Id, - ModelId = response.Model - }; - - if (response.Usage is { } usage) - { - completion.Usage = new UsageDetails() - { - InputTokenCount = usage.PromptTokens, - OutputTokenCount = usage.CompletionTokens, - TotalTokenCount = usage.TotalTokens - }; - } - - foreach (Choice choice in response.Choices) - { - ChatRole role = choice.Message.Role switch - { - DTOs.ChatMessage.RoleEnum.System => ChatRole.System, - DTOs.ChatMessage.RoleEnum.Assistant => ChatRole.User, - _ => ChatRole.User, - }; - - completion.Choices.Add(new Microsoft.Extensions.AI.ChatMessage(role, choice.Message.Content)); - - if (completion.FinishReason is null && choice.FinishReason != null) - { - completion.FinishReason = choice.FinishReason switch - { - Choice.FinishReasonEnum.Length => ChatFinishReason.Length, - Choice.FinishReasonEnum.ModelLength => ChatFinishReason.Length, - _ => ChatFinishReason.Stop - }; - } - } - - return completion; - } - - async IAsyncEnumerable IChatClient.CompleteStreamingAsync( - IList chatMessages, ChatOptions options, [EnumeratorCancellation] CancellationToken cancellationToken) - { - await foreach (var response in StreamCompletionAsync(CreateRequest(chatMessages, options), cancellationToken).WithCancellation(cancellationToken).ConfigureAwait(false)) - { - foreach (var choice in response.Choices) - { - yield return new StreamingChatCompletionUpdate - { - ChoiceIndex = choice.Index, - CompletionId = response.Id, - ModelId = response.Model, - Role = choice.Delta?.Role switch - { - DTOs.ChatMessage.RoleEnum.System => ChatRole.System, - DTOs.ChatMessage.RoleEnum.Assistant => ChatRole.User, - _ => ChatRole.User, - }, - FinishReason = choice.FinishReason switch - { - Choice.FinishReasonEnum.Length => ChatFinishReason.Length, - Choice.FinishReasonEnum.ModelLength => ChatFinishReason.Length, - _ => ChatFinishReason.Stop - }, - Text = choice.Delta?.Content, - }; - } - - if (response.Usage is { } usage) - { - yield return new StreamingChatCompletionUpdate() - { - CompletionId = response.Id, - ModelId = response.Model, - Contents = new List() - { - new UsageContent(new UsageDetails() - { - InputTokenCount = usage.PromptTokens, - OutputTokenCount = usage.CompletionTokens, - TotalTokenCount = usage.TotalTokens - }) - }, - }; - } - } - } - - private static ChatCompletionRequest CreateRequest(IList chatMessages, ChatOptions options) - { - var messages = chatMessages.Select(m => - { - DTOs.ChatMessage.RoleEnum role = - m.Role == ChatRole.System ? DTOs.ChatMessage.RoleEnum.System : - m.Role == ChatRole.User ? DTOs.ChatMessage.RoleEnum.User : - DTOs.ChatMessage.RoleEnum.Assistant; - - return new DTOs.ChatMessage(role, string.Concat(m.Contents.OfType())); - }).ToList(); - - var request = new ChatCompletionRequest( - model: options?.ModelId, - messages: messages, - temperature: (decimal?)options?.Temperature, - topP: (decimal?)options?.TopP, - maxTokens: options?.MaxOutputTokens, - safePrompt: options?.AdditionalProperties?.TryGetValue(nameof(ChatCompletionRequest.SafePrompt), out bool safePrompt) is true, - randomSeed: (int?)options?.Seed); - - if (options.ResponseFormat is ChatResponseFormatJson) - { - request.ResponseFormat = new ResponseFormat() { Type = ResponseFormat.ResponseFormatEnum.JSON }; - } - - return request; - } - - void IDisposable.Dispose() { } - - object IChatClient.GetService(Type serviceType, object key) => - key is null && serviceType?.IsInstanceOfType(this) is true ? this : null; - - ChatClientMetadata IChatClient.Metadata => _metadata ??= new ChatClientMetadata(nameof(MistralClient), new Uri(Url)); + - private ChatClientMetadata _metadata; + } } diff --git a/Mistral.SDK/DTOs/ChatMessage.cs b/Mistral.SDK/DTOs/ChatMessage.cs index 6109a2d..53377dd 100644 --- a/Mistral.SDK/DTOs/ChatMessage.cs +++ b/Mistral.SDK/DTOs/ChatMessage.cs @@ -25,6 +25,14 @@ public class ChatMessage this.Content = content; } + public ChatMessage(string toolCallId, string name, string content = default(string)) + { + this.Role = RoleEnum.Tool; + this.Name = name; + this.ToolCallId = toolCallId; + this.Content = content; + } + public ChatMessage() { From 5d0e48a10c51b6d2c3dee1a32c86718308bb144e Mon Sep 17 00:00:00 2001 From: tghamm Date: Sun, 1 Dec 2024 12:59:07 -0500 Subject: [PATCH 5/8] cleanup and works in streaming with IChatClient --- Mistral.SDK.Tests/ChatClient.cs | 32 +++++- .../CompletionsEndpoint.ChatClient.cs | 99 ++++++++----------- 2 files changed, 74 insertions(+), 57 deletions(-) diff --git a/Mistral.SDK.Tests/ChatClient.cs b/Mistral.SDK.Tests/ChatClient.cs index 26f7d11..23c7eab 100644 --- a/Mistral.SDK.Tests/ChatClient.cs +++ b/Mistral.SDK.Tests/ChatClient.cs @@ -99,7 +99,7 @@ public async Task TestNonStreamingFunctionCalls() { ModelId = "mistral-large-latest", MaxOutputTokens = 512, - ToolMode = ChatToolMode.RequireAny, + ToolMode = ChatToolMode.Auto, Tools = [AIFunctionFactory.Create((string personName) => personName switch { "Alice" => "25", _ => "40" @@ -112,5 +112,35 @@ public async Task TestNonStreamingFunctionCalls() res.Message.Text?.Contains("25") is true, res.Message.Text); } + + [TestMethod] + public async Task TestStreamingFunctionCalls() + { + IChatClient client = new MistralClient().Completions + .AsBuilder() + .UseFunctionInvocation() + .Build(); + + ChatOptions options = new() + { + ModelId = "mistral-large-latest", + MaxOutputTokens = 512, + ToolMode = ChatToolMode.Auto, + Tools = [AIFunctionFactory.Create((string personName) => personName switch { + "Alice" => "25", + _ => "40" + }, "GetPersonAge", "Gets the age of the person whose name is specified.")] + }; + + StringBuilder sb = new(); + await foreach (var update in client.CompleteStreamingAsync("How old is Alice?", options)) + { + sb.Append(update); + } + + Assert.IsTrue( + sb.ToString().Contains("25") is true, + sb.ToString()); + } } } \ No newline at end of file diff --git a/Mistral.SDK/Completions/CompletionsEndpoint.ChatClient.cs b/Mistral.SDK/Completions/CompletionsEndpoint.ChatClient.cs index bf48e60..6d81516 100644 --- a/Mistral.SDK/Completions/CompletionsEndpoint.ChatClient.cs +++ b/Mistral.SDK/Completions/CompletionsEndpoint.ChatClient.cs @@ -39,42 +39,9 @@ async Task IChatClient.CompleteAsync( }; } - //foreach (Choice choice in response.Choices) - //{ - // ChatRole role = choice.Message.Role switch - // { - // DTOs.ChatMessage.RoleEnum.System => ChatRole.System, - // DTOs.ChatMessage.RoleEnum.Assistant => ChatRole.User, - // _ => ChatRole.User, - // }; - - // completion.Choices.Add(new Microsoft.Extensions.AI.ChatMessage(role, choice.Message.Content)); - - // if (completion.FinishReason is null && choice.FinishReason != null) - // { - // completion.FinishReason = choice.FinishReason switch - // { - // Choice.FinishReasonEnum.Length => ChatFinishReason.Length, - // Choice.FinishReasonEnum.ModelLength => ChatFinishReason.Length, - // _ => ChatFinishReason.Stop - // }; - // } - //} - return completion; } - private static UsageDetails CreateUsageDetails(Usage usage) => - new() - { - InputTokenCount = usage.PromptTokens, - OutputTokenCount = usage.CompletionTokens, - AdditionalProperties = new() - { - [nameof(usage.TotalTokens)] = usage.TotalTokens - } - }; - async IAsyncEnumerable IChatClient.CompleteStreamingAsync( IList chatMessages, ChatOptions options, [EnumeratorCancellation] CancellationToken cancellationToken) { @@ -82,11 +49,11 @@ async IAsyncEnumerable IChatClient.CompleteStream { foreach (var choice in response.Choices) { - yield return new StreamingChatCompletionUpdate - { + var update = new StreamingChatCompletionUpdate { ChoiceIndex = choice.Index, CompletionId = response.Id, ModelId = response.Model, + RawRepresentation = response, Role = choice.Delta?.Role switch { DTOs.ChatMessage.RoleEnum.System => ChatRole.System, @@ -100,7 +67,32 @@ async IAsyncEnumerable IChatClient.CompleteStream _ => ChatFinishReason.Stop }, Text = choice.Delta?.Content, + }; + + if (choice.Delta?.ToolCalls is { Count: > 0 }) + { + update.Contents = new List + { + new TextContent(choice.Delta.Content) + }; + + foreach (var toolCall in choice.Delta.ToolCalls) + { + Dictionary arguments = null; + if (toolCall.Function.Arguments is not null) + { + arguments = JsonSerializer.Deserialize>(toolCall.Function.Arguments.ToString()); + } + + update.Contents.Add(new FunctionCallContent( + toolCall.Id, + toolCall.Function.Name, + arguments)); + } + } + + yield return update; } if (response.Usage is { } usage) @@ -133,33 +125,30 @@ private static ChatCompletionRequest CreateRequest(IList() + case Microsoft.Extensions.AI.FunctionResultContent frc: + return new DTOs.ChatMessage(frc.CallId, frc.Name, frc.Result?.ToString()); + case Microsoft.Extensions.AI.FunctionCallContent fcc: + return new DTOs.ChatMessage() { - new ToolCall() + Role = DTOs.ChatMessage.RoleEnum.Assistant, + ToolCalls = new List() { - Id = fcc.CallId, - Function = new ToolCallParameter() + new ToolCall() { - Arguments = JsonSerializer.SerializeToNode(fcc.Arguments), - Name = fcc.Name, + Id = fcc.CallId, + Function = new ToolCallParameter() + { + Arguments = JsonSerializer.SerializeToNode(fcc.Arguments), + Name = fcc.Name, + } } } - } - }; + }; } - } return new DTOs.ChatMessage(role, string.Concat(m.Contents.OfType())); @@ -216,8 +205,6 @@ private static List ProcessResponseContent(ChatCompletionResponse res Dictionary arguments = null; if (toolCall.Function.Arguments is not null) { - string jsonString = toolCall.Function.Arguments.AsValue().ToJsonString(); - jsonString = System.Text.RegularExpressions.Regex.Unescape(jsonString); // Decode Unicode escape sequences arguments = JsonSerializer.Deserialize>(toolCall.Function.Arguments.ToString()); } From 83bc8c35161691888b13e5ce01a4d36ab27b3821 Mon Sep 17 00:00:00 2001 From: tghamm Date: Sun, 1 Dec 2024 20:40:02 -0500 Subject: [PATCH 6/8] eliminates source generation conflicts with function calling --- Mistral.SDK.Tests/FunctionCalling.cs | 22 +++++----- Mistral.SDK/Common/Function.cs | 2 +- .../Converters/MistralSdkJsonOption.cs | 43 ------------------- Mistral.SDK/Embeddings/EmbeddingsEndpoint.cs | 2 +- Mistral.SDK/EndpointBase.cs | 9 ++-- Mistral.SDK/Extensions/TypeExtensions.cs | 8 ++-- Mistral.SDK/Models/ModelsEndpoint.cs | 2 +- 7 files changed, 23 insertions(+), 65 deletions(-) delete mode 100644 Mistral.SDK/Converters/MistralSdkJsonOption.cs diff --git a/Mistral.SDK.Tests/FunctionCalling.cs b/Mistral.SDK.Tests/FunctionCalling.cs index 6869fc6..5b61a79 100644 --- a/Mistral.SDK.Tests/FunctionCalling.cs +++ b/Mistral.SDK.Tests/FunctionCalling.cs @@ -206,10 +206,11 @@ public async Task TestBasicToolDeclaredGlobally() var client = new MistralClient(); var messages = new List() { - new ChatMessage(ChatMessage.RoleEnum.User, "What is the current weather in San Francisco?") + new ChatMessage(ChatMessage.RoleEnum.User, "What is the weather in San Francisco, CA in Fahrenheit?") }; var request = new ChatCompletionRequest("mistral-large-latest", messages); - + request.MaxTokens = 1024; + request.Temperature = 0.0m; request.ToolChoice = ToolChoiceType.Auto; request.Tools = Common.Tool.GetAllAvailableTools(includeDefaults: false, forceUpdate: true, clearCache: true).ToList(); @@ -223,7 +224,8 @@ public async Task TestBasicToolDeclaredGlobally() var resp = await toolCall.InvokeAsync(); messages.Add(new ChatMessage(toolCall, resp)); } - + //request.ToolChoice = ToolChoiceType.none; + //request.Tools = null; var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("72")); @@ -305,19 +307,19 @@ public async Task TestBoolTool() var client = new MistralClient(); var messages = new List() { - new ChatMessage(ChatMessage.RoleEnum.User,"Should I roll the dice? Your answer should contain the word yes or no.") + new ChatMessage(ChatMessage.RoleEnum.User,"Should I wear a hat? It's warm outside.") }; var request = new ChatCompletionRequest("mistral-large-latest", messages); - request.ToolChoice = ToolChoiceType.Auto; + request.ToolChoice = ToolChoiceType.Any; request.Tools = new List { - Common.Tool.FromFunc("Dice_Roller", - ([FunctionParameter("Decides whether to roll the dice", true)]bool rollDice)=> + Common.Tool.FromFunc("Hat_Determiner", + ([FunctionParameter("Is it cold outside", true)]bool isItCold)=> { return "no"; - }, "Decides whether the user should roll the dice") + }, "Determines whether you should wear a heat based on whether it's cold outside.") }; var response = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); @@ -329,10 +331,10 @@ public async Task TestBoolTool() var resp = toolCall.Invoke(); messages.Add(new ChatMessage(toolCall, resp)); } - + request.ToolChoice = ToolChoiceType.none; var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); - Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("no")); + Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("warm")); } diff --git a/Mistral.SDK/Common/Function.cs b/Mistral.SDK/Common/Function.cs index 3a7cf2f..043e3c6 100644 --- a/Mistral.SDK/Common/Function.cs +++ b/Mistral.SDK/Common/Function.cs @@ -434,7 +434,7 @@ public async Task InvokeAsync(CancellationToken cancellationToken = defaul } else if (value is JsonElement element) { - invokeArgs[i] = JsonSerializer.Deserialize(element.GetRawText(), parameter.ParameterType, MistralSdkJsonOption.Options); + invokeArgs[i] = JsonSerializer.Deserialize(element.GetRawText(), parameter.ParameterType); } else { diff --git a/Mistral.SDK/Converters/MistralSdkJsonOption.cs b/Mistral.SDK/Converters/MistralSdkJsonOption.cs deleted file mode 100644 index 5421e3f..0000000 --- a/Mistral.SDK/Converters/MistralSdkJsonOption.cs +++ /dev/null @@ -1,43 +0,0 @@ -using System.Collections.Generic; -using System.Text.Json; -using System.Text.Json.Serialization; -using Mistral.SDK.DTOs; - -namespace Mistral.SDK.Converters; - -public static class MistralSdkJsonOption -{ - -#if NET8_0_OR_GREATER - public static readonly JsonSerializerOptions Options = null; -#else - public static readonly JsonSerializerOptions Options = null; -#endif -} - -#if NET8_0_OR_GREATER - -[JsonSourceGenerationOptions(DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] -[JsonSerializable(typeof(ChatCompletionRequest))] -[JsonSerializable(typeof(ChatCompletionResponse))] -[JsonSerializable(typeof(ChatMessage))] -[JsonSerializable(typeof(Choice))] -[JsonSerializable(typeof(EmbeddingRequest))] -[JsonSerializable(typeof(EmbeddingResponse))] -[JsonSerializable(typeof(EmbeddingResult))] -[JsonSerializable(typeof(Error))] -[JsonSerializable(typeof(ErrorResponse))] -[JsonSerializable(typeof(ModelList))] -[JsonSerializable(typeof(ResponseFormat))] -[JsonSerializable(typeof(Usage))] -[JsonSerializable(typeof(Common.Function))] -[JsonSerializable(typeof(List))] -[JsonSerializable(typeof(Common.Tool))] -[JsonSerializable(typeof(List))] -[JsonSerializable(typeof(List))] -[JsonSerializable(typeof(decimal?))] -[JsonSerializable(typeof(bool?))] -[JsonSerializable(typeof(ToolChoiceType))] -public sealed partial class JsonContext : JsonSerializerContext; - -#endif \ No newline at end of file diff --git a/Mistral.SDK/Embeddings/EmbeddingsEndpoint.cs b/Mistral.SDK/Embeddings/EmbeddingsEndpoint.cs index 84ebbc4..aa1ce0e 100644 --- a/Mistral.SDK/Embeddings/EmbeddingsEndpoint.cs +++ b/Mistral.SDK/Embeddings/EmbeddingsEndpoint.cs @@ -37,7 +37,7 @@ public async Task GetEmbeddingsAsync(EmbeddingRequest request #endif var res = await JsonSerializer.DeserializeAsync( - new MemoryStream(Encoding.UTF8.GetBytes(resultAsString)), MistralSdkJsonOption.Options, cancellationToken: cancellationToken) + new MemoryStream(Encoding.UTF8.GetBytes(resultAsString)), cancellationToken: cancellationToken) .ConfigureAwait(false); return res; diff --git a/Mistral.SDK/EndpointBase.cs b/Mistral.SDK/EndpointBase.cs index 2a38c94..c609322 100644 --- a/Mistral.SDK/EndpointBase.cs +++ b/Mistral.SDK/EndpointBase.cs @@ -94,7 +94,7 @@ protected async Task HttpRequest(string url = null, Http #endif var res = await JsonSerializer.DeserializeAsync( - new MemoryStream(Encoding.UTF8.GetBytes(resultAsString)), MistralSdkJsonOption.Options, cancellationToken: cancellationToken) + new MemoryStream(Encoding.UTF8.GetBytes(resultAsString)), cancellationToken: cancellationToken) .ConfigureAwait(false); return res; @@ -117,8 +117,7 @@ protected async Task HttpRequestRaw(string url = null, Http } else { - string jsonContent = JsonSerializer.Serialize(postData, MistralSdkJsonOption.Options ?? - new JsonSerializerOptions() { DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull }); + string jsonContent = JsonSerializer.Serialize(postData, new JsonSerializerOptions() { DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull }); var stringContent = new StringContent(jsonContent, Encoding.UTF8, "application/json"); req.Content = stringContent; } @@ -201,14 +200,14 @@ protected async IAsyncEnumerable HttpStreamingRequest(st else if (currentEvent.EventType == null) { var res = await JsonSerializer.DeserializeAsync( - new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), MistralSdkJsonOption.Options, cancellationToken: cancellationToken) + new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), cancellationToken: cancellationToken) .ConfigureAwait(false); yield return res; } else if (currentEvent.EventType != null) { var res = await JsonSerializer.DeserializeAsync( - new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), MistralSdkJsonOption.Options, cancellationToken: cancellationToken) + new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), cancellationToken: cancellationToken) .ConfigureAwait(false); throw new Exception(res.Error.Message); } diff --git a/Mistral.SDK/Extensions/TypeExtensions.cs b/Mistral.SDK/Extensions/TypeExtensions.cs index c98bf9a..7e9cf7d 100644 --- a/Mistral.SDK/Extensions/TypeExtensions.cs +++ b/Mistral.SDK/Extensions/TypeExtensions.cs @@ -123,7 +123,7 @@ public static JsonObject GenerateJsonSchema(this Type type, JsonObject rootSchem foreach (var value in Enum.GetValues(type)) { - schema["enum"].AsArray().Add(JsonNode.Parse(JsonSerializer.Serialize(value, MistralSdkJsonOption.Options))); + schema["enum"].AsArray().Add(JsonNode.Parse(JsonSerializer.Serialize(value))); } } else if (type.IsArray || (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(List<>))) @@ -190,7 +190,7 @@ public static JsonObject GenerateJsonSchema(this Type type, JsonObject rootSchem if (functionPropertyAttribute.DefaultValue != null) { - defaultValue = JsonNode.Parse(JsonSerializer.Serialize(functionPropertyAttribute.DefaultValue, MistralSdkJsonOption.Options)); + defaultValue = JsonNode.Parse(JsonSerializer.Serialize(functionPropertyAttribute.DefaultValue)); propertyInfo["default"] = defaultValue; } @@ -200,7 +200,7 @@ public static JsonObject GenerateJsonSchema(this Type type, JsonObject rootSchem foreach (var value in functionPropertyAttribute.PossibleValues) { - var @enum = JsonNode.Parse(JsonSerializer.Serialize(value, MistralSdkJsonOption.Options)); + var @enum = JsonNode.Parse(JsonSerializer.Serialize(value)); if (defaultValue == null) { @@ -217,7 +217,7 @@ public static JsonObject GenerateJsonSchema(this Type type, JsonObject rootSchem if (defaultValue != null && !enums.Contains(defaultValue)) { - enums.Add(JsonNode.Parse(defaultValue.ToJsonString(MistralSdkJsonOption.Options))); + enums.Add(JsonNode.Parse(defaultValue.ToJsonString())); } propertyInfo["enum"] = enums; diff --git a/Mistral.SDK/Models/ModelsEndpoint.cs b/Mistral.SDK/Models/ModelsEndpoint.cs index fa531f2..f8332cb 100644 --- a/Mistral.SDK/Models/ModelsEndpoint.cs +++ b/Mistral.SDK/Models/ModelsEndpoint.cs @@ -35,7 +35,7 @@ public async Task GetModelsAsync(CancellationToken cancellationToken #endif var res = await JsonSerializer.DeserializeAsync( - new MemoryStream(Encoding.UTF8.GetBytes(resultAsString)), MistralSdkJsonOption.Options, cancellationToken: cancellationToken) + new MemoryStream(Encoding.UTF8.GetBytes(resultAsString)), cancellationToken: cancellationToken) .ConfigureAwait(false); return res; From d9983d7cb4c3baa6dca8067a83ef543f58777668 Mon Sep 17 00:00:00 2001 From: tghamm Date: Sun, 1 Dec 2024 21:43:30 -0500 Subject: [PATCH 7/8] sk integration test, serialization quirks --- Mistral.SDK.Tests/FunctionCalling.cs | 11 ++-- Mistral.SDK.Tests/Mistral.SDK.Tests.csproj | 1 + .../SemanticKernelInitializationTests.cs | 66 +++++++++++++++++++ Mistral.SDK/Common/Function.cs | 2 +- Mistral.SDK/Common/Tool.cs | 20 +++++- Mistral.SDK/Extensions/TypeExtensions.cs | 8 +-- Mistral.SDK/MistralClient.cs | 9 +++ 7 files changed, 105 insertions(+), 12 deletions(-) create mode 100644 Mistral.SDK.Tests/SemanticKernelInitializationTests.cs diff --git a/Mistral.SDK.Tests/FunctionCalling.cs b/Mistral.SDK.Tests/FunctionCalling.cs index 5b61a79..2175410 100644 --- a/Mistral.SDK.Tests/FunctionCalling.cs +++ b/Mistral.SDK.Tests/FunctionCalling.cs @@ -86,7 +86,7 @@ public async Task TestInstanceObjectTool() { new ChatMessage(ChatMessage.RoleEnum.User, "What is the weather in San Francisco, CA?") }; - var request = new ChatCompletionRequest("mistral-large-latest", messages); + var request = new ChatCompletionRequest("mistral-small-latest", messages); request.ToolChoice = ToolChoiceType.Auto; @@ -208,7 +208,7 @@ public async Task TestBasicToolDeclaredGlobally() { new ChatMessage(ChatMessage.RoleEnum.User, "What is the weather in San Francisco, CA in Fahrenheit?") }; - var request = new ChatCompletionRequest("mistral-large-latest", messages); + var request = new ChatCompletionRequest("mistral-small-latest", messages); request.MaxTokens = 1024; request.Temperature = 0.0m; request.ToolChoice = ToolChoiceType.Auto; @@ -224,8 +224,7 @@ public async Task TestBasicToolDeclaredGlobally() var resp = await toolCall.InvokeAsync(); messages.Add(new ChatMessage(toolCall, resp)); } - //request.ToolChoice = ToolChoiceType.none; - //request.Tools = null; + var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("72")); @@ -309,7 +308,7 @@ public async Task TestBoolTool() { new ChatMessage(ChatMessage.RoleEnum.User,"Should I wear a hat? It's warm outside.") }; - var request = new ChatCompletionRequest("mistral-large-latest", messages); + var request = new ChatCompletionRequest("mistral-small-latest", messages); request.ToolChoice = ToolChoiceType.Any; @@ -334,7 +333,7 @@ public async Task TestBoolTool() request.ToolChoice = ToolChoiceType.none; var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); - Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("warm")); + Assert.IsTrue(finalResult.Choices.First().Message.Content.ToLower().Contains("no")); } diff --git a/Mistral.SDK.Tests/Mistral.SDK.Tests.csproj b/Mistral.SDK.Tests/Mistral.SDK.Tests.csproj index 30a1318..5e795d8 100644 --- a/Mistral.SDK.Tests/Mistral.SDK.Tests.csproj +++ b/Mistral.SDK.Tests/Mistral.SDK.Tests.csproj @@ -15,6 +15,7 @@ + diff --git a/Mistral.SDK.Tests/SemanticKernelInitializationTests.cs b/Mistral.SDK.Tests/SemanticKernelInitializationTests.cs new file mode 100644 index 0000000..a016449 --- /dev/null +++ b/Mistral.SDK.Tests/SemanticKernelInitializationTests.cs @@ -0,0 +1,66 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +#pragma warning disable SKEXP0001 + +namespace Mistral.SDK.Tests +{ + [TestClass] + public class SemanticKernelInitializationTests + { + [TestMethod] + public async Task TestSKInit() + { + var skChatService = + new ChatClientBuilder(new MistralClient().Completions) + .UseFunctionInvocation() + .Build() + .AsChatCompletionService(); + + + var sk = Kernel.CreateBuilder(); + sk.Plugins.AddFromType("Weather"); + sk.Services.AddSingleton(skChatService); + + var kernel = sk.Build(); + var chatCompletionService = kernel.Services.GetRequiredService(); + // Create chat history + var history = new ChatHistory(); + history.AddUserMessage("What is the weather like in San Francisco right now?"); + OpenAIPromptExecutionSettings promptExecutionSettings = new() + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(), + ModelId = "mistral-small-latest", + MaxTokens = 1024, + Temperature = 0.0, + }; + + // Get the response from the AI + var result = await chatCompletionService.GetChatMessageContentAsync( + history, + executionSettings: promptExecutionSettings, + kernel: kernel + ); ; + + + Assert.IsTrue(result.Content.Contains("72")); + } + } + + public class SkPlugins + { + [KernelFunction("GetWeather")] + [Description("Gets the weather for a given location")] + public async Task GetWeather(string location) + { + return "It is 72 degrees and sunny in " + location; + } + } +} diff --git a/Mistral.SDK/Common/Function.cs b/Mistral.SDK/Common/Function.cs index 043e3c6..883ebaf 100644 --- a/Mistral.SDK/Common/Function.cs +++ b/Mistral.SDK/Common/Function.cs @@ -434,7 +434,7 @@ public async Task InvokeAsync(CancellationToken cancellationToken = defaul } else if (value is JsonElement element) { - invokeArgs[i] = JsonSerializer.Deserialize(element.GetRawText(), parameter.ParameterType); + invokeArgs[i] = JsonSerializer.Deserialize(element.GetRawText(), parameter.ParameterType, MistralClient.JsonSerializationOptions); } else { diff --git a/Mistral.SDK/Common/Tool.cs b/Mistral.SDK/Common/Tool.cs index 25f5bdb..44d14fd 100644 --- a/Mistral.SDK/Common/Tool.cs +++ b/Mistral.SDK/Common/Tool.cs @@ -168,6 +168,24 @@ private static bool TryGetTool(string name, object instance, out Tool tool) return false; } + private static IEnumerable GetAssemblyTypes(Assembly assembly) + { + try + { + return assembly.GetTypes(); + } + catch (ReflectionTypeLoadException ex) + { + // Return the types that could be loaded + return ex.Types.Where(t => t != null); + } + catch + { + // Return an empty sequence if any other exception occurs + return Enumerable.Empty(); + } + } + /// /// Gets a list of all available tools. /// @@ -190,7 +208,7 @@ public static IReadOnlyList GetAllAvailableTools(bool includeDefaults = tr var tools = new List(); tools.AddRange( from assembly in AppDomain.CurrentDomain.GetAssemblies() - from type in assembly.GetTypes() + from type in GetAssemblyTypes(assembly) from method in type.GetMethods() where method.IsStatic let functionAttribute = method.GetCustomAttribute() diff --git a/Mistral.SDK/Extensions/TypeExtensions.cs b/Mistral.SDK/Extensions/TypeExtensions.cs index 7e9cf7d..7ca6869 100644 --- a/Mistral.SDK/Extensions/TypeExtensions.cs +++ b/Mistral.SDK/Extensions/TypeExtensions.cs @@ -123,7 +123,7 @@ public static JsonObject GenerateJsonSchema(this Type type, JsonObject rootSchem foreach (var value in Enum.GetValues(type)) { - schema["enum"].AsArray().Add(JsonNode.Parse(JsonSerializer.Serialize(value))); + schema["enum"].AsArray().Add(JsonNode.Parse(JsonSerializer.Serialize(value, MistralClient.JsonSerializationOptions))); } } else if (type.IsArray || (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(List<>))) @@ -190,7 +190,7 @@ public static JsonObject GenerateJsonSchema(this Type type, JsonObject rootSchem if (functionPropertyAttribute.DefaultValue != null) { - defaultValue = JsonNode.Parse(JsonSerializer.Serialize(functionPropertyAttribute.DefaultValue)); + defaultValue = JsonNode.Parse(JsonSerializer.Serialize(functionPropertyAttribute.DefaultValue, MistralClient.JsonSerializationOptions)); propertyInfo["default"] = defaultValue; } @@ -200,7 +200,7 @@ public static JsonObject GenerateJsonSchema(this Type type, JsonObject rootSchem foreach (var value in functionPropertyAttribute.PossibleValues) { - var @enum = JsonNode.Parse(JsonSerializer.Serialize(value)); + var @enum = JsonNode.Parse(JsonSerializer.Serialize(value, MistralClient.JsonSerializationOptions)); if (defaultValue == null) { @@ -217,7 +217,7 @@ public static JsonObject GenerateJsonSchema(this Type type, JsonObject rootSchem if (defaultValue != null && !enums.Contains(defaultValue)) { - enums.Add(JsonNode.Parse(defaultValue.ToJsonString())); + enums.Add(JsonNode.Parse(defaultValue.ToJsonString(MistralClient.JsonSerializationOptions))); } propertyInfo["enum"] = enums; diff --git a/Mistral.SDK/MistralClient.cs b/Mistral.SDK/MistralClient.cs index 2b80a3c..6bfb6b0 100644 --- a/Mistral.SDK/MistralClient.cs +++ b/Mistral.SDK/MistralClient.cs @@ -2,6 +2,8 @@ using System.Collections.Generic; using System.Net.Http; using System.Text; +using System.Text.Json.Serialization; +using System.Text.Json; using Mistral.SDK.Completions; using Mistral.SDK.Embeddings; using Mistral.SDK.Models; @@ -51,6 +53,13 @@ public MistralClient(APIAuthentication apiKeys = null, HttpClient client = null) Embeddings = new EmbeddingsEndpoint(this); } + internal static JsonSerializerOptions JsonSerializationOptions { get; } = new() + { + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + Converters = { new JsonStringEnumConverter() }, + ReferenceHandler = ReferenceHandler.IgnoreCycles, + }; + private HttpClient SetupClient(HttpClient client) { if (client is not null) From 6f3c7e647751eae544849eb8ee98f19524431ecc Mon Sep 17 00:00:00 2001 From: tghamm Date: Mon, 2 Dec 2024 06:52:09 -0500 Subject: [PATCH 8/8] Version Bump, Readme Update --- Mistral.SDK.Tests/ChatClient.cs | 4 +- Mistral.SDK.Tests/FunctionCalling.cs | 22 +-- .../SemanticKernelInitializationTests.cs | 2 +- Mistral.SDK/Mistral.SDK.csproj | 8 +- README.md | 137 +++++++++++++++++- 5 files changed, 151 insertions(+), 22 deletions(-) diff --git a/Mistral.SDK.Tests/ChatClient.cs b/Mistral.SDK.Tests/ChatClient.cs index 23c7eab..b72e426 100644 --- a/Mistral.SDK.Tests/ChatClient.cs +++ b/Mistral.SDK.Tests/ChatClient.cs @@ -97,7 +97,7 @@ public async Task TestNonStreamingFunctionCalls() ChatOptions options = new() { - ModelId = "mistral-large-latest", + ModelId = ModelDefinitions.MistralSmall, MaxOutputTokens = 512, ToolMode = ChatToolMode.Auto, Tools = [AIFunctionFactory.Create((string personName) => personName switch { @@ -123,7 +123,7 @@ public async Task TestStreamingFunctionCalls() ChatOptions options = new() { - ModelId = "mistral-large-latest", + ModelId = ModelDefinitions.MistralSmall, MaxOutputTokens = 512, ToolMode = ChatToolMode.Auto, Tools = [AIFunctionFactory.Create((string personName) => personName switch { diff --git a/Mistral.SDK.Tests/FunctionCalling.cs b/Mistral.SDK.Tests/FunctionCalling.cs index 2175410..4ff744b 100644 --- a/Mistral.SDK.Tests/FunctionCalling.cs +++ b/Mistral.SDK.Tests/FunctionCalling.cs @@ -54,7 +54,7 @@ public async Task TestStaticObjectTool() { new ChatMessage(ChatMessage.RoleEnum.User, "What is the weather in San Francisco, CA?") }; - var request = new ChatCompletionRequest("mistral-large-latest", messages); + var request = new ChatCompletionRequest(ModelDefinitions.MistralSmall, messages); request.ToolChoice = ToolChoiceType.Auto; @@ -86,7 +86,7 @@ public async Task TestInstanceObjectTool() { new ChatMessage(ChatMessage.RoleEnum.User, "What is the weather in San Francisco, CA?") }; - var request = new ChatCompletionRequest("mistral-small-latest", messages); + var request = new ChatCompletionRequest(ModelDefinitions.MistralSmall, messages); request.ToolChoice = ToolChoiceType.Auto; @@ -119,7 +119,7 @@ public async Task TestBasicFunctionStreaming() { new ChatMessage(ChatMessage.RoleEnum.User, "How many characters are in the word Christmas, multiply by 5, add 6, subtract 2, then divide by 2.1?") }; - var request = new ChatCompletionRequest("mistral-large-latest", messages); + var request = new ChatCompletionRequest(ModelDefinitions.MistralSmall, messages); request.ToolChoice = ToolChoiceType.Any; @@ -156,7 +156,6 @@ public async Task TestBasicFunctionStreaming() { finalMessage += response.Choices.First().Delta.Content; } - //var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); Assert.IsTrue(finalMessage.Contains("23")); } @@ -172,7 +171,7 @@ public async Task TestBasicFunction() { new ChatMessage(ChatMessage.RoleEnum.User, "What is the current weather in San Francisco?") }; - var request = new ChatCompletionRequest("mistral-large-latest", messages); + var request = new ChatCompletionRequest(ModelDefinitions.MistralSmall, messages); request.ToolChoice = ToolChoiceType.Auto; @@ -208,7 +207,7 @@ public async Task TestBasicToolDeclaredGlobally() { new ChatMessage(ChatMessage.RoleEnum.User, "What is the weather in San Francisco, CA in Fahrenheit?") }; - var request = new ChatCompletionRequest("mistral-small-latest", messages); + var request = new ChatCompletionRequest(ModelDefinitions.MistralSmall, messages); request.MaxTokens = 1024; request.Temperature = 0.0m; request.ToolChoice = ToolChoiceType.Auto; @@ -238,7 +237,7 @@ public async Task TestTestEmptyArgsAndMultiTool() { new ChatMessage(ChatMessage.RoleEnum.User, "What is the current user's name?") }; - var request = new ChatCompletionRequest("mistral-large-latest", messages); + var request = new ChatCompletionRequest(ModelDefinitions.MistralSmall, messages); request.ToolChoice = ToolChoiceType.Auto; @@ -267,7 +266,7 @@ public async Task TestMathFuncTool() { new ChatMessage(ChatMessage.RoleEnum.User,"How many characters are in the word Christmas, multiply by 5, add 6, subtract 2, then divide by 2.1?") }; - var request = new ChatCompletionRequest("mistral-large-latest", messages); + var request = new ChatCompletionRequest(ModelDefinitions.MistralSmall, messages); request.ToolChoice = ToolChoiceType.Auto; @@ -308,7 +307,7 @@ public async Task TestBoolTool() { new ChatMessage(ChatMessage.RoleEnum.User,"Should I wear a hat? It's warm outside.") }; - var request = new ChatCompletionRequest("mistral-small-latest", messages); + var request = new ChatCompletionRequest(ModelDefinitions.MistralSmall, messages); request.ToolChoice = ToolChoiceType.Any; @@ -324,16 +323,17 @@ public async Task TestBoolTool() var response = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); messages.Add(response.Choices.First().Message); - + var hitTool = false; foreach (var toolCall in response.ToolCalls) { var resp = toolCall.Invoke(); messages.Add(new ChatMessage(toolCall, resp)); + hitTool = true; } request.ToolChoice = ToolChoiceType.none; var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); - Assert.IsTrue(finalResult.Choices.First().Message.Content.ToLower().Contains("no")); + Assert.IsTrue(hitTool); } diff --git a/Mistral.SDK.Tests/SemanticKernelInitializationTests.cs b/Mistral.SDK.Tests/SemanticKernelInitializationTests.cs index a016449..9d920ee 100644 --- a/Mistral.SDK.Tests/SemanticKernelInitializationTests.cs +++ b/Mistral.SDK.Tests/SemanticKernelInitializationTests.cs @@ -37,7 +37,7 @@ public async Task TestSKInit() OpenAIPromptExecutionSettings promptExecutionSettings = new() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(), - ModelId = "mistral-small-latest", + ModelId = ModelDefinitions.MistralSmall, MaxTokens = 1024, Temperature = 0.0, }; diff --git a/Mistral.SDK/Mistral.SDK.csproj b/Mistral.SDK/Mistral.SDK.csproj index e3b25bd..b1b2db7 100644 --- a/Mistral.SDK/Mistral.SDK.csproj +++ b/Mistral.SDK/Mistral.SDK.csproj @@ -14,12 +14,12 @@ Mistral, AI, ML, API, C#, .NET, Mixtral Mistral API - Improves Cancellation Support + Function Calling Support Mistral.SDK - 1.3.3 - 1.3.3.0 - 1.3.3.0 + 2.0.0 + 2.0.0.0 + 2.0.0.0 True README.md True diff --git a/README.md b/README.md index ae35dfc..9ef52e1 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Mistral.SDK is an unofficial C# client designed for interacting with the Mistral - [IChatClient](#ichatclient) - [List Models](#list-models) - [Embeddings](#embeddings) + - [Function Calling](#function-calling) - [Contributing](#contributing) - [License](#license) @@ -37,7 +38,7 @@ The `MistralClient` can optionally take a custom `HttpClient` in the `MistralCli ## Usage -There are two ways to start using the `MistralClient`. The first is to simply new up an instance of the `MistralClient` and start using it, the second is to use the messaging/Embedding client with the new `Microsoft.Extensions.AI.Abstractions` builder. +There are three ways to start using the `MistralClient`. The first is to simply new up an instance of the `MistralClient` and start using it, the second is to use the messaging/Embedding client with the new `Microsoft.Extensions.AI.Abstractions` builder. The third is to use the Completions client with `Microsoft.SemanticKernel`. Brief examples of each are below. Option 1: @@ -56,7 +57,25 @@ IChatClient client = new MistralClient().Completions; IEmbeddingGenerator> client = new MistralClient().Embeddings; ``` -Both support all the core features of the `MistralClient's` Messaging and Embedding capabilities, but the latter will be fully featured in .NET 9 and provide built in telemetry and DI and make it easier to choose which SDK you are using. +Option 3: + +```csharp +using Microsoft.SemanticKernel; + +var skChatService = + new ChatClientBuilder(new MistralClient().Completions) + .UseFunctionInvocation() + .Build() + .AsChatCompletionService(); + + +var sk = Kernel.CreateBuilder(); +sk.Plugins.AddFromType("Weather"); +sk.Services.AddSingleton(skChatService); +``` +See integration tests for a more complete example. + +All support all the core features of the `MistralClient's` Messaging and Embedding capabilities, but the latter will be fully featured in .NET 9 and provide built in telemetry and DI and make it easier to choose which SDK you are using. ## Examples @@ -150,8 +169,31 @@ EmbeddingGenerator> client = new MistralClient().Embedd var response = await client.GenerateEmbeddingVectorAsync("hello world", new() { ModelId = ModelDefinitions.MistralEmbed }); Assert.IsTrue(!response.IsEmpty); +//Functions call +IChatClient client = new MistralClient().Completions + .AsBuilder() + .UseFunctionInvocation() + .Build(); + +ChatOptions options = new() +{ + ModelId = ModelDefinitions.MistralSmall, + MaxOutputTokens = 512, + ToolMode = ChatToolMode.Auto, + Tools = [AIFunctionFactory.Create((string personName) => personName switch { + "Alice" => "25", + _ => "40" + }, "GetPersonAge", "Gets the age of the person whose name is specified.")] +}; + +var res = await client.CompleteAsync("How old is Alice?", options); + +Assert.IsTrue( + res.Message.Text?.Contains("25") is true, + res.Message.Text); + ``` -Please see the unit tests for even more examples. +Please see the integration tests for even more examples. ### List Models @@ -176,9 +218,96 @@ var request = new EmbeddingRequest( var response = await client.Embeddings.GetEmbeddingsAsync(request); ``` +### Function Calling + +The `MistralClient` supports Function Calling through a variety of mechanisms. It's worth noting that currently some models seem to hallucinate function calling behavior more than others, and this is a known issue with Mistral. + +```csharp +public enum TempType +{ + Fahrenheit, + Celsius +} + +[Function("This function returns the weather for a given location")] +public static async Task GetWeather([FunctionParameter("Location of the weather", true)] string location, + [FunctionParameter("Unit of temperature, celsius or fahrenheit", true)] TempType tempType) +{ + await Task.Yield(); + return "72 degrees and sunny"; +} + +//declared globally +var client = new MistralClient(); +var messages = new List() +{ + new ChatMessage(ChatMessage.RoleEnum.User, "What is the weather in San Francisco, CA in Fahrenheit?") +}; +var request = new ChatCompletionRequest(ModelDefinitions.MistralSmall, messages); +request.MaxTokens = 1024; +request.Temperature = 0.0m; +request.ToolChoice = ToolChoiceType.Auto; + +request.Tools = Common.Tool.GetAllAvailableTools(includeDefaults: false, forceUpdate: true, clearCache: true).ToList(); + +var response = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + +messages.Add(response.Choices.First().Message); + +foreach (var toolCall in response.ToolCalls) +{ + var resp = await toolCall.InvokeAsync(); + messages.Add(new ChatMessage(toolCall, resp)); +} + +var finalResult = await client.Completions.GetCompletionAsync(request).ConfigureAwait(false); + +Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("72")); + +//from a func +var client = new MistralClient(); +var messages = new List() +{ + new ChatMessage(ChatMessage.RoleEnum.User,"How many characters are in the word Christmas, multiply by 5, add 6, subtract 2, then divide by 2.1?") +}; +var request = new ChatCompletionRequest(ModelDefinitions.MistralSmall, messages); + +request.ToolChoice = ToolChoiceType.Auto; + +request.Tools = new List +{ + Common.Tool.FromFunc("ChristmasMathFunction", + ([FunctionParameter("word to start with", true)]string word, + [FunctionParameter("number to multiply word count by", true)]int multiplier, + [FunctionParameter("amount to add to word count", true)]int addition, + [FunctionParameter("amount to subtract from word count", true)]int subtraction, + [FunctionParameter("amount to divide word count by", true)]double divisor) => + { + return ((word.Length * multiplier + addition - subtraction) / divisor).ToString(CultureInfo.InvariantCulture); + }, "Function that can be used to determine the number of characters in a word combined with a mathematical formula") +}; + +var response = await client.Completions.GetCompletionAsync(request); + +messages.Add(response.Choices.First().Message); + +foreach (var toolCall in response.ToolCalls) +{ + var resp = toolCall.Invoke(); + messages.Add(new ChatMessage(toolCall, resp)); +} + +var finalResult = await client.Completions.GetCompletionAsync(request); + +Assert.IsTrue(finalResult.Choices.First().Message.Content.Contains("23")); + +//see integration tests for examples like streaming function calls, calling a static or instance based function, and more. + +``` + ## Contributing -Pull requests are welcome. If you're planning to make a major change, please open an issue first to discuss your proposed changes. +Pull requests are welcome with associated integration tests. If you're planning to make a major change, please open an issue first to discuss your proposed changes. ## License