diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs
index 689bca9c4da..296c089ce1f 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs
@@ -61,6 +61,16 @@ public ChatMessage Message
/// Gets or sets the ID of the chat completion.
public string? CompletionId { get; set; }
+ /// Gets or sets the chat thread ID associated with this chat completion.
+ ///
+ /// Some implementations are capable of storing the state for a chat thread, such that
+ /// the input messages supplied to need only be the additional messages beyond
+ /// what's already stored. If this property is non-, it represents an identifier for that state,
+ /// and it should be used in a subsequent instead of supplying the same messages
+ /// (and this 's message) as part of the chatMessages parameter.
+ ///
+ public string? ChatThreadId { get; set; }
+
/// Gets or sets the model ID used in the creation of the chat completion.
public string? ModelId { get; set; }
@@ -133,6 +143,7 @@ public StreamingChatCompletionUpdate[] ToStreamingChatCompletionUpdates()
ChatMessage choice = Choices[choiceIndex];
updates[choiceIndex] = new StreamingChatCompletionUpdate
{
+ ChatThreadId = ChatThreadId,
ChoiceIndex = choiceIndex,
AdditionalProperties = choice.AdditionalProperties,
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs
index d52cc36cdbb..e7faeba0ee1 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs
@@ -39,6 +39,22 @@ public ChatMessage(
_contents = Throw.IfNull(contents);
}
+ /// Clones the to a new instance.
+ /// A shallow clone of the original message object.
+ ///
+ /// This is a shallow clone. The returned instance is different from the original, but all properties
+ /// refer to the same objects as the original.
+ ///
+ public ChatMessage Clone() =>
+ new()
+ {
+ AdditionalProperties = AdditionalProperties,
+ _authorName = _authorName,
+ _contents = _contents,
+ RawRepresentation = RawRepresentation,
+ Role = Role,
+ };
+
/// Gets or sets the name of the author of the message.
public string? AuthorName
{
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs
index 69adc8392fd..12ce2f56860 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs
@@ -9,6 +9,9 @@ namespace Microsoft.Extensions.AI;
/// Represents the options for a chat request.
public class ChatOptions
{
+ /// Gets or sets an optional identifier used to associate a request with an existing chat thread.
+ public string? ChatThreadId { get; set; }
+
/// Gets or sets the temperature for generating chat responses.
public float? Temperature { get; set; }
@@ -72,6 +75,7 @@ public virtual ChatOptions Clone()
{
ChatOptions options = new()
{
+ ChatThreadId = ChatThreadId,
Temperature = Temperature,
MaxOutputTokens = MaxOutputTokens,
TopP = TopP,
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs
index 36ae500e138..e50fd42169b 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs
@@ -102,6 +102,16 @@ public IList Contents
/// Gets or sets the ID of the completion of which this update is a part.
public string? CompletionId { get; set; }
+ /// Gets or sets the chat thread ID associated with the chat completion of which this update is a part.
+ ///
+ /// Some implementations are capable of storing the state for a chat thread, such that
+ /// the input messages supplied to need only be the additional messages beyond
+ /// what's already stored. If this property is non-, it represents an identifier for that state,
+ /// and it should be used in a subsequent instead of supplying the same messages
+ /// (and this streaming message) as part of the chatMessages parameter.
+ ///
+ public string? ChatThreadId { get; set; }
+
/// Gets or sets a timestamp for the completion update.
public DateTimeOffset? CreatedAt { get; set; }
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs
index b70d7471b80..9694b0e4dc0 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs
@@ -13,6 +13,7 @@
#pragma warning disable S109 // Magic numbers should not be used
#pragma warning disable S127 // "for" loop stop conditions should be invariant
+#pragma warning disable S1121 // Assignments should not be made from within sub-expressions
namespace Microsoft.Extensions.AI;
@@ -103,7 +104,21 @@ private static void ProcessUpdate(StreamingChatCompletionUpdate update, Dictiona
}
#endif
- ((List)message.Contents).AddRange(update.Contents);
+ // Incorporate all content from the update into the completion.
+ foreach (var content in update.Contents)
+ {
+ switch (content)
+ {
+ // Usage content is treated specially and propagated to the completion's Usage.
+ case UsageContent usage:
+ (completion.Usage ??= new()).Add(usage.Details);
+ break;
+
+ default:
+ message.Contents.Add(content);
+ break;
+ }
+ }
message.AuthorName ??= update.AuthorName;
if (update.Role is ChatRole role && message.Role == default)
@@ -178,20 +193,6 @@ static void AddMessage(ChatCompletion completion, bool coalesceContent, KeyValue
}
completion.Choices.Add(entry.Value);
-
- if (completion.Usage is null)
- {
- foreach (var content in entry.Value.Contents)
- {
- if (content is UsageContent c)
- {
- completion.Usage = c.Details;
- entry.Value.Contents = entry.Value.Contents.ToList();
- _ = entry.Value.Contents.Remove(c);
- break;
- }
- }
- }
}
}
diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs
new file mode 100644
index 00000000000..b0f36e43313
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs
@@ -0,0 +1,351 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Shared.Diagnostics;
+using OpenAI;
+using OpenAI.Assistants;
+using OpenAI.Chat;
+
+#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
+#pragma warning disable CA1031 // Do not catch general exception types
+#pragma warning disable S1067 // Expressions should not be too complex
+#pragma warning disable S1751 // Loops with at most one iteration should be refactored
+#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
+#pragma warning disable SA1204 // Static elements should appear before instance elements
+#pragma warning disable SA1108 // Block statements should not contain embedded comments
+
+namespace Microsoft.Extensions.AI;
+
+/// Represents an for an OpenAI or .
+internal sealed class OpenAIAssistantClient : IChatClient
+{
+ /// Metadata for the client.
+ private readonly ChatClientMetadata _metadata;
+
+ /// The underlying .
+ private readonly AssistantClient _assistantClient;
+
+ /// The ID of the assistant to use.
+ private readonly string _assistantId;
+
+ /// The thread ID to use if none is supplied in .
+ private readonly string? _threadId;
+
+ /// Initializes a new instance of the class for the specified .
+ /// The underlying client.
+ /// The ID of the assistant to use.
+ ///
+ /// The ID of the thread to use. If not supplied here, it should be supplied per request in .
+ /// If none is supplied, a new thread will be created for a request.
+ ///
+ public OpenAIAssistantClient(AssistantClient assistantClient, string assistantId, string? threadId)
+ {
+ _assistantClient = Throw.IfNull(assistantClient);
+ _assistantId = Throw.IfNull(assistantId);
+ _threadId = threadId;
+
+ _metadata = new("openai");
+ }
+
+ ///
+ public object? GetService(Type serviceType, object? serviceKey = null)
+ {
+ _ = Throw.IfNull(serviceType);
+
+ return
+ serviceKey is not null ? null :
+ serviceType == typeof(ChatClientMetadata) ? _metadata :
+ serviceType == typeof(AssistantClient) ? _assistantClient :
+ serviceType.IsInstanceOfType(this) ? this :
+ null;
+ }
+
+ ///
+ public Task CompleteAsync(
+ IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) =>
+ CompleteStreamingAsync(chatMessages, options, cancellationToken).ToChatCompletionAsync(coalesceContent: true, cancellationToken);
+
+ ///
+ public async IAsyncEnumerable CompleteStreamingAsync(
+ IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ // Extract necessary state from chatMessages and options.
+ (RunCreationOptions runOptions, List? toolResults) = CreateRunOptions(chatMessages, options);
+
+ // Get the thread ID.
+ string? threadId = options?.ChatThreadId ?? _threadId;
+ if (threadId is null && toolResults is not null)
+ {
+ Throw.ArgumentException(nameof(chatMessages), "No thread ID was provided, but chat messages includes tool results.");
+ }
+
+ // Get the updates to process from the assistant. If we have any tool results, this means submitting those and ignoring
+ // our runOptions. Otherwise, create a run, and a thread if we don't have one.
+ IAsyncEnumerable updates;
+ if (GetRunId(toolResults, out List? toolOutputs) is string existingRunId)
+ {
+ updates = _assistantClient.SubmitToolOutputsToRunStreamingAsync(threadId, existingRunId, toolOutputs, cancellationToken);
+ }
+ else if (threadId is null)
+ {
+ ThreadCreationOptions creationOptions = new();
+ foreach (var message in runOptions.AdditionalMessages)
+ {
+ creationOptions.InitialMessages.Add(message);
+ }
+
+ runOptions.AdditionalMessages.Clear();
+
+ updates = _assistantClient.CreateThreadAndRunStreamingAsync(_assistantId, creationOptions, runOptions, cancellationToken: cancellationToken);
+ }
+ else
+ {
+ updates = _assistantClient.CreateRunStreamingAsync(threadId, _assistantId, runOptions, cancellationToken);
+ }
+
+ // Process each update.
+ await foreach (var update in updates.ConfigureAwait(false))
+ {
+ switch (update)
+ {
+ case MessageContentUpdate mcu:
+ yield return new()
+ {
+ ChatThreadId = threadId,
+ RawRepresentation = mcu,
+ Role = mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant,
+ Text = mcu.Text,
+ };
+ break;
+
+ case ThreadUpdate tu when options is not null:
+ threadId ??= tu.Value.Id;
+ break;
+
+ case RunUpdate ru:
+ threadId ??= ru.Value.ThreadId;
+
+ StreamingChatCompletionUpdate ruUpdate = new()
+ {
+ AuthorName = ru.Value.AssistantId,
+ ChatThreadId = threadId,
+ CompletionId = ru.Value.Id,
+ CreatedAt = ru.Value.CreatedAt,
+ ModelId = ru.Value.Model,
+ RawRepresentation = ru,
+ Role = ChatRole.Assistant,
+ };
+
+ if (ru.Value.Usage is { } usage)
+ {
+ ruUpdate.Contents.Add(new UsageContent(new()
+ {
+ InputTokenCount = usage.InputTokenCount,
+ OutputTokenCount = usage.OutputTokenCount,
+ TotalTokenCount = usage.TotalTokenCount,
+ }));
+ }
+
+ if (ru is RequiredActionUpdate rau && rau.ToolCallId is string toolCallId && rau.FunctionName is string functionName)
+ {
+ ruUpdate.Contents.Add(
+ new FunctionCallContent(
+ JsonSerializer.Serialize(new[] { ru.Value.Id, toolCallId }, OpenAIJsonContext.Default.StringArray!),
+ functionName,
+ JsonSerializer.Deserialize(rau.FunctionArguments, OpenAIJsonContext.Default.IDictionaryStringObject)!));
+ }
+
+ yield return ruUpdate;
+ break;
+ }
+ }
+ }
+
+ ///
+ void IDisposable.Dispose()
+ {
+ // Nothing to dispose. Implementation required for the IChatClient interface.
+ }
+
+ /// Adds the provided messages to the thread and returns the options to use for the request.
+ private static (RunCreationOptions RunOptions, List? ToolResults) CreateRunOptions(IList chatMessages, ChatOptions? options)
+ {
+ _ = Throw.IfNull(chatMessages);
+
+ RunCreationOptions runOptions = new();
+
+ // Handle ChatOptions.
+ if (options is not null)
+ {
+ // Propagate the simple properties that have a 1:1 correspondence.
+ runOptions.MaxOutputTokenCount = options.MaxOutputTokens;
+ runOptions.ModelOverride = options.ModelId;
+ runOptions.NucleusSamplingFactor = options.TopP;
+ runOptions.Temperature = options.Temperature;
+
+ // Propagate additional properties from AdditionalProperties.
+ if (options.AdditionalProperties?.TryGetValue(nameof(RunCreationOptions.AllowParallelToolCalls), out bool allowParallelToolCalls) is true)
+ {
+ runOptions.AllowParallelToolCalls = allowParallelToolCalls;
+ }
+
+ if (options.AdditionalProperties?.TryGetValue(nameof(RunCreationOptions.MaxInputTokenCount), out int maxInputTokenCount) is true)
+ {
+ runOptions.MaxInputTokenCount = maxInputTokenCount;
+ }
+
+ if (options.AdditionalProperties?.TryGetValue(nameof(RunCreationOptions.TruncationStrategy), out RunTruncationStrategy? truncationStrategy) is true)
+ {
+ runOptions.TruncationStrategy = truncationStrategy;
+ }
+
+ // Store all the tools to use.
+ if (options.Tools is { Count: > 0 } tools)
+ {
+ foreach (AITool tool in tools)
+ {
+ if (tool is AIFunction aiFunction)
+ {
+ bool? strict =
+ aiFunction.Metadata.AdditionalProperties.TryGetValue("Strict", out object? strictObj) &&
+ strictObj is bool strictValue ?
+ strictValue : null;
+
+ var functionParameters = BinaryData.FromBytes(
+ JsonSerializer.SerializeToUtf8Bytes(
+ JsonSerializer.Deserialize(aiFunction.Metadata.Schema, OpenAIJsonContext.Default.OpenAIChatToolJson)!,
+ OpenAIJsonContext.Default.OpenAIChatToolJson));
+
+ runOptions.ToolsOverride.Add(ToolDefinition.CreateFunction(aiFunction.Metadata.Name, aiFunction.Metadata.Description, functionParameters, strict));
+ }
+ }
+ }
+
+ // Store the tool mode.
+ switch (options.ToolMode)
+ {
+ case AutoChatToolMode:
+ runOptions.ToolConstraint = ToolConstraint.Auto;
+ break;
+
+ case RequiredChatToolMode required:
+ runOptions.ToolConstraint = required.RequiredFunctionName is null ?
+ new ToolConstraint(ToolDefinition.CreateFunction(required.RequiredFunctionName)) :
+ ToolConstraint.Required;
+ break;
+ }
+
+ // Store the response format.
+ if (options.ResponseFormat is ChatResponseFormatText)
+ {
+ runOptions.ResponseFormat = AssistantResponseFormat.Text;
+ }
+ else if (options.ResponseFormat is ChatResponseFormatJson jsonFormat)
+ {
+ runOptions.ResponseFormat = jsonFormat.Schema is { } jsonSchema ?
+ AssistantResponseFormat.CreateJsonSchemaFormat(
+ jsonFormat.SchemaName ?? "json_schema",
+ BinaryData.FromBytes(JsonSerializer.SerializeToUtf8Bytes(jsonSchema, OpenAIJsonContext.Default.JsonElement)),
+ jsonFormat.SchemaDescription) :
+ AssistantResponseFormat.JsonObject;
+ }
+ }
+
+ // Handle ChatMessages. System messages are turned into additional instructions.
+ StringBuilder? instructions = null;
+ List? functionResults = null;
+ foreach (var chatMessage in chatMessages)
+ {
+ List messageContents = [];
+
+ if (chatMessage.Role == ChatRole.System)
+ {
+ instructions ??= new();
+ foreach (var textContent in chatMessage.Contents.OfType())
+ {
+ _ = instructions.Append(textContent);
+ }
+
+ continue;
+ }
+
+ foreach (AIContent content in chatMessage.Contents)
+ {
+ switch (content)
+ {
+ case TextContent tc:
+ messageContents.Add(MessageContent.FromText(tc.Text));
+ break;
+
+ case DataContent dc when dc.MediaTypeStartsWith("image/"):
+ messageContents.Add(MessageContent.FromImageUri(new(dc.Uri)));
+ break;
+
+ case FunctionResultContent frc:
+ (functionResults ??= []).Add(frc);
+ break;
+ }
+ }
+
+ if (messageContents.Count > 0)
+ {
+ runOptions.AdditionalMessages.Add(new(
+ chatMessage.Role == ChatRole.Assistant ? MessageRole.Assistant : MessageRole.User,
+ messageContents));
+ }
+ }
+
+ if (instructions is not null)
+ {
+ runOptions.AdditionalInstructions = instructions.ToString();
+ }
+
+ return (runOptions, functionResults);
+ }
+
+ private static string? GetRunId(List? toolResults, out List? toolOutputs)
+ {
+ string? runId = null;
+ toolOutputs = null;
+ if (toolResults?.Count > 0)
+ {
+ foreach (var frc in toolResults)
+ {
+ // When creating the FunctionCallContext, we created it with a CallId == [runId, callId].
+ // We need to extract the run ID and ensure that the ToolOutput we send back to OpenAI
+ // is only the call ID.
+ string[]? runAndCallIDs;
+ try
+ {
+ runAndCallIDs = JsonSerializer.Deserialize(frc.CallId, OpenAIJsonContext.Default.StringArray);
+ }
+ catch
+ {
+ continue;
+ }
+
+ if (runAndCallIDs is null ||
+ runAndCallIDs.Length != 2 ||
+ string.IsNullOrWhiteSpace(runAndCallIDs[0]) || // run ID
+ string.IsNullOrWhiteSpace(runAndCallIDs[1]) || // call ID
+ (runId is not null && runId != runAndCallIDs[0]))
+ {
+ continue;
+ }
+
+ runId = runAndCallIDs[0];
+ (toolOutputs ??= []).Add(new(runAndCallIDs[1], frc.Result?.ToString() ?? string.Empty));
+ }
+ }
+
+ return runId;
+ }
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs
index 61f4fedcb46..6e3aa019c77 100644
--- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs
@@ -21,8 +21,8 @@ namespace Microsoft.Extensions.AI;
/// Represents an for an OpenAI or .
public sealed class OpenAIChatClient : IChatClient
{
- /// Default OpenAI endpoint.
- private static readonly Uri _defaultOpenAIEndpoint = new("https://api.openai.com/v1");
+ /// Gets the default OpenAI endpoint.
+ internal static Uri DefaultOpenAIEndpoint { get; } = new("https://api.openai.com/v1");
/// Metadata about the client.
private readonly ChatClientMetadata _metadata;
@@ -52,7 +52,7 @@ public OpenAIChatClient(OpenAIClient openAIClient, string modelId)
// implement the abstractions directly rather than providing adapters on top of the public APIs,
// the package can provide such implementations separate from what's exposed in the public API.
Uri providerUrl = typeof(OpenAIClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
- ?.GetValue(openAIClient) as Uri ?? _defaultOpenAIEndpoint;
+ ?.GetValue(openAIClient) as Uri ?? DefaultOpenAIEndpoint;
_metadata = new("openai", providerUrl, modelId);
}
@@ -70,7 +70,7 @@ public OpenAIChatClient(ChatClient chatClient)
// implement the abstractions directly rather than providing adapters on top of the public APIs,
// the package can provide such implementations separate from what's exposed in the public API.
Uri providerUrl = typeof(ChatClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
- ?.GetValue(chatClient) as Uri ?? _defaultOpenAIEndpoint;
+ ?.GetValue(chatClient) as Uri ?? DefaultOpenAIEndpoint;
string? model = typeof(ChatClient).GetField("_model", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
?.GetValue(chatClient) as string;
diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs
index 2bea9264730..49f78518015 100644
--- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using OpenAI;
+using OpenAI.Assistants;
using OpenAI.Chat;
using OpenAI.Embeddings;
@@ -23,6 +24,19 @@ public static IChatClient AsChatClient(this OpenAIClient openAIClient, string mo
public static IChatClient AsChatClient(this ChatClient chatClient) =>
new OpenAIChatClient(chatClient);
+#pragma warning disable OPENAI001 // Type is for evaluation purposes only
+ /// Gets an for use with this .
+ /// The client.
+ /// The ID of the assistant to use.
+ ///
+ /// The ID of the thread to use. If not supplied here, it should be supplied per request in .
+ /// If none is supplied, a new thread will be created for a request.
+ ///
+ /// An that can be used to converse via the .
+ public static IChatClient AsChatClient(this AssistantClient assistantClient, string assistantId, string? threadId = null) =>
+ new OpenAIAssistantClient(assistantClient, assistantId, threadId);
+#pragma warning restore OPENAI001
+
/// Gets an for use with this .
/// The client.
/// The model to use.
diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIJsonContext.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIJsonContext.cs
index 69f610b4818..c75b2a4c644 100644
--- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIJsonContext.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIJsonContext.cs
@@ -15,4 +15,5 @@ namespace Microsoft.Extensions.AI;
[JsonSerializable(typeof(OpenAIRealtimeExtensions.ConversationFunctionToolParametersSchema))]
[JsonSerializable(typeof(OpenAIModelMappers.OpenAIChatToolJson))]
[JsonSerializable(typeof(IDictionary))]
+[JsonSerializable(typeof(string[]))]
internal sealed partial class OpenAIJsonContext : JsonSerializerContext;
diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs
index b5671232f8d..2612980b34f 100644
--- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs
@@ -23,7 +23,7 @@ namespace Microsoft.Extensions.AI;
internal static partial class OpenAIModelMappers
{
- private static readonly JsonElement _defaultParameterSchema = JsonDocument.Parse("{}").RootElement;
+ internal static JsonElement DefaultParameterSchema { get; } = JsonDocument.Parse("{}").RootElement;
public static OpenAI.Chat.ChatCompletion ToOpenAIChatCompletion(ChatCompletion chatCompletion, JsonSerializerOptions options)
{
@@ -382,7 +382,7 @@ private static AITool FromOpenAIChatTool(ChatTool chatTool)
ReturnParameter = new()
{
Description = "Return parameter",
- Schema = _defaultParameterSchema,
+ Schema = DefaultParameterSchema,
}
};
diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
index 59716092b7a..6e3923ed1ad 100644
--- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
@@ -140,7 +140,8 @@ public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = nul
public bool ConcurrentInvocation { get; set; }
///
- /// Gets or sets a value indicating whether to keep intermediate messages in the chat history.
+ /// Gets or sets a value indicating whether to keep intermediate function calling request
+ /// and response messages in the chat history.
///
///
/// if intermediate messages persist in the list provided
@@ -155,14 +156,20 @@ public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = nul
/// those messages to the list of messages, along with instances
/// it creates with the results of invoking the requested functions. The resulting augmented
/// list of messages is then passed to the inner client in order to send the results back.
- /// By default, those messages persist in the list provided to
- /// and by the caller. Set
- /// to to remove those messages prior to completing the operation.
+ /// By default, those messages persist in the list provided to
+ /// and by the caller, such that those
+ /// messages are available to the caller. Set to avoid including
+ /// those messages in the caller-provided .
///
///
/// Changing the value of this property while the client is in use might result in inconsistencies
/// as to whether function calling messages are kept during an in-flight request.
///
+ ///
+ /// If the underlying responds with
+ /// set to a non- value, this property may be ignored and behave as if it is
+ /// , with any such intermediate messages not stored in the messages list.
+ ///
///
public bool KeepFunctionCallingMessages { get; set; } = true;
@@ -211,10 +218,8 @@ public override async Task CompleteAsync(IList chat
using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient));
ChatCompletion? response = null;
- HashSet? messagesToRemove = null;
- HashSet? contentsToRemove = null;
UsageDetails? totalUsage = null;
-
+ IList originalChatMessages = chatMessages;
try
{
for (int iteration = 0; ; iteration++)
@@ -256,54 +261,52 @@ public override async Task CompleteAsync(IList chat
break;
}
- // Track all added messages in order to remove them, if requested.
- if (!KeepFunctionCallingMessages)
- {
- messagesToRemove ??= [];
- }
-
- // Add the original response message into the history and track the message for removal.
- chatMessages.Add(response.Message);
- if (messagesToRemove is not null)
+ // Update the chat history. If the underlying client is tracking the state, then we want to avoid re-sending
+ // what we already sent as well as this response message, so create a new list to store the response message(s).
+ if (response.ChatThreadId is not null)
{
- if (functionCallContents.Length == response.Message.Contents.Count)
+ if (chatMessages == originalChatMessages)
{
- // The most common case is that the response message contains only function calling content.
- // In that case, we can just track the whole message for removal.
- _ = messagesToRemove.Add(response.Message);
+ chatMessages = [];
}
else
{
- // In the less likely case where some content is function calling and some isn't, we don't want to remove
- // the non-function calling content by removing the whole message. So we track the content directly.
- (contentsToRemove ??= []).UnionWith(functionCallContents);
+ chatMessages.Clear();
}
}
-
- // Add the responses from the function calls into the history.
- var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false);
- if (modeAndMessages.MessagesAdded is not null)
+ else
{
- messagesToRemove?.UnionWith(modeAndMessages.MessagesAdded);
+ // Otherwise, we need to add the response message to the history we're sending back. However, if the caller
+ // doesn't want the intermediate messages, create a new list that we mutate instead of mutating the original.
+ if (!KeepFunctionCallingMessages)
+ {
+ // Create a new list that will include the message with the function call contents.
+ if (chatMessages == originalChatMessages)
+ {
+ chatMessages = [.. chatMessages];
+ }
+
+ // We want to include any non-functional calling content, if there is any,
+ // in the caller's list so that they don't lose out on actual content.
+ // This can happen but is relatively rare.
+ if (response.Message.Contents.Any(c => c is not FunctionCallContent))
+ {
+ var clone = response.Message.Clone();
+ clone.Contents = clone.Contents.Where(c => c is not FunctionCallContent).ToList();
+ originalChatMessages.Add(clone);
+ }
+ }
+
+ // Add the original response message into the history.
+ chatMessages.Add(response.Message);
}
- switch (modeAndMessages.Mode)
+ // Add the responses from the function calls into the history.
+ var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false);
+ if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, response.ChatThreadId))
{
- case ContinueMode.Continue when options.ToolMode is RequiredChatToolMode:
- // We have to reset this after the first iteration, otherwise we'll be in an infinite loop.
- options = options.Clone();
- options.ToolMode = null;
- break;
-
- case ContinueMode.AllowOneMoreRoundtrip:
- // The LLM gets one further chance to answer, but cannot use tools.
- options = options.Clone();
- options.Tools = null;
- break;
-
- case ContinueMode.Terminate:
- // Bail immediately.
- return response;
+ // Terminate
+ return response;
}
}
@@ -311,8 +314,6 @@ public override async Task CompleteAsync(IList chat
}
finally
{
- RemoveMessagesAndContentFromList(messagesToRemove, contentsToRemove, chatMessages);
-
if (response is not null)
{
response.Usage = totalUsage;
@@ -330,102 +331,94 @@ public override async IAsyncEnumerable CompleteSt
// Create an activity to group them together for better observability.
using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient));
- HashSet? messagesToRemove = null;
List functionCallContents = [];
int? choice;
- try
+ IList originalChatMessages = chatMessages;
+ for (int iteration = 0; ; iteration++)
{
- for (int iteration = 0; ; iteration++)
+ choice = null;
+ string? chatThreadId = null;
+ functionCallContents.Clear();
+ await foreach (var update in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
{
- choice = null;
- functionCallContents.Clear();
- await foreach (var update in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
+ // We're going to emit all StreamingChatMessage items upstream, even ones that represent
+ // function calls, because a given StreamingChatMessage can contain other content, too.
+ // And if we yield the function calls, and the consumer adds all the content into a message
+ // that's then added into history, they'll end up with function call contents that aren't
+ // directly paired with function result contents, which may cause issues for some models
+ // when the history is later sent again.
+
+ // Find all the FCCs. We need to track these separately in order to be able to process them later.
+ int preFccCount = functionCallContents.Count;
+ functionCallContents.AddRange(update.Contents.OfType());
+
+ // If there were any, remove them from the update. We do this before yielding the update so
+ // that we're not modifying an instance already provided back to the caller.
+ int addedFccs = functionCallContents.Count - preFccCount;
+ if (addedFccs > 0)
{
- // We're going to emit all StreamingChatMessage items upstream, even ones that represent
- // function calls, because a given StreamingChatMessage can contain other content, too.
- // And if we yield the function calls, and the consumer adds all the content into a message
- // that's then added into history, they'll end up with function call contents that aren't
- // directly paired with function result contents, which may cause issues for some models
- // when the history is later sent again.
-
- // Find all the FCCs. We need to track these separately in order to be able to process them later.
- int preFccCount = functionCallContents.Count;
- functionCallContents.AddRange(update.Contents.OfType());
-
- // If there were any, remove them from the update. We do this before yielding the update so
- // that we're not modifying an instance already provided back to the caller.
- int addedFccs = functionCallContents.Count - preFccCount;
- if (addedFccs > 0)
- {
- update.Contents = addedFccs == update.Contents.Count ?
- [] : update.Contents.Where(c => c is not FunctionCallContent).ToList();
- }
-
- // Only one choice is allowed with automatic function calling.
- if (choice is null)
- {
- choice = update.ChoiceIndex;
- }
- else if (choice != update.ChoiceIndex)
- {
- ThrowForMultipleChoices();
- }
-
- yield return update;
- Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802
+ update.Contents = addedFccs == update.Contents.Count ?
+ [] : update.Contents.Where(c => c is not FunctionCallContent).ToList();
}
- // If there are no tools to call, or for any other reason we should stop, return the response.
- if (options is null
- || options.Tools is not { Count: > 0 }
- || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)
- || functionCallContents is not { Count: > 0 })
+ // Only one choice is allowed with automatic function calling.
+ if (choice is null)
{
- break;
+ choice = update.ChoiceIndex;
}
-
- // Track all added messages in order to remove them, if requested.
- if (!KeepFunctionCallingMessages)
+ else if (choice != update.ChoiceIndex)
{
- messagesToRemove ??= [];
+ ThrowForMultipleChoices();
}
- // Add a manufactured response message containing the function call contents to the chat history.
- ChatMessage functionCallMessage = new(ChatRole.Assistant, [.. functionCallContents]);
- chatMessages.Add(functionCallMessage);
- _ = messagesToRemove?.Add(functionCallMessage);
+ chatThreadId ??= update.ChatThreadId;
- // Process all of the functions, adding their results into the history.
- var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false);
- if (modeAndMessages.MessagesAdded is not null)
+ yield return update;
+ Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802
+ }
+
+ // If there are no tools to call, or for any other reason we should stop, return the response.
+ if (options is null
+ || options.Tools is not { Count: > 0 }
+ || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)
+ || functionCallContents is not { Count: > 0 })
+ {
+ break;
+ }
+
+ // Update the chat history. If the underlying client is tracking the state, then we want to avoid re-sending
+ // what we already sent as well as this response message, so create a new list to store the response message(s).
+ if (chatThreadId is not null)
+ {
+ if (chatMessages == originalChatMessages)
{
- messagesToRemove?.UnionWith(modeAndMessages.MessagesAdded);
+ chatMessages = [];
}
-
- // Decide how to proceed based on the result of the function calls.
- switch (modeAndMessages.Mode)
+ else
{
- case ContinueMode.Continue when options.ToolMode is RequiredChatToolMode:
- // We have to reset this after the first iteration, otherwise we'll be in an infinite loop.
- options = options.Clone();
- options.ToolMode = null;
- break;
-
- case ContinueMode.AllowOneMoreRoundtrip:
- // The LLM gets one further chance to answer, but cannot use tools.
- options = options.Clone();
- options.Tools = null;
- break;
-
- case ContinueMode.Terminate:
- // Bail immediately.
- yield break;
+ chatMessages.Clear();
}
}
- }
- finally
- {
- RemoveMessagesAndContentFromList(messagesToRemove, contentToRemove: null, chatMessages);
+ else
+ {
+ // Otherwise, we need to add the response message to the history we're sending back. However, if the caller
+ // doesn't want the intermediate messages, create a new list that we mutate instead of mutating the original.
+ if (chatMessages == originalChatMessages && !KeepFunctionCallingMessages)
+ {
+ chatMessages = [.. chatMessages];
+ }
+
+ // Add a manufactured response message containing the function call contents to the chat history.
+ chatMessages.Add(new(ChatRole.Assistant, [.. functionCallContents]));
+ }
+
+ // Process all of the functions, adding their results into the history.
+ var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false);
+ if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, chatThreadId))
+ {
+ // Terminate
+ yield break;
+ }
}
}
@@ -439,42 +432,53 @@ private static void ThrowForMultipleChoices()
throw new InvalidOperationException("Automatic function call invocation only accepts a single choice, but multiple choices were received.");
}
- ///
- /// Removes all of the messages in from
- /// and all of the content in from the messages in .
- ///
- private static void RemoveMessagesAndContentFromList(
- HashSet? messagesToRemove,
- HashSet? contentToRemove,
- IList messages)
+ /// Updates for the response.
+ /// true if the function calling loop should terminate; otherwise, false.
+ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions options, string? chatThreadId)
{
- Debug.Assert(
- contentToRemove is null || messagesToRemove is not null,
- "We should only be tracking content to remove if we're also tracking messages to remove.");
-
- if (messagesToRemove is not null)
+ switch (mode)
{
- for (int m = messages.Count - 1; m >= 0; m--)
- {
- ChatMessage message = messages[m];
+ case ContinueMode.Continue when options.ToolMode is RequiredChatToolMode:
+ // We have to reset the tool mode to be non-required after the first iteration,
+ // as otherwise we'll be in an infinite loop.
+ options = options.Clone();
+ options.ToolMode = null;
+ if (chatThreadId is not null)
+ {
+ options.ChatThreadId = chatThreadId;
+ }
+
+ break;
- if (contentToRemove is not null)
+ case ContinueMode.AllowOneMoreRoundtrip:
+ // The LLM gets one further chance to answer, but cannot use tools.
+ options = options.Clone();
+ options.Tools = null;
+ options.ToolMode = null;
+ if (chatThreadId is not null)
{
- for (int c = message.Contents.Count - 1; c >= 0; c--)
- {
- if (contentToRemove.Contains(message.Contents[c]))
- {
- message.Contents.RemoveAt(c);
- }
- }
+ options.ChatThreadId = chatThreadId;
}
- if (messages.Count == 0 || messagesToRemove.Contains(messages[m]))
+ break;
+
+ case ContinueMode.Terminate:
+ // Bail immediately.
+ return true;
+
+ default:
+ // As with the other modes, ensure we've propagated the chat thread ID to the options.
+ // We only need to clone the options if we're actually mutating it.
+ if (chatThreadId is not null && options.ChatThreadId != chatThreadId)
{
- messages.RemoveAt(m);
+ options = options.Clone();
+ options.ChatThreadId = chatThreadId;
}
- }
+
+ break;
}
+
+ return false;
}
///
@@ -630,7 +634,7 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
{
string message = result.Status switch
{
- FunctionStatus.NotFound => "Error: Requested function not found.",
+ FunctionStatus.NotFound => $"Error: Requested function \"{result.CallContent.Name}\" not found.",
FunctionStatus.Failed => "Error: Function failed.",
_ => "Error: Unknown error.",
};
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs
index 4e3ceadd793..498be7ecb1e 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs
@@ -13,6 +13,7 @@ public class ChatOptionsTests
public void Constructor_Parameterless_PropsDefaulted()
{
ChatOptions options = new();
+ Assert.Null(options.ChatThreadId);
Assert.Null(options.Temperature);
Assert.Null(options.MaxOutputTokens);
Assert.Null(options.TopP);
@@ -28,6 +29,7 @@ public void Constructor_Parameterless_PropsDefaulted()
Assert.Null(options.AdditionalProperties);
ChatOptions clone = options.Clone();
+ Assert.Null(options.ChatThreadId);
Assert.Null(clone.Temperature);
Assert.Null(clone.MaxOutputTokens);
Assert.Null(clone.TopP);
@@ -65,6 +67,7 @@ public void Properties_Roundtrip()
["key"] = "value",
};
+ options.ChatThreadId = "12345";
options.Temperature = 0.1f;
options.MaxOutputTokens = 2;
options.TopP = 0.3f;
@@ -79,6 +82,7 @@ public void Properties_Roundtrip()
options.Tools = tools;
options.AdditionalProperties = additionalProps;
+ Assert.Equal("12345", options.ChatThreadId);
Assert.Equal(0.1f, options.Temperature);
Assert.Equal(2, options.MaxOutputTokens);
Assert.Equal(0.3f, options.TopP);
@@ -94,6 +98,7 @@ public void Properties_Roundtrip()
Assert.Same(additionalProps, options.AdditionalProperties);
ChatOptions clone = options.Clone();
+ Assert.Equal("12345", options.ChatThreadId);
Assert.Equal(0.1f, clone.Temperature);
Assert.Equal(2, clone.MaxOutputTokens);
Assert.Equal(0.3f, clone.TopP);
@@ -125,6 +130,7 @@ public void JsonSerialization_Roundtrips()
["key"] = "value",
};
+ options.ChatThreadId = "12345";
options.Temperature = 0.1f;
options.MaxOutputTokens = 2;
options.TopP = 0.3f;
@@ -148,6 +154,7 @@ public void JsonSerialization_Roundtrips()
ChatOptions? deserialized = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatOptions);
Assert.NotNull(deserialized);
+ Assert.Equal("12345", deserialized.ChatThreadId);
Assert.Equal(0.1f, deserialized.Temperature);
Assert.Equal(2, deserialized.MaxOutputTokens);
Assert.Equal(0.3f, deserialized.TopP);
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs
index 9af25dbd16a..5b5294d24f4 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs
@@ -61,10 +61,13 @@ public async Task ToChatCompletion_SuccessfullyCreatesCompletion(bool useAsync,
};
Assert.NotNull(completion);
+ Assert.NotNull(completion.Usage);
+ Assert.Equal(5, completion.Usage.InputTokenCount);
+ Assert.Equal(7, completion.Usage.OutputTokenCount);
+
Assert.Equal("12345", completion.CompletionId);
Assert.Equal(new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), completion.CreatedAt);
Assert.Equal("model123", completion.ModelId);
- Assert.Same(Assert.IsType(updates[6].Contents[0]).Details, completion.Usage);
Assert.Equal(3, completion.Choices.Count);
@@ -89,7 +92,7 @@ public async Task ToChatCompletion_SuccessfullyCreatesCompletion(bool useAsync,
Assert.Equal(ChatRole.Assistant, message.Role);
Assert.Null(message.AuthorName);
Assert.Null(message.AdditionalProperties);
- Assert.Same(updates[7].Contents[0], Assert.Single(message.Contents));
+ Assert.Empty(message.Contents);
if (coalesceContent is null or true)
{
diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs
index 540ca7c8431..2690eb7181c 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs
@@ -19,6 +19,9 @@ namespace Microsoft.Extensions.AI;
public class FunctionInvokingChatClientTests
{
+ private readonly Func _keepMessagesConfigure =
+ b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = true });
+
[Fact]
public void InvalidArgs_Throws()
{
@@ -64,9 +67,9 @@ public async Task SupportsSingleFunctionCallPerRequestAsync()
new ChatMessage(ChatRole.Assistant, "world"),
];
- await InvokeAndAssertAsync(options, plan);
+ await InvokeAndAssertAsync(options, plan, configurePipeline: _keepMessagesConfigure);
- await InvokeAndAssertStreamingAsync(options, plan);
+ await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: _keepMessagesConfigure);
}
[Theory]
@@ -111,7 +114,8 @@ public async Task SupportsMultipleFunctionCallsPerRequestAsync(bool concurrentIn
new ChatMessage(ChatRole.Assistant, "world"),
];
- Func configure = b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = concurrentInvocation });
+ Func configure = b => b.Use(
+ s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = concurrentInvocation, KeepFunctionCallingMessages = true });
await InvokeAndAssertAsync(options, plan, configurePipeline: configure);
@@ -151,7 +155,8 @@ public async Task ParallelFunctionCallsMayBeInvokedConcurrentlyAsync()
new ChatMessage(ChatRole.Assistant, "done"),
];
- Func configure = b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = true });
+ Func configure = b => b.Use(
+ s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = true, KeepFunctionCallingMessages = true });
await InvokeAndAssertAsync(options, plan, configurePipeline: configure);
@@ -194,9 +199,9 @@ public async Task ConcurrentInvocationOfParallelCallsDisabledByDefaultAsync()
new ChatMessage(ChatRole.Assistant, "done"),
];
- await InvokeAndAssertAsync(options, plan);
+ await InvokeAndAssertAsync(options, plan, configurePipeline: _keepMessagesConfigure);
- await InvokeAndAssertStreamingAsync(options, plan);
+ await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: _keepMessagesConfigure);
}
[Theory]
@@ -232,7 +237,8 @@ public async Task RemovesFunctionCallingMessagesWhenRequestedAsync(bool keepFunc
new ChatMessage(ChatRole.Assistant, "world")
];
- Func configure = b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages });
+ Func configure = b => b.Use(
+ client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages });
Validate(await InvokeAndAssertAsync(options, plan, expected, configure));
Validate(await InvokeAndAssertStreamingAsync(options, plan, expected, configure));
@@ -254,7 +260,7 @@ void Validate(List finalChat)
[Theory]
[InlineData(false)]
[InlineData(true)]
- public async Task RemovesFunctionCallingContentWhenRequestedAsync(bool keepFunctionCallingMessages)
+ public async Task KeepsFunctionCallingContentWhenRequestedAsync(bool keepFunctionCallingMessages)
{
var options = new ChatOptions
{
@@ -278,7 +284,8 @@ public async Task RemovesFunctionCallingContentWhenRequestedAsync(bool keepFunct
new ChatMessage(ChatRole.Assistant, "world"),
];
- Func configure = b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages });
+ Func configure = b => b.Use(
+ client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages });
#pragma warning disable SA1005, S125
Validate(await InvokeAndAssertAsync(options, plan, keepFunctionCallingMessages ? null :
@@ -340,7 +347,8 @@ public async Task ExceptionDetailsOnlyReportedWhenRequestedAsync(bool detailedEr
new ChatMessage(ChatRole.Assistant, "world"),
];
- Func configure = b => b.Use(s => new FunctionInvokingChatClient(s) { DetailedErrors = detailedErrors });
+ Func configure = b => b.Use(
+ s => new FunctionInvokingChatClient(s) { DetailedErrors = detailedErrors, KeepFunctionCallingMessages = true });
await InvokeAndAssertAsync(options, plan, configurePipeline: configure);
@@ -405,7 +413,10 @@ public async Task FunctionInvocationsLogged(LogLevel level)
};
Func configure = b =>
- b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService>()));
+ b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService>())
+ {
+ KeepFunctionCallingMessages = true,
+ });
await InvokeAsync(services => InvokeAndAssertAsync(options, plan, configurePipeline: configure, services: services));
@@ -461,8 +472,10 @@ public async Task FunctionInvocationTrackedWithActivity(bool enableTelemetry)
};
Func configure = b => b.Use(c =>
- new FunctionInvokingChatClient(
- new OpenTelemetryChatClient(c, sourceName: sourceName)));
+ new FunctionInvokingChatClient(new OpenTelemetryChatClient(c, sourceName: sourceName))
+ {
+ KeepFunctionCallingMessages = true,
+ });
await InvokeAsync(() => InvokeAndAssertAsync(options, plan, configurePipeline: configure));
@@ -529,7 +542,7 @@ public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls()
}
};
- using var client = new FunctionInvokingChatClient(innerClient);
+ using var client = new FunctionInvokingChatClient(innerClient) { KeepFunctionCallingMessages = true };
var updates = new List();
await foreach (var update in client.CompleteStreamingAsync(messages, options, CancellationToken.None))
@@ -603,7 +616,7 @@ await InvokeAsync(() => InvokeAndAssertAsync(options, plan, expected: [
// The last message is the one returned by the chat client
// This message's content should contain the last function call before the termination
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func1", new Dictionary { ["i"] = 42 })]),
- ]));
+ ], configurePipeline: _keepMessagesConfigure));
await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, expected: [
.. planBeforeTermination,
@@ -611,7 +624,7 @@ await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, expected: [
// The last message is the one returned by the chat client
// When streaming, function call content is removed from this message
new ChatMessage(ChatRole.Assistant, []),
- ]));
+ ], configurePipeline: _keepMessagesConfigure));
// The current context should be null outside the async call stack for the function invocation
Assert.Null(FunctionInvokingChatClient.CurrentContext);
@@ -640,6 +653,56 @@ void AssertInvocationContext(FunctionInvokingChatClient.FunctionInvocationContex
}
}
+ [Fact]
+ public async Task PropagatesCompletionChatThreadIdToOptions()
+ {
+ var options = new ChatOptions
+ {
+ Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")],
+ };
+
+ int iteration = 0;
+
+ Func, ChatOptions?, CancellationToken, ChatCompletion> callback =
+ (chatContents, chatOptions, cancellationToken) =>
+ {
+ iteration++;
+
+ if (iteration == 1)
+ {
+ Assert.Null(chatOptions?.ChatThreadId);
+ return new ChatCompletion(new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId-abc", "Func1")]))
+ {
+ ChatThreadId = "12345",
+ };
+ }
+ else if (iteration == 2)
+ {
+ Assert.Equal("12345", chatOptions?.ChatThreadId);
+ return new ChatCompletion(new ChatMessage(ChatRole.Assistant, "done!"));
+ }
+ else
+ {
+ throw new InvalidOperationException("Unexpected iteration");
+ }
+ };
+
+ using var innerClient = new TestChatClient
+ {
+ CompleteAsyncCallback = (chatContents, chatOptions, cancellationToken) =>
+ Task.FromResult(callback(chatContents, chatOptions, cancellationToken)),
+ CompleteStreamingAsyncCallback = (chatContents, chatOptions, cancellationToken) =>
+ YieldAsync(callback(chatContents, chatOptions, cancellationToken).ToStreamingChatCompletionUpdates()),
+ };
+
+ using IChatClient service = innerClient.AsBuilder().UseFunctionInvocation().Build();
+
+ iteration = 0;
+ Assert.Equal("done!", (await service.CompleteAsync("hey", options)).ToString());
+ iteration = 0;
+ Assert.Equal("done!", (await service.CompleteStreamingAsync("hey", options).ToChatCompletionAsync()).ToString());
+ }
+
private static async Task> InvokeAndAssertAsync(
ChatOptions options,
List plan,
@@ -659,7 +722,6 @@ private static async Task> InvokeAndAssertAsync(
{
CompleteAsyncCallback = async (contents, actualOptions, actualCancellationToken) =>
{
- Assert.Same(chat, contents);
Assert.Equal(cts.Token, actualCancellationToken);
await Task.Yield();
@@ -753,7 +815,6 @@ private static async Task> InvokeAndAssertStreamingAsync(
{
CompleteStreamingAsyncCallback = (contents, actualOptions, actualCancellationToken) =>
{
- Assert.Same(chat, contents);
Assert.Equal(cts.Token, actualCancellationToken);
return YieldAsync(new ChatCompletion(new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count].Contents])).ToStreamingChatCompletionUpdates());