Skip to content

Commit

Permalink
Update to {Azure.AI.}OpenAI 2.2.0-beta.1 (#5869)
Browse files Browse the repository at this point in the history
* Update versions and react to breaking changes

* Use newly exposed OpenAI APIs
  • Loading branch information
stephentoub authored Feb 11, 2025
1 parent 038cc27 commit 468fe7d
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 150 deletions.
2 changes: 1 addition & 1 deletion eng/packages/General.props
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
<PackageVersion Include="Microsoft.CodeAnalysis" Version="$(MicrosoftCodeAnalysisVersion)" />
<PackageVersion Include="Microsoft.IO.RecyclableMemoryStream" Version="3.0.0" />
<PackageVersion Include="Newtonsoft.Json" Version="13.0.3" />
<PackageVersion Include="OpenAI" Version="2.1.0" />
<PackageVersion Include="OpenAI" Version="2.2.0-beta.1" />
<PackageVersion Include="Polly" Version="8.4.2" />
<PackageVersion Include="Polly.Core" Version="8.4.2" />
<PackageVersion Include="Polly.Extensions" Version="8.4.2" />
Expand Down
2 changes: 1 addition & 1 deletion eng/packages/TestOnly.props
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<Project xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<PackageVersion Include="AutoFixture.AutoMoq" Version="4.17.0" />
<PackageVersion Include="Azure.AI.OpenAI" Version="2.1.0" />
<PackageVersion Include="Azure.AI.OpenAI" Version="2.2.0-beta.1" />
<PackageVersion Include="autofixture" Version="4.17.0" />
<PackageVersion Include="BenchmarkDotNet" Version="0.13.5" />
<PackageVersion Include="FluentAssertions" Version="6.11.0" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ strictObj is bool strictValue ?
{
List<MessageContent> messageContents = [];

if (chatMessage.Role == ChatRole.System)
if (chatMessage.Role == ChatRole.System ||
chatMessage.Role == OpenAIModelMappers.ChatRoleDeveloper)
{
instructions ??= new();
foreach (var textContent in chatMessage.Contents.OfType<TextContent>())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public async Task<ChatCompletion> CompleteAsync(
// Make the call to OpenAI.
var response = await _chatClient.CompleteChatAsync(openAIChatMessages, openAIOptions, cancellationToken).ConfigureAwait(false);

return OpenAIModelMappers.FromOpenAIChatCompletion(response.Value, options);
return OpenAIModelMappers.FromOpenAIChatCompletion(response.Value, options, openAIOptions);
}

/// <inheritdoc />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
using Microsoft.Shared.Diagnostics;
using OpenAI.Chat;

#pragma warning disable CA1308 // Normalize strings to uppercase
#pragma warning disable CA1859 // Use concrete types when possible for improved performance
#pragma warning disable SA1204 // Static elements should appear before instance elements
#pragma warning disable S103 // Lines should not be too long
#pragma warning disable CA1859 // Use concrete types when possible for improved performance
#pragma warning disable S1067 // Expressions should not be too complex
#pragma warning disable S2178 // Short-circuit logic should be used in boolean contexts
#pragma warning disable S3440 // Variables should not be checked against the values they're about to be assigned
#pragma warning disable EA0011 // Consider removing unnecessary conditional access operator (?)

namespace Microsoft.Extensions.AI;

Expand Down Expand Up @@ -70,7 +73,7 @@ public static OpenAI.Chat.ChatCompletion ToOpenAIChatCompletion(ChatCompletion c
usage: chatTokenUsage);
}

public static ChatCompletion FromOpenAIChatCompletion(OpenAI.Chat.ChatCompletion openAICompletion, ChatOptions? options)
public static ChatCompletion FromOpenAIChatCompletion(OpenAI.Chat.ChatCompletion openAICompletion, ChatOptions? options, ChatCompletionOptions chatCompletionOptions)
{
_ = Throw.IfNull(openAICompletion);

Expand All @@ -90,6 +93,37 @@ public static ChatCompletion FromOpenAIChatCompletion(OpenAI.Chat.ChatCompletion
}
}

// Output audio is handled separately from message content parts.
if (openAICompletion.OutputAudio is ChatOutputAudio audio)
{
string mimeType = chatCompletionOptions?.AudioOptions?.OutputAudioFormat.ToString()?.ToLowerInvariant() switch
{
"opus" => "audio/opus",
"aac" => "audio/aac",
"flac" => "audio/flac",
"wav" => "audio/wav",
"pcm" => "audio/pcm",
"mp3" or _ => "audio/mpeg",
};

var dc = new DataContent(audio.AudioBytes.ToMemory(), mimeType)
{
AdditionalProperties = new() { [nameof(audio.ExpiresAt)] = audio.ExpiresAt },
};

if (audio.Id is string id)
{
dc.AdditionalProperties[nameof(audio.Id)] = id;
}

if (audio.Transcript is string transcript)
{
dc.AdditionalProperties[nameof(audio.Transcript)] = transcript;
}

returnMessage.Contents.Add(dc);
}

// Also manufacture function calling content items from any tool calls in the response.
if (options?.Tools is { Count: > 0 })
{
Expand All @@ -108,11 +142,11 @@ public static ChatCompletion FromOpenAIChatCompletion(OpenAI.Chat.ChatCompletion
// Wrap the content in a ChatCompletion to return.
var completion = new ChatCompletion([returnMessage])
{
RawRepresentation = openAICompletion,
CompletionId = openAICompletion.Id,
CreatedAt = openAICompletion.CreatedAt,
ModelId = openAICompletion.Model,
FinishReason = FromOpenAIFinishReason(openAICompletion.FinishReason),
ModelId = openAICompletion.Model,
RawRepresentation = openAICompletion,
};

if (openAICompletion.Usage is ChatTokenUsage tokenUsage)
Expand Down Expand Up @@ -265,6 +299,16 @@ public static ChatCompletionOptions ToOpenAIOptions(ChatOptions? options)

if (options.AdditionalProperties is { Count: > 0 } additionalProperties)
{
if (additionalProperties.TryGetValue(nameof(result.AllowParallelToolCalls), out bool allowParallelToolCalls))
{
result.AllowParallelToolCalls = allowParallelToolCalls;
}

if (additionalProperties.TryGetValue(nameof(result.AudioOptions), out ChatAudioOptions? audioOptions))
{
result.AudioOptions = audioOptions;
}

if (additionalProperties.TryGetValue(nameof(result.EndUserId), out string? endUserId))
{
result.EndUserId = endUserId;
Expand All @@ -283,28 +327,38 @@ public static ChatCompletionOptions ToOpenAIOptions(ChatOptions? options)
}
}

if (additionalProperties.TryGetValue(nameof(result.AllowParallelToolCalls), out bool allowParallelToolCalls))
if (additionalProperties.TryGetValue(nameof(result.Metadata), out IDictionary<string, string>? metadata))
{
result.AllowParallelToolCalls = allowParallelToolCalls;
foreach (KeyValuePair<string, string> kvp in metadata)
{
result.Metadata[kvp.Key] = kvp.Value;
}
}

if (additionalProperties.TryGetValue(nameof(result.TopLogProbabilityCount), out int topLogProbabilityCountInt))
if (additionalProperties.TryGetValue(nameof(result.OutputPrediction), out ChatOutputPrediction? outputPrediction))
{
result.TopLogProbabilityCount = topLogProbabilityCountInt;
result.OutputPrediction = outputPrediction;
}

if (additionalProperties.TryGetValue(nameof(result.Metadata), out IDictionary<string, string>? metadata))
if (additionalProperties.TryGetValue(nameof(result.ReasoningEffortLevel), out ChatReasoningEffortLevel reasoningEffortLevel))
{
foreach (KeyValuePair<string, string> kvp in metadata)
{
result.Metadata[kvp.Key] = kvp.Value;
}
result.ReasoningEffortLevel = reasoningEffortLevel;
}

if (additionalProperties.TryGetValue(nameof(result.ResponseModalities), out ChatResponseModalities responseModalities))
{
result.ResponseModalities = responseModalities;
}

if (additionalProperties.TryGetValue(nameof(result.StoredOutputEnabled), out bool storeOutputEnabled))
{
result.StoredOutputEnabled = storeOutputEnabled;
}

if (additionalProperties.TryGetValue(nameof(result.TopLogProbabilityCount), out int topLogProbabilityCountInt))
{
result.TopLogProbabilityCount = topLogProbabilityCountInt;
}
}

if (options.Tools is { Count: > 0 } tools)
Expand Down Expand Up @@ -420,26 +474,22 @@ private static UsageDetails FromOpenAIUsage(ChatTokenUsage tokenUsage)
AdditionalCounts = [],
};

var counts = destination.AdditionalCounts;

if (tokenUsage.InputTokenDetails is ChatInputTokenUsageDetails inputDetails)
{
destination.AdditionalCounts.Add(
$"{nameof(ChatTokenUsage.InputTokenDetails)}.{nameof(ChatInputTokenUsageDetails.AudioTokenCount)}",
inputDetails.AudioTokenCount);

destination.AdditionalCounts.Add(
$"{nameof(ChatTokenUsage.InputTokenDetails)}.{nameof(ChatInputTokenUsageDetails.CachedTokenCount)}",
inputDetails.CachedTokenCount);
const string InputDetails = nameof(ChatTokenUsage.InputTokenDetails);
counts.Add($"{InputDetails}.{nameof(ChatInputTokenUsageDetails.AudioTokenCount)}", inputDetails.AudioTokenCount);
counts.Add($"{InputDetails}.{nameof(ChatInputTokenUsageDetails.CachedTokenCount)}", inputDetails.CachedTokenCount);
}

if (tokenUsage.OutputTokenDetails is ChatOutputTokenUsageDetails outputDetails)
{
destination.AdditionalCounts.Add(
$"{nameof(ChatTokenUsage.OutputTokenDetails)}.{nameof(ChatOutputTokenUsageDetails.AudioTokenCount)}",
outputDetails.AudioTokenCount);

destination.AdditionalCounts.Add(
$"{nameof(ChatTokenUsage.OutputTokenDetails)}.{nameof(ChatOutputTokenUsageDetails.ReasoningTokenCount)}",
outputDetails.ReasoningTokenCount);
const string OutputDetails = nameof(ChatTokenUsage.OutputTokenDetails);
counts.Add($"{OutputDetails}.{nameof(ChatOutputTokenUsageDetails.ReasoningTokenCount)}", outputDetails.ReasoningTokenCount);
counts.Add($"{OutputDetails}.{nameof(ChatOutputTokenUsageDetails.AudioTokenCount)}", outputDetails.AudioTokenCount);
counts.Add($"{OutputDetails}.{nameof(ChatOutputTokenUsageDetails.AcceptedPredictionTokenCount)}", outputDetails.AcceptedPredictionTokenCount);
counts.Add($"{OutputDetails}.{nameof(ChatOutputTokenUsageDetails.RejectedPredictionTokenCount)}", outputDetails.RejectedPredictionTokenCount);
}

return destination;
Expand All @@ -452,34 +502,26 @@ private static ChatTokenUsage ToOpenAIUsage(UsageDetails usageDetails)

if (usageDetails.AdditionalCounts is { Count: > 0 } additionalCounts)
{
int? inputAudioTokenCount = additionalCounts.TryGetValue(
$"{nameof(ChatTokenUsage.InputTokenDetails)}.{nameof(ChatInputTokenUsageDetails.AudioTokenCount)}",
out int value) ? value : null;

int? inputCachedTokenCount = additionalCounts.TryGetValue(
$"{nameof(ChatTokenUsage.InputTokenDetails)}.{nameof(ChatInputTokenUsageDetails.CachedTokenCount)}",
out value) ? value : null;

int? outputAudioTokenCount = additionalCounts.TryGetValue(
$"{nameof(ChatTokenUsage.OutputTokenDetails)}.{nameof(ChatOutputTokenUsageDetails.AudioTokenCount)}",
out value) ? value : null;

int? outputReasoningTokenCount = additionalCounts.TryGetValue(
$"{nameof(ChatTokenUsage.OutputTokenDetails)}.{nameof(ChatOutputTokenUsageDetails.ReasoningTokenCount)}",
out value) ? value : null;

if (inputAudioTokenCount is not null || inputCachedTokenCount is not null)
const string InputDetails = nameof(ChatTokenUsage.InputTokenDetails);
if (additionalCounts.TryGetValue($"{InputDetails}.{nameof(ChatInputTokenUsageDetails.AudioTokenCount)}", out int inputAudioTokenCount) |
additionalCounts.TryGetValue($"{InputDetails}.{nameof(ChatInputTokenUsageDetails.CachedTokenCount)}", out int inputCachedTokenCount))
{
inputTokenUsageDetails = OpenAIChatModelFactory.ChatInputTokenUsageDetails(
audioTokenCount: inputAudioTokenCount ?? 0,
cachedTokenCount: inputCachedTokenCount ?? 0);
audioTokenCount: inputAudioTokenCount,
cachedTokenCount: inputCachedTokenCount);
}

if (outputAudioTokenCount is not null || outputReasoningTokenCount is not null)
const string OutputDetails = nameof(ChatTokenUsage.OutputTokenDetails);
if (additionalCounts.TryGetValue($"{OutputDetails}.{nameof(ChatOutputTokenUsageDetails.ReasoningTokenCount)}", out int outputReasoningTokenCount) |
additionalCounts.TryGetValue($"{OutputDetails}.{nameof(ChatOutputTokenUsageDetails.AudioTokenCount)}", out int outputAudioTokenCount) |
additionalCounts.TryGetValue($"{OutputDetails}.{nameof(ChatOutputTokenUsageDetails.AcceptedPredictionTokenCount)}", out int outputAcceptedPredictionCount) |
additionalCounts.TryGetValue($"{OutputDetails}.{nameof(ChatOutputTokenUsageDetails.RejectedPredictionTokenCount)}", out int outputRejectedPredictionCount))
{
outputTokenUsageDetails = OpenAIChatModelFactory.ChatOutputTokenUsageDetails(
audioTokenCount: outputAudioTokenCount ?? 0,
reasoningTokenCount: outputReasoningTokenCount ?? 0);
reasoningTokenCount: outputReasoningTokenCount,
audioTokenCount: outputAudioTokenCount,
acceptedPredictionTokenCount: outputAcceptedPredictionCount,
rejectedPredictionTokenCount: outputRejectedPredictionCount);
}
}

Expand All @@ -505,6 +547,7 @@ private static ChatRole FromOpenAIChatRole(ChatMessageRole role) =>
ChatMessageRole.User => ChatRole.User,
ChatMessageRole.Assistant => ChatRole.Assistant,
ChatMessageRole.Tool => ChatRole.Tool,
ChatMessageRole.Developer => ChatRoleDeveloper,
_ => new ChatRole(role.ToString()),
};

Expand All @@ -515,7 +558,9 @@ private static ChatRole FromOpenAIChatRole(ChatMessageRole role) =>
role == ChatRole.System ? ChatMessageRole.System :
role == ChatRole.User ? ChatMessageRole.User :
role == ChatRole.Assistant ? ChatMessageRole.Assistant :
role == ChatRole.Tool ? ChatMessageRole.Tool : ChatMessageRole.User;
role == ChatRole.Tool ? ChatMessageRole.Tool :
role == OpenAIModelMappers.ChatRoleDeveloper ? ChatMessageRole.Developer :
ChatMessageRole.User;

/// <summary>Creates an <see cref="AIContent"/> from a <see cref="ChatMessageContentPart"/>.</summary>
/// <param name="contentPart">The content part to convert into a content.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ namespace Microsoft.Extensions.AI;

internal static partial class OpenAIModelMappers
{
public static ChatRole ChatRoleDeveloper { get; } = new ChatRole("developer");

public static OpenAIChatCompletionRequest FromOpenAIChatCompletionRequest(OpenAI.Chat.ChatCompletionOptions chatCompletionOptions)
{
ChatOptions chatOptions = FromOpenAIOptions(chatCompletionOptions);
Expand Down Expand Up @@ -45,6 +47,15 @@ public static IEnumerable<ChatMessage> FromOpenAIChatMessages(IEnumerable<OpenAI
};
break;

case DeveloperChatMessage developerMessage:
yield return new ChatMessage
{
Role = ChatRoleDeveloper,
AuthorName = developerMessage.ParticipantName,
Contents = FromOpenAIChatContent(developerMessage.Content),
};
break;

case UserChatMessage userMessage:
yield return new ChatMessage
{
Expand Down Expand Up @@ -118,11 +129,14 @@ public static IEnumerable<ChatMessage> FromOpenAIChatMessages(IEnumerable<OpenAI

foreach (ChatMessage input in inputs)
{
if (input.Role == ChatRole.System || input.Role == ChatRole.User)
if (input.Role == ChatRole.System ||
input.Role == ChatRole.User ||
input.Role == ChatRoleDeveloper)
{
var parts = ToOpenAIChatContent(input.Contents);
yield return input.Role == ChatRole.System ?
new SystemChatMessage(parts) { ParticipantName = input.AuthorName } :
yield return
input.Role == ChatRole.System ? new SystemChatMessage(parts) { ParticipantName = input.AuthorName } :
input.Role == OpenAIModelMappers.ChatRoleDeveloper ? new DeveloperChatMessage(parts) { ParticipantName = input.AuthorName } :
new UserChatMessage(parts) { ParticipantName = input.AuthorName };
}
else if (input.Role == ChatRole.Tool)
Expand Down Expand Up @@ -225,6 +239,19 @@ private static List<ChatMessageContentPart> ToOpenAIChatContent(IList<AIContent>
parts.Add(ChatMessageContentPart.CreateImagePart(new Uri(uri)));
}

break;

case DataContent dataContent when dataContent.MediaTypeStartsWith("audio/") && dataContent.Data.HasValue:
var audioData = BinaryData.FromBytes(dataContent.Data.Value);
if (dataContent.MediaTypeStartsWith("audio/mpeg"))
{
parts.Add(ChatMessageContentPart.CreateInputAudioPart(audioData, ChatInputAudioFormat.Mp3));
}
else if (dataContent.MediaTypeStartsWith("audio/wav"))
{
parts.Add(ChatMessageContentPart.CreateInputAudioPart(audioData, ChatInputAudioFormat.Wav));
}

break;
}
}
Expand Down
Loading

0 comments on commit 468fe7d

Please sign in to comment.