diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs index 689bca9c4da..296c089ce1f 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs @@ -61,6 +61,16 @@ public ChatMessage Message /// Gets or sets the ID of the chat completion. public string? CompletionId { get; set; } + /// Gets or sets the chat thread ID associated with this chat completion. + /// + /// Some implementations are capable of storing the state for a chat thread, such that + /// the input messages supplied to need only be the additional messages beyond + /// what's already stored. If this property is non-, it represents an identifier for that state, + /// and it should be used in a subsequent instead of supplying the same messages + /// (and this 's message) as part of the chatMessages parameter. + /// + public string? ChatThreadId { get; set; } + /// Gets or sets the model ID used in the creation of the chat completion. public string? ModelId { get; set; } @@ -133,6 +143,7 @@ public StreamingChatCompletionUpdate[] ToStreamingChatCompletionUpdates() ChatMessage choice = Choices[choiceIndex]; updates[choiceIndex] = new StreamingChatCompletionUpdate { + ChatThreadId = ChatThreadId, ChoiceIndex = choiceIndex, AdditionalProperties = choice.AdditionalProperties, diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs index d52cc36cdbb..e7faeba0ee1 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs @@ -39,6 +39,22 @@ public ChatMessage( _contents = Throw.IfNull(contents); } + /// Clones the to a new instance. + /// A shallow clone of the original message object. + /// + /// This is a shallow clone. The returned instance is different from the original, but all properties + /// refer to the same objects as the original. + /// + public ChatMessage Clone() => + new() + { + AdditionalProperties = AdditionalProperties, + _authorName = _authorName, + _contents = _contents, + RawRepresentation = RawRepresentation, + Role = Role, + }; + /// Gets or sets the name of the author of the message. public string? AuthorName { diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs index 69adc8392fd..12ce2f56860 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs @@ -9,6 +9,9 @@ namespace Microsoft.Extensions.AI; /// Represents the options for a chat request. public class ChatOptions { + /// Gets or sets an optional identifier used to associate a request with an existing chat thread. + public string? ChatThreadId { get; set; } + /// Gets or sets the temperature for generating chat responses. public float? Temperature { get; set; } @@ -72,6 +75,7 @@ public virtual ChatOptions Clone() { ChatOptions options = new() { + ChatThreadId = ChatThreadId, Temperature = Temperature, MaxOutputTokens = MaxOutputTokens, TopP = TopP, diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs index 36ae500e138..e50fd42169b 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs @@ -102,6 +102,16 @@ public IList Contents /// Gets or sets the ID of the completion of which this update is a part. public string? CompletionId { get; set; } + /// Gets or sets the chat thread ID associated with the chat completion of which this update is a part. + /// + /// Some implementations are capable of storing the state for a chat thread, such that + /// the input messages supplied to need only be the additional messages beyond + /// what's already stored. If this property is non-, it represents an identifier for that state, + /// and it should be used in a subsequent instead of supplying the same messages + /// (and this streaming message) as part of the chatMessages parameter. + /// + public string? ChatThreadId { get; set; } + /// Gets or sets a timestamp for the completion update. public DateTimeOffset? CreatedAt { get; set; } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs index b70d7471b80..9694b0e4dc0 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs @@ -13,6 +13,7 @@ #pragma warning disable S109 // Magic numbers should not be used #pragma warning disable S127 // "for" loop stop conditions should be invariant +#pragma warning disable S1121 // Assignments should not be made from within sub-expressions namespace Microsoft.Extensions.AI; @@ -103,7 +104,21 @@ private static void ProcessUpdate(StreamingChatCompletionUpdate update, Dictiona } #endif - ((List)message.Contents).AddRange(update.Contents); + // Incorporate all content from the update into the completion. + foreach (var content in update.Contents) + { + switch (content) + { + // Usage content is treated specially and propagated to the completion's Usage. + case UsageContent usage: + (completion.Usage ??= new()).Add(usage.Details); + break; + + default: + message.Contents.Add(content); + break; + } + } message.AuthorName ??= update.AuthorName; if (update.Role is ChatRole role && message.Role == default) @@ -178,20 +193,6 @@ static void AddMessage(ChatCompletion completion, bool coalesceContent, KeyValue } completion.Choices.Add(entry.Value); - - if (completion.Usage is null) - { - foreach (var content in entry.Value.Contents) - { - if (content is UsageContent c) - { - completion.Usage = c.Details; - entry.Value.Contents = entry.Value.Contents.ToList(); - _ = entry.Value.Contents.Remove(c); - break; - } - } - } } } diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs new file mode 100644 index 00000000000..b0f36e43313 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs @@ -0,0 +1,351 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; +using OpenAI; +using OpenAI.Assistants; +using OpenAI.Chat; + +#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning disable CA1031 // Do not catch general exception types +#pragma warning disable S1067 // Expressions should not be too complex +#pragma warning disable S1751 // Loops with at most one iteration should be refactored +#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields +#pragma warning disable SA1204 // Static elements should appear before instance elements +#pragma warning disable SA1108 // Block statements should not contain embedded comments + +namespace Microsoft.Extensions.AI; + +/// Represents an for an OpenAI or . +internal sealed class OpenAIAssistantClient : IChatClient +{ + /// Metadata for the client. + private readonly ChatClientMetadata _metadata; + + /// The underlying . + private readonly AssistantClient _assistantClient; + + /// The ID of the assistant to use. + private readonly string _assistantId; + + /// The thread ID to use if none is supplied in . + private readonly string? _threadId; + + /// Initializes a new instance of the class for the specified . + /// The underlying client. + /// The ID of the assistant to use. + /// + /// The ID of the thread to use. If not supplied here, it should be supplied per request in . + /// If none is supplied, a new thread will be created for a request. + /// + public OpenAIAssistantClient(AssistantClient assistantClient, string assistantId, string? threadId) + { + _assistantClient = Throw.IfNull(assistantClient); + _assistantId = Throw.IfNull(assistantId); + _threadId = threadId; + + _metadata = new("openai"); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + _ = Throw.IfNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(ChatClientMetadata) ? _metadata : + serviceType == typeof(AssistantClient) ? _assistantClient : + serviceType.IsInstanceOfType(this) ? this : + null; + } + + /// + public Task CompleteAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + CompleteStreamingAsync(chatMessages, options, cancellationToken).ToChatCompletionAsync(coalesceContent: true, cancellationToken); + + /// + public async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // Extract necessary state from chatMessages and options. + (RunCreationOptions runOptions, List? toolResults) = CreateRunOptions(chatMessages, options); + + // Get the thread ID. + string? threadId = options?.ChatThreadId ?? _threadId; + if (threadId is null && toolResults is not null) + { + Throw.ArgumentException(nameof(chatMessages), "No thread ID was provided, but chat messages includes tool results."); + } + + // Get the updates to process from the assistant. If we have any tool results, this means submitting those and ignoring + // our runOptions. Otherwise, create a run, and a thread if we don't have one. + IAsyncEnumerable updates; + if (GetRunId(toolResults, out List? toolOutputs) is string existingRunId) + { + updates = _assistantClient.SubmitToolOutputsToRunStreamingAsync(threadId, existingRunId, toolOutputs, cancellationToken); + } + else if (threadId is null) + { + ThreadCreationOptions creationOptions = new(); + foreach (var message in runOptions.AdditionalMessages) + { + creationOptions.InitialMessages.Add(message); + } + + runOptions.AdditionalMessages.Clear(); + + updates = _assistantClient.CreateThreadAndRunStreamingAsync(_assistantId, creationOptions, runOptions, cancellationToken: cancellationToken); + } + else + { + updates = _assistantClient.CreateRunStreamingAsync(threadId, _assistantId, runOptions, cancellationToken); + } + + // Process each update. + await foreach (var update in updates.ConfigureAwait(false)) + { + switch (update) + { + case MessageContentUpdate mcu: + yield return new() + { + ChatThreadId = threadId, + RawRepresentation = mcu, + Role = mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant, + Text = mcu.Text, + }; + break; + + case ThreadUpdate tu when options is not null: + threadId ??= tu.Value.Id; + break; + + case RunUpdate ru: + threadId ??= ru.Value.ThreadId; + + StreamingChatCompletionUpdate ruUpdate = new() + { + AuthorName = ru.Value.AssistantId, + ChatThreadId = threadId, + CompletionId = ru.Value.Id, + CreatedAt = ru.Value.CreatedAt, + ModelId = ru.Value.Model, + RawRepresentation = ru, + Role = ChatRole.Assistant, + }; + + if (ru.Value.Usage is { } usage) + { + ruUpdate.Contents.Add(new UsageContent(new() + { + InputTokenCount = usage.InputTokenCount, + OutputTokenCount = usage.OutputTokenCount, + TotalTokenCount = usage.TotalTokenCount, + })); + } + + if (ru is RequiredActionUpdate rau && rau.ToolCallId is string toolCallId && rau.FunctionName is string functionName) + { + ruUpdate.Contents.Add( + new FunctionCallContent( + JsonSerializer.Serialize(new[] { ru.Value.Id, toolCallId }, OpenAIJsonContext.Default.StringArray!), + functionName, + JsonSerializer.Deserialize(rau.FunctionArguments, OpenAIJsonContext.Default.IDictionaryStringObject)!)); + } + + yield return ruUpdate; + break; + } + } + } + + /// + void IDisposable.Dispose() + { + // Nothing to dispose. Implementation required for the IChatClient interface. + } + + /// Adds the provided messages to the thread and returns the options to use for the request. + private static (RunCreationOptions RunOptions, List? ToolResults) CreateRunOptions(IList chatMessages, ChatOptions? options) + { + _ = Throw.IfNull(chatMessages); + + RunCreationOptions runOptions = new(); + + // Handle ChatOptions. + if (options is not null) + { + // Propagate the simple properties that have a 1:1 correspondence. + runOptions.MaxOutputTokenCount = options.MaxOutputTokens; + runOptions.ModelOverride = options.ModelId; + runOptions.NucleusSamplingFactor = options.TopP; + runOptions.Temperature = options.Temperature; + + // Propagate additional properties from AdditionalProperties. + if (options.AdditionalProperties?.TryGetValue(nameof(RunCreationOptions.AllowParallelToolCalls), out bool allowParallelToolCalls) is true) + { + runOptions.AllowParallelToolCalls = allowParallelToolCalls; + } + + if (options.AdditionalProperties?.TryGetValue(nameof(RunCreationOptions.MaxInputTokenCount), out int maxInputTokenCount) is true) + { + runOptions.MaxInputTokenCount = maxInputTokenCount; + } + + if (options.AdditionalProperties?.TryGetValue(nameof(RunCreationOptions.TruncationStrategy), out RunTruncationStrategy? truncationStrategy) is true) + { + runOptions.TruncationStrategy = truncationStrategy; + } + + // Store all the tools to use. + if (options.Tools is { Count: > 0 } tools) + { + foreach (AITool tool in tools) + { + if (tool is AIFunction aiFunction) + { + bool? strict = + aiFunction.Metadata.AdditionalProperties.TryGetValue("Strict", out object? strictObj) && + strictObj is bool strictValue ? + strictValue : null; + + var functionParameters = BinaryData.FromBytes( + JsonSerializer.SerializeToUtf8Bytes( + JsonSerializer.Deserialize(aiFunction.Metadata.Schema, OpenAIJsonContext.Default.OpenAIChatToolJson)!, + OpenAIJsonContext.Default.OpenAIChatToolJson)); + + runOptions.ToolsOverride.Add(ToolDefinition.CreateFunction(aiFunction.Metadata.Name, aiFunction.Metadata.Description, functionParameters, strict)); + } + } + } + + // Store the tool mode. + switch (options.ToolMode) + { + case AutoChatToolMode: + runOptions.ToolConstraint = ToolConstraint.Auto; + break; + + case RequiredChatToolMode required: + runOptions.ToolConstraint = required.RequiredFunctionName is null ? + new ToolConstraint(ToolDefinition.CreateFunction(required.RequiredFunctionName)) : + ToolConstraint.Required; + break; + } + + // Store the response format. + if (options.ResponseFormat is ChatResponseFormatText) + { + runOptions.ResponseFormat = AssistantResponseFormat.Text; + } + else if (options.ResponseFormat is ChatResponseFormatJson jsonFormat) + { + runOptions.ResponseFormat = jsonFormat.Schema is { } jsonSchema ? + AssistantResponseFormat.CreateJsonSchemaFormat( + jsonFormat.SchemaName ?? "json_schema", + BinaryData.FromBytes(JsonSerializer.SerializeToUtf8Bytes(jsonSchema, OpenAIJsonContext.Default.JsonElement)), + jsonFormat.SchemaDescription) : + AssistantResponseFormat.JsonObject; + } + } + + // Handle ChatMessages. System messages are turned into additional instructions. + StringBuilder? instructions = null; + List? functionResults = null; + foreach (var chatMessage in chatMessages) + { + List messageContents = []; + + if (chatMessage.Role == ChatRole.System) + { + instructions ??= new(); + foreach (var textContent in chatMessage.Contents.OfType()) + { + _ = instructions.Append(textContent); + } + + continue; + } + + foreach (AIContent content in chatMessage.Contents) + { + switch (content) + { + case TextContent tc: + messageContents.Add(MessageContent.FromText(tc.Text)); + break; + + case DataContent dc when dc.MediaTypeStartsWith("image/"): + messageContents.Add(MessageContent.FromImageUri(new(dc.Uri))); + break; + + case FunctionResultContent frc: + (functionResults ??= []).Add(frc); + break; + } + } + + if (messageContents.Count > 0) + { + runOptions.AdditionalMessages.Add(new( + chatMessage.Role == ChatRole.Assistant ? MessageRole.Assistant : MessageRole.User, + messageContents)); + } + } + + if (instructions is not null) + { + runOptions.AdditionalInstructions = instructions.ToString(); + } + + return (runOptions, functionResults); + } + + private static string? GetRunId(List? toolResults, out List? toolOutputs) + { + string? runId = null; + toolOutputs = null; + if (toolResults?.Count > 0) + { + foreach (var frc in toolResults) + { + // When creating the FunctionCallContext, we created it with a CallId == [runId, callId]. + // We need to extract the run ID and ensure that the ToolOutput we send back to OpenAI + // is only the call ID. + string[]? runAndCallIDs; + try + { + runAndCallIDs = JsonSerializer.Deserialize(frc.CallId, OpenAIJsonContext.Default.StringArray); + } + catch + { + continue; + } + + if (runAndCallIDs is null || + runAndCallIDs.Length != 2 || + string.IsNullOrWhiteSpace(runAndCallIDs[0]) || // run ID + string.IsNullOrWhiteSpace(runAndCallIDs[1]) || // call ID + (runId is not null && runId != runAndCallIDs[0])) + { + continue; + } + + runId = runAndCallIDs[0]; + (toolOutputs ??= []).Add(new(runAndCallIDs[1], frc.Result?.ToString() ?? string.Empty)); + } + } + + return runId; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs index 61f4fedcb46..6e3aa019c77 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs @@ -21,8 +21,8 @@ namespace Microsoft.Extensions.AI; /// Represents an for an OpenAI or . public sealed class OpenAIChatClient : IChatClient { - /// Default OpenAI endpoint. - private static readonly Uri _defaultOpenAIEndpoint = new("https://api.openai.com/v1"); + /// Gets the default OpenAI endpoint. + internal static Uri DefaultOpenAIEndpoint { get; } = new("https://api.openai.com/v1"); /// Metadata about the client. private readonly ChatClientMetadata _metadata; @@ -52,7 +52,7 @@ public OpenAIChatClient(OpenAIClient openAIClient, string modelId) // implement the abstractions directly rather than providing adapters on top of the public APIs, // the package can provide such implementations separate from what's exposed in the public API. Uri providerUrl = typeof(OpenAIClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) - ?.GetValue(openAIClient) as Uri ?? _defaultOpenAIEndpoint; + ?.GetValue(openAIClient) as Uri ?? DefaultOpenAIEndpoint; _metadata = new("openai", providerUrl, modelId); } @@ -70,7 +70,7 @@ public OpenAIChatClient(ChatClient chatClient) // implement the abstractions directly rather than providing adapters on top of the public APIs, // the package can provide such implementations separate from what's exposed in the public API. Uri providerUrl = typeof(ChatClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) - ?.GetValue(chatClient) as Uri ?? _defaultOpenAIEndpoint; + ?.GetValue(chatClient) as Uri ?? DefaultOpenAIEndpoint; string? model = typeof(ChatClient).GetField("_model", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) ?.GetValue(chatClient) as string; diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs index 2bea9264730..49f78518015 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using OpenAI; +using OpenAI.Assistants; using OpenAI.Chat; using OpenAI.Embeddings; @@ -23,6 +24,19 @@ public static IChatClient AsChatClient(this OpenAIClient openAIClient, string mo public static IChatClient AsChatClient(this ChatClient chatClient) => new OpenAIChatClient(chatClient); +#pragma warning disable OPENAI001 // Type is for evaluation purposes only + /// Gets an for use with this . + /// The client. + /// The ID of the assistant to use. + /// + /// The ID of the thread to use. If not supplied here, it should be supplied per request in . + /// If none is supplied, a new thread will be created for a request. + /// + /// An that can be used to converse via the . + public static IChatClient AsChatClient(this AssistantClient assistantClient, string assistantId, string? threadId = null) => + new OpenAIAssistantClient(assistantClient, assistantId, threadId); +#pragma warning restore OPENAI001 + /// Gets an for use with this . /// The client. /// The model to use. diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIJsonContext.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIJsonContext.cs index 69f610b4818..c75b2a4c644 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIJsonContext.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIJsonContext.cs @@ -15,4 +15,5 @@ namespace Microsoft.Extensions.AI; [JsonSerializable(typeof(OpenAIRealtimeExtensions.ConversationFunctionToolParametersSchema))] [JsonSerializable(typeof(OpenAIModelMappers.OpenAIChatToolJson))] [JsonSerializable(typeof(IDictionary))] +[JsonSerializable(typeof(string[]))] internal sealed partial class OpenAIJsonContext : JsonSerializerContext; diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs index b5671232f8d..2612980b34f 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs @@ -23,7 +23,7 @@ namespace Microsoft.Extensions.AI; internal static partial class OpenAIModelMappers { - private static readonly JsonElement _defaultParameterSchema = JsonDocument.Parse("{}").RootElement; + internal static JsonElement DefaultParameterSchema { get; } = JsonDocument.Parse("{}").RootElement; public static OpenAI.Chat.ChatCompletion ToOpenAIChatCompletion(ChatCompletion chatCompletion, JsonSerializerOptions options) { @@ -382,7 +382,7 @@ private static AITool FromOpenAIChatTool(ChatTool chatTool) ReturnParameter = new() { Description = "Return parameter", - Schema = _defaultParameterSchema, + Schema = DefaultParameterSchema, } }; diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index 59716092b7a..6e3923ed1ad 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -140,7 +140,8 @@ public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = nul public bool ConcurrentInvocation { get; set; } /// - /// Gets or sets a value indicating whether to keep intermediate messages in the chat history. + /// Gets or sets a value indicating whether to keep intermediate function calling request + /// and response messages in the chat history. /// /// /// if intermediate messages persist in the list provided @@ -155,14 +156,20 @@ public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = nul /// those messages to the list of messages, along with instances /// it creates with the results of invoking the requested functions. The resulting augmented /// list of messages is then passed to the inner client in order to send the results back. - /// By default, those messages persist in the list provided to - /// and by the caller. Set - /// to to remove those messages prior to completing the operation. + /// By default, those messages persist in the list provided to + /// and by the caller, such that those + /// messages are available to the caller. Set to avoid including + /// those messages in the caller-provided . /// /// /// Changing the value of this property while the client is in use might result in inconsistencies /// as to whether function calling messages are kept during an in-flight request. /// + /// + /// If the underlying responds with + /// set to a non- value, this property may be ignored and behave as if it is + /// , with any such intermediate messages not stored in the messages list. + /// /// public bool KeepFunctionCallingMessages { get; set; } = true; @@ -211,10 +218,8 @@ public override async Task CompleteAsync(IList chat using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); ChatCompletion? response = null; - HashSet? messagesToRemove = null; - HashSet? contentsToRemove = null; UsageDetails? totalUsage = null; - + IList originalChatMessages = chatMessages; try { for (int iteration = 0; ; iteration++) @@ -256,54 +261,52 @@ public override async Task CompleteAsync(IList chat break; } - // Track all added messages in order to remove them, if requested. - if (!KeepFunctionCallingMessages) - { - messagesToRemove ??= []; - } - - // Add the original response message into the history and track the message for removal. - chatMessages.Add(response.Message); - if (messagesToRemove is not null) + // Update the chat history. If the underlying client is tracking the state, then we want to avoid re-sending + // what we already sent as well as this response message, so create a new list to store the response message(s). + if (response.ChatThreadId is not null) { - if (functionCallContents.Length == response.Message.Contents.Count) + if (chatMessages == originalChatMessages) { - // The most common case is that the response message contains only function calling content. - // In that case, we can just track the whole message for removal. - _ = messagesToRemove.Add(response.Message); + chatMessages = []; } else { - // In the less likely case where some content is function calling and some isn't, we don't want to remove - // the non-function calling content by removing the whole message. So we track the content directly. - (contentsToRemove ??= []).UnionWith(functionCallContents); + chatMessages.Clear(); } } - - // Add the responses from the function calls into the history. - var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); - if (modeAndMessages.MessagesAdded is not null) + else { - messagesToRemove?.UnionWith(modeAndMessages.MessagesAdded); + // Otherwise, we need to add the response message to the history we're sending back. However, if the caller + // doesn't want the intermediate messages, create a new list that we mutate instead of mutating the original. + if (!KeepFunctionCallingMessages) + { + // Create a new list that will include the message with the function call contents. + if (chatMessages == originalChatMessages) + { + chatMessages = [.. chatMessages]; + } + + // We want to include any non-functional calling content, if there is any, + // in the caller's list so that they don't lose out on actual content. + // This can happen but is relatively rare. + if (response.Message.Contents.Any(c => c is not FunctionCallContent)) + { + var clone = response.Message.Clone(); + clone.Contents = clone.Contents.Where(c => c is not FunctionCallContent).ToList(); + originalChatMessages.Add(clone); + } + } + + // Add the original response message into the history. + chatMessages.Add(response.Message); } - switch (modeAndMessages.Mode) + // Add the responses from the function calls into the history. + var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); + if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, response.ChatThreadId)) { - case ContinueMode.Continue when options.ToolMode is RequiredChatToolMode: - // We have to reset this after the first iteration, otherwise we'll be in an infinite loop. - options = options.Clone(); - options.ToolMode = null; - break; - - case ContinueMode.AllowOneMoreRoundtrip: - // The LLM gets one further chance to answer, but cannot use tools. - options = options.Clone(); - options.Tools = null; - break; - - case ContinueMode.Terminate: - // Bail immediately. - return response; + // Terminate + return response; } } @@ -311,8 +314,6 @@ public override async Task CompleteAsync(IList chat } finally { - RemoveMessagesAndContentFromList(messagesToRemove, contentsToRemove, chatMessages); - if (response is not null) { response.Usage = totalUsage; @@ -330,102 +331,94 @@ public override async IAsyncEnumerable CompleteSt // Create an activity to group them together for better observability. using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); - HashSet? messagesToRemove = null; List functionCallContents = []; int? choice; - try + IList originalChatMessages = chatMessages; + for (int iteration = 0; ; iteration++) { - for (int iteration = 0; ; iteration++) + choice = null; + string? chatThreadId = null; + functionCallContents.Clear(); + await foreach (var update in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) { - choice = null; - functionCallContents.Clear(); - await foreach (var update in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + // We're going to emit all StreamingChatMessage items upstream, even ones that represent + // function calls, because a given StreamingChatMessage can contain other content, too. + // And if we yield the function calls, and the consumer adds all the content into a message + // that's then added into history, they'll end up with function call contents that aren't + // directly paired with function result contents, which may cause issues for some models + // when the history is later sent again. + + // Find all the FCCs. We need to track these separately in order to be able to process them later. + int preFccCount = functionCallContents.Count; + functionCallContents.AddRange(update.Contents.OfType()); + + // If there were any, remove them from the update. We do this before yielding the update so + // that we're not modifying an instance already provided back to the caller. + int addedFccs = functionCallContents.Count - preFccCount; + if (addedFccs > 0) { - // We're going to emit all StreamingChatMessage items upstream, even ones that represent - // function calls, because a given StreamingChatMessage can contain other content, too. - // And if we yield the function calls, and the consumer adds all the content into a message - // that's then added into history, they'll end up with function call contents that aren't - // directly paired with function result contents, which may cause issues for some models - // when the history is later sent again. - - // Find all the FCCs. We need to track these separately in order to be able to process them later. - int preFccCount = functionCallContents.Count; - functionCallContents.AddRange(update.Contents.OfType()); - - // If there were any, remove them from the update. We do this before yielding the update so - // that we're not modifying an instance already provided back to the caller. - int addedFccs = functionCallContents.Count - preFccCount; - if (addedFccs > 0) - { - update.Contents = addedFccs == update.Contents.Count ? - [] : update.Contents.Where(c => c is not FunctionCallContent).ToList(); - } - - // Only one choice is allowed with automatic function calling. - if (choice is null) - { - choice = update.ChoiceIndex; - } - else if (choice != update.ChoiceIndex) - { - ThrowForMultipleChoices(); - } - - yield return update; - Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 + update.Contents = addedFccs == update.Contents.Count ? + [] : update.Contents.Where(c => c is not FunctionCallContent).ToList(); } - // If there are no tools to call, or for any other reason we should stop, return the response. - if (options is null - || options.Tools is not { Count: > 0 } - || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations) - || functionCallContents is not { Count: > 0 }) + // Only one choice is allowed with automatic function calling. + if (choice is null) { - break; + choice = update.ChoiceIndex; } - - // Track all added messages in order to remove them, if requested. - if (!KeepFunctionCallingMessages) + else if (choice != update.ChoiceIndex) { - messagesToRemove ??= []; + ThrowForMultipleChoices(); } - // Add a manufactured response message containing the function call contents to the chat history. - ChatMessage functionCallMessage = new(ChatRole.Assistant, [.. functionCallContents]); - chatMessages.Add(functionCallMessage); - _ = messagesToRemove?.Add(functionCallMessage); + chatThreadId ??= update.ChatThreadId; - // Process all of the functions, adding their results into the history. - var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); - if (modeAndMessages.MessagesAdded is not null) + yield return update; + Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 + } + + // If there are no tools to call, or for any other reason we should stop, return the response. + if (options is null + || options.Tools is not { Count: > 0 } + || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations) + || functionCallContents is not { Count: > 0 }) + { + break; + } + + // Update the chat history. If the underlying client is tracking the state, then we want to avoid re-sending + // what we already sent as well as this response message, so create a new list to store the response message(s). + if (chatThreadId is not null) + { + if (chatMessages == originalChatMessages) { - messagesToRemove?.UnionWith(modeAndMessages.MessagesAdded); + chatMessages = []; } - - // Decide how to proceed based on the result of the function calls. - switch (modeAndMessages.Mode) + else { - case ContinueMode.Continue when options.ToolMode is RequiredChatToolMode: - // We have to reset this after the first iteration, otherwise we'll be in an infinite loop. - options = options.Clone(); - options.ToolMode = null; - break; - - case ContinueMode.AllowOneMoreRoundtrip: - // The LLM gets one further chance to answer, but cannot use tools. - options = options.Clone(); - options.Tools = null; - break; - - case ContinueMode.Terminate: - // Bail immediately. - yield break; + chatMessages.Clear(); } } - } - finally - { - RemoveMessagesAndContentFromList(messagesToRemove, contentToRemove: null, chatMessages); + else + { + // Otherwise, we need to add the response message to the history we're sending back. However, if the caller + // doesn't want the intermediate messages, create a new list that we mutate instead of mutating the original. + if (chatMessages == originalChatMessages && !KeepFunctionCallingMessages) + { + chatMessages = [.. chatMessages]; + } + + // Add a manufactured response message containing the function call contents to the chat history. + chatMessages.Add(new(ChatRole.Assistant, [.. functionCallContents])); + } + + // Process all of the functions, adding their results into the history. + var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); + if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, chatThreadId)) + { + // Terminate + yield break; + } } } @@ -439,42 +432,53 @@ private static void ThrowForMultipleChoices() throw new InvalidOperationException("Automatic function call invocation only accepts a single choice, but multiple choices were received."); } - /// - /// Removes all of the messages in from - /// and all of the content in from the messages in . - /// - private static void RemoveMessagesAndContentFromList( - HashSet? messagesToRemove, - HashSet? contentToRemove, - IList messages) + /// Updates for the response. + /// true if the function calling loop should terminate; otherwise, false. + private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions options, string? chatThreadId) { - Debug.Assert( - contentToRemove is null || messagesToRemove is not null, - "We should only be tracking content to remove if we're also tracking messages to remove."); - - if (messagesToRemove is not null) + switch (mode) { - for (int m = messages.Count - 1; m >= 0; m--) - { - ChatMessage message = messages[m]; + case ContinueMode.Continue when options.ToolMode is RequiredChatToolMode: + // We have to reset the tool mode to be non-required after the first iteration, + // as otherwise we'll be in an infinite loop. + options = options.Clone(); + options.ToolMode = null; + if (chatThreadId is not null) + { + options.ChatThreadId = chatThreadId; + } + + break; - if (contentToRemove is not null) + case ContinueMode.AllowOneMoreRoundtrip: + // The LLM gets one further chance to answer, but cannot use tools. + options = options.Clone(); + options.Tools = null; + options.ToolMode = null; + if (chatThreadId is not null) { - for (int c = message.Contents.Count - 1; c >= 0; c--) - { - if (contentToRemove.Contains(message.Contents[c])) - { - message.Contents.RemoveAt(c); - } - } + options.ChatThreadId = chatThreadId; } - if (messages.Count == 0 || messagesToRemove.Contains(messages[m])) + break; + + case ContinueMode.Terminate: + // Bail immediately. + return true; + + default: + // As with the other modes, ensure we've propagated the chat thread ID to the options. + // We only need to clone the options if we're actually mutating it. + if (chatThreadId is not null && options.ChatThreadId != chatThreadId) { - messages.RemoveAt(m); + options = options.Clone(); + options.ChatThreadId = chatThreadId; } - } + + break; } + + return false; } /// @@ -630,7 +634,7 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul { string message = result.Status switch { - FunctionStatus.NotFound => "Error: Requested function not found.", + FunctionStatus.NotFound => $"Error: Requested function \"{result.CallContent.Name}\" not found.", FunctionStatus.Failed => "Error: Function failed.", _ => "Error: Unknown error.", }; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs index 4e3ceadd793..498be7ecb1e 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs @@ -13,6 +13,7 @@ public class ChatOptionsTests public void Constructor_Parameterless_PropsDefaulted() { ChatOptions options = new(); + Assert.Null(options.ChatThreadId); Assert.Null(options.Temperature); Assert.Null(options.MaxOutputTokens); Assert.Null(options.TopP); @@ -28,6 +29,7 @@ public void Constructor_Parameterless_PropsDefaulted() Assert.Null(options.AdditionalProperties); ChatOptions clone = options.Clone(); + Assert.Null(options.ChatThreadId); Assert.Null(clone.Temperature); Assert.Null(clone.MaxOutputTokens); Assert.Null(clone.TopP); @@ -65,6 +67,7 @@ public void Properties_Roundtrip() ["key"] = "value", }; + options.ChatThreadId = "12345"; options.Temperature = 0.1f; options.MaxOutputTokens = 2; options.TopP = 0.3f; @@ -79,6 +82,7 @@ public void Properties_Roundtrip() options.Tools = tools; options.AdditionalProperties = additionalProps; + Assert.Equal("12345", options.ChatThreadId); Assert.Equal(0.1f, options.Temperature); Assert.Equal(2, options.MaxOutputTokens); Assert.Equal(0.3f, options.TopP); @@ -94,6 +98,7 @@ public void Properties_Roundtrip() Assert.Same(additionalProps, options.AdditionalProperties); ChatOptions clone = options.Clone(); + Assert.Equal("12345", options.ChatThreadId); Assert.Equal(0.1f, clone.Temperature); Assert.Equal(2, clone.MaxOutputTokens); Assert.Equal(0.3f, clone.TopP); @@ -125,6 +130,7 @@ public void JsonSerialization_Roundtrips() ["key"] = "value", }; + options.ChatThreadId = "12345"; options.Temperature = 0.1f; options.MaxOutputTokens = 2; options.TopP = 0.3f; @@ -148,6 +154,7 @@ public void JsonSerialization_Roundtrips() ChatOptions? deserialized = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatOptions); Assert.NotNull(deserialized); + Assert.Equal("12345", deserialized.ChatThreadId); Assert.Equal(0.1f, deserialized.Temperature); Assert.Equal(2, deserialized.MaxOutputTokens); Assert.Equal(0.3f, deserialized.TopP); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs index 9af25dbd16a..5b5294d24f4 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs @@ -61,10 +61,13 @@ public async Task ToChatCompletion_SuccessfullyCreatesCompletion(bool useAsync, }; Assert.NotNull(completion); + Assert.NotNull(completion.Usage); + Assert.Equal(5, completion.Usage.InputTokenCount); + Assert.Equal(7, completion.Usage.OutputTokenCount); + Assert.Equal("12345", completion.CompletionId); Assert.Equal(new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), completion.CreatedAt); Assert.Equal("model123", completion.ModelId); - Assert.Same(Assert.IsType(updates[6].Contents[0]).Details, completion.Usage); Assert.Equal(3, completion.Choices.Count); @@ -89,7 +92,7 @@ public async Task ToChatCompletion_SuccessfullyCreatesCompletion(bool useAsync, Assert.Equal(ChatRole.Assistant, message.Role); Assert.Null(message.AuthorName); Assert.Null(message.AdditionalProperties); - Assert.Same(updates[7].Contents[0], Assert.Single(message.Contents)); + Assert.Empty(message.Contents); if (coalesceContent is null or true) { diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index 540ca7c8431..2690eb7181c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -19,6 +19,9 @@ namespace Microsoft.Extensions.AI; public class FunctionInvokingChatClientTests { + private readonly Func _keepMessagesConfigure = + b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = true }); + [Fact] public void InvalidArgs_Throws() { @@ -64,9 +67,9 @@ public async Task SupportsSingleFunctionCallPerRequestAsync() new ChatMessage(ChatRole.Assistant, "world"), ]; - await InvokeAndAssertAsync(options, plan); + await InvokeAndAssertAsync(options, plan, configurePipeline: _keepMessagesConfigure); - await InvokeAndAssertStreamingAsync(options, plan); + await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: _keepMessagesConfigure); } [Theory] @@ -111,7 +114,8 @@ public async Task SupportsMultipleFunctionCallsPerRequestAsync(bool concurrentIn new ChatMessage(ChatRole.Assistant, "world"), ]; - Func configure = b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = concurrentInvocation }); + Func configure = b => b.Use( + s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = concurrentInvocation, KeepFunctionCallingMessages = true }); await InvokeAndAssertAsync(options, plan, configurePipeline: configure); @@ -151,7 +155,8 @@ public async Task ParallelFunctionCallsMayBeInvokedConcurrentlyAsync() new ChatMessage(ChatRole.Assistant, "done"), ]; - Func configure = b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = true }); + Func configure = b => b.Use( + s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = true, KeepFunctionCallingMessages = true }); await InvokeAndAssertAsync(options, plan, configurePipeline: configure); @@ -194,9 +199,9 @@ public async Task ConcurrentInvocationOfParallelCallsDisabledByDefaultAsync() new ChatMessage(ChatRole.Assistant, "done"), ]; - await InvokeAndAssertAsync(options, plan); + await InvokeAndAssertAsync(options, plan, configurePipeline: _keepMessagesConfigure); - await InvokeAndAssertStreamingAsync(options, plan); + await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: _keepMessagesConfigure); } [Theory] @@ -232,7 +237,8 @@ public async Task RemovesFunctionCallingMessagesWhenRequestedAsync(bool keepFunc new ChatMessage(ChatRole.Assistant, "world") ]; - Func configure = b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages }); + Func configure = b => b.Use( + client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages }); Validate(await InvokeAndAssertAsync(options, plan, expected, configure)); Validate(await InvokeAndAssertStreamingAsync(options, plan, expected, configure)); @@ -254,7 +260,7 @@ void Validate(List finalChat) [Theory] [InlineData(false)] [InlineData(true)] - public async Task RemovesFunctionCallingContentWhenRequestedAsync(bool keepFunctionCallingMessages) + public async Task KeepsFunctionCallingContentWhenRequestedAsync(bool keepFunctionCallingMessages) { var options = new ChatOptions { @@ -278,7 +284,8 @@ public async Task RemovesFunctionCallingContentWhenRequestedAsync(bool keepFunct new ChatMessage(ChatRole.Assistant, "world"), ]; - Func configure = b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages }); + Func configure = b => b.Use( + client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages }); #pragma warning disable SA1005, S125 Validate(await InvokeAndAssertAsync(options, plan, keepFunctionCallingMessages ? null : @@ -340,7 +347,8 @@ public async Task ExceptionDetailsOnlyReportedWhenRequestedAsync(bool detailedEr new ChatMessage(ChatRole.Assistant, "world"), ]; - Func configure = b => b.Use(s => new FunctionInvokingChatClient(s) { DetailedErrors = detailedErrors }); + Func configure = b => b.Use( + s => new FunctionInvokingChatClient(s) { DetailedErrors = detailedErrors, KeepFunctionCallingMessages = true }); await InvokeAndAssertAsync(options, plan, configurePipeline: configure); @@ -405,7 +413,10 @@ public async Task FunctionInvocationsLogged(LogLevel level) }; Func configure = b => - b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService>())); + b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService>()) + { + KeepFunctionCallingMessages = true, + }); await InvokeAsync(services => InvokeAndAssertAsync(options, plan, configurePipeline: configure, services: services)); @@ -461,8 +472,10 @@ public async Task FunctionInvocationTrackedWithActivity(bool enableTelemetry) }; Func configure = b => b.Use(c => - new FunctionInvokingChatClient( - new OpenTelemetryChatClient(c, sourceName: sourceName))); + new FunctionInvokingChatClient(new OpenTelemetryChatClient(c, sourceName: sourceName)) + { + KeepFunctionCallingMessages = true, + }); await InvokeAsync(() => InvokeAndAssertAsync(options, plan, configurePipeline: configure)); @@ -529,7 +542,7 @@ public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls() } }; - using var client = new FunctionInvokingChatClient(innerClient); + using var client = new FunctionInvokingChatClient(innerClient) { KeepFunctionCallingMessages = true }; var updates = new List(); await foreach (var update in client.CompleteStreamingAsync(messages, options, CancellationToken.None)) @@ -603,7 +616,7 @@ await InvokeAsync(() => InvokeAndAssertAsync(options, plan, expected: [ // The last message is the one returned by the chat client // This message's content should contain the last function call before the termination new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func1", new Dictionary { ["i"] = 42 })]), - ])); + ], configurePipeline: _keepMessagesConfigure)); await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, expected: [ .. planBeforeTermination, @@ -611,7 +624,7 @@ await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, expected: [ // The last message is the one returned by the chat client // When streaming, function call content is removed from this message new ChatMessage(ChatRole.Assistant, []), - ])); + ], configurePipeline: _keepMessagesConfigure)); // The current context should be null outside the async call stack for the function invocation Assert.Null(FunctionInvokingChatClient.CurrentContext); @@ -640,6 +653,56 @@ void AssertInvocationContext(FunctionInvokingChatClient.FunctionInvocationContex } } + [Fact] + public async Task PropagatesCompletionChatThreadIdToOptions() + { + var options = new ChatOptions + { + Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")], + }; + + int iteration = 0; + + Func, ChatOptions?, CancellationToken, ChatCompletion> callback = + (chatContents, chatOptions, cancellationToken) => + { + iteration++; + + if (iteration == 1) + { + Assert.Null(chatOptions?.ChatThreadId); + return new ChatCompletion(new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId-abc", "Func1")])) + { + ChatThreadId = "12345", + }; + } + else if (iteration == 2) + { + Assert.Equal("12345", chatOptions?.ChatThreadId); + return new ChatCompletion(new ChatMessage(ChatRole.Assistant, "done!")); + } + else + { + throw new InvalidOperationException("Unexpected iteration"); + } + }; + + using var innerClient = new TestChatClient + { + CompleteAsyncCallback = (chatContents, chatOptions, cancellationToken) => + Task.FromResult(callback(chatContents, chatOptions, cancellationToken)), + CompleteStreamingAsyncCallback = (chatContents, chatOptions, cancellationToken) => + YieldAsync(callback(chatContents, chatOptions, cancellationToken).ToStreamingChatCompletionUpdates()), + }; + + using IChatClient service = innerClient.AsBuilder().UseFunctionInvocation().Build(); + + iteration = 0; + Assert.Equal("done!", (await service.CompleteAsync("hey", options)).ToString()); + iteration = 0; + Assert.Equal("done!", (await service.CompleteStreamingAsync("hey", options).ToChatCompletionAsync()).ToString()); + } + private static async Task> InvokeAndAssertAsync( ChatOptions options, List plan, @@ -659,7 +722,6 @@ private static async Task> InvokeAndAssertAsync( { CompleteAsyncCallback = async (contents, actualOptions, actualCancellationToken) => { - Assert.Same(chat, contents); Assert.Equal(cts.Token, actualCancellationToken); await Task.Yield(); @@ -753,7 +815,6 @@ private static async Task> InvokeAndAssertStreamingAsync( { CompleteStreamingAsyncCallback = (contents, actualOptions, actualCancellationToken) => { - Assert.Same(chat, contents); Assert.Equal(cts.Token, actualCancellationToken); return YieldAsync(new ChatCompletion(new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count].Contents])).ToStreamingChatCompletionUpdates());