diff --git a/src/csharp/ChatClient.cs b/src/csharp/ChatClient.cs index d47c3275b..f4033d1aa 100644 --- a/src/csharp/ChatClient.cs +++ b/src/csharp/ChatClient.cs @@ -8,9 +8,11 @@ namespace Microsoft.ML.OnnxRuntimeGenAI; -/// An implementation based on ONNX Runtime GenAI. -public sealed class ChatClient : IChatClient, IDisposable +/// Provides an implementation based on ONNX Runtime GenAI. +public sealed partial class ChatClient : IChatClient { + /// The options used to configure the instance. + private readonly ChatClientConfiguration _config; /// The wrapped . private readonly Model _model; /// The wrapped . @@ -20,8 +22,9 @@ public sealed class ChatClient : IChatClient, IDisposable /// Initializes an instance of the class. /// The file path to the model to load. + /// Options used to configure the client instance. /// is null. - public ChatClient(string modelPath) + public ChatClient(string modelPath, ChatClientConfiguration configuration) { if (modelPath is null) { @@ -54,32 +57,12 @@ public ChatClient(Model model, bool ownsModel = true) _model = model; _tokenizer = new Tokenizer(_model); - Metadata = new("Microsoft.ML.OnnxRuntimeGenAI"); + Metadata = new("onnxruntime-genai"); } /// public ChatClientMetadata Metadata { get; } - /// - /// Gets or sets stop sequences to use during generation. - /// - /// - /// These will apply in addition to any stop sequences that are a part of the . - /// - public IList StopSequences { get; set; } = - [ - // Default stop sequences based on Phi3 - "<|system|>", - "<|user|>", - "<|assistant|>", - "<|end|>" - ]; - - /// - /// Gets or sets a function that creates a prompt string from the chat history. - /// - public Func, string> PromptFormatter { get; set; } - /// public void Dispose() { @@ -102,12 +85,13 @@ public async Task CompleteAsync(IList chatMessages, StringBuilder text = new(); await Task.Run(() => { - using Sequences tokens = _tokenizer.Encode(CreatePrompt(chatMessages)); + using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages)); using GeneratorParams generatorParams = new(_model); UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options); - generatorParams.SetInputSequences(tokens); using Generator generator = new(_model, generatorParams); + generator.AppendTokenSequences(tokens); + using var tokenizerStream = _tokenizer.CreateStream(); var completionId = Guid.NewGuid().ToString(); @@ -115,7 +99,6 @@ await Task.Run(() => { cancellationToken.ThrowIfCancellationRequested(); - generator.ComputeLogits(); generator.GenerateNextToken(); ReadOnlySpan outputSequence = generator.GetSequence(0); @@ -147,12 +130,13 @@ public async IAsyncEnumerable CompleteStreamingAs throw new ArgumentNullException(nameof(chatMessages)); } - using Sequences tokens = _tokenizer.Encode(CreatePrompt(chatMessages)); + using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages)); using GeneratorParams generatorParams = new(_model); UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options); - generatorParams.SetInputSequences(tokens); using Generator generator = new(_model, generatorParams); + generator.AppendTokenSequences(tokens); + using var tokenizerStream = _tokenizer.CreateStream(); var completionId = Guid.NewGuid().ToString(); @@ -160,7 +144,6 @@ public async IAsyncEnumerable CompleteStreamingAs { string next = await Task.Run(() => { - generator.ComputeLogits(); generator.GenerateNextToken(); ReadOnlySpan outputSequence = generator.GetSequence(0); @@ -193,43 +176,7 @@ public object GetService(Type serviceType, object key = null) => /// Gets whether the specified token is a stop sequence. private bool IsStop(string token, ChatOptions options) => options?.StopSequences?.Contains(token) is true || - StopSequences?.Contains(token) is true; - - /// Creates a prompt string from the supplied chat history. - private string CreatePrompt(IEnumerable messages) - { - if (messages is null) - { - throw new ArgumentNullException(nameof(messages)); - } - - if (PromptFormatter is not null) - { - return PromptFormatter(messages) ?? string.Empty; - } - - // Default formatting based on Phi3. - StringBuilder prompt = new(); - - foreach (var message in messages) - { - foreach (var content in message.Contents) - { - switch (content) - { - case TextContent tc when !string.IsNullOrWhiteSpace(tc.Text): - prompt.Append("<|").Append(message.Role.Value).Append("|>\n") - .Append(tc.Text.Replace("<|end|>\n", "")) - .Append("<|end|>\n"); - break; - } - } - } - - prompt.Append("<|assistant|>"); - - return prompt.ToString(); - } + Array.IndexOf(_config.StopSequences, token) >= 0; /// Updates the based on the supplied . private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options) @@ -262,6 +209,11 @@ private static void UpdateGeneratorParamsFromOptions(int numInputTokens, Generat } } + if (options.Seed.HasValue) + { + generatorParams.SetSearchOption("random_seed", options.Seed.Value); + } + if (options.AdditionalProperties is { } props) { foreach (var entry in props) diff --git a/src/csharp/ChatClientConfiguration.cs b/src/csharp/ChatClientConfiguration.cs new file mode 100644 index 000000000..282ae9362 --- /dev/null +++ b/src/csharp/ChatClientConfiguration.cs @@ -0,0 +1,73 @@ +using Microsoft.Extensions.AI; +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.OnnxRuntimeGenAI; + +/// Provides configuration options used when constructing a . +/// +/// Every model has different requirements for stop sequences and prompt formatting. For best results, +/// the configuration should be tailored to the exact nature of the model being used. For example, +/// when using a Phi3 model, a configuration like the following may be used: +/// +/// static ChatClientConfiguration CreateForPhi3() => +/// new(["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"], +/// (IEnumerable<ChatMessage> messages) => +/// { +/// StringBuilder prompt = new(); +/// +/// foreach (var message in messages) +/// foreach (var content in message.Contents.OfType<TextContent>()) +/// prompt.Append("<|").Append(message.Role.Value).Append("|>\n").Append(tc.Text).Append("<|end|>\n"); +/// +/// return prompt.Append("<|assistant|>\n").ToString(); +/// }); +/// +/// +public sealed class ChatClientConfiguration +{ + private string[] _stopSequences; + private Func, string> _promptFormatter; + + /// Initializes a new instance of the class. + /// The stop sequences used by the model. + /// The function to use to format a list of messages for input into the model. + /// is null. + /// is null. + public ChatClientConfiguration( + string[] stopSequences, + Func, string> promptFormatter) + { + if (stopSequences is null) + { + throw new ArgumentNullException(nameof(stopSequences)); + } + + if (promptFormatter is null) + { + throw new ArgumentNullException(nameof(promptFormatter)); + } + + StopSequences = stopSequences; + PromptFormatter = promptFormatter; + } + + /// + /// Gets or sets stop sequences to use during generation. + /// + /// + /// These will apply in addition to any stop sequences that are a part of the . + /// + public string[] StopSequences + { + get => _stopSequences; + set => _stopSequences = value ?? throw new ArgumentNullException(nameof(value)); + } + + /// Gets the function that creates a prompt string from the chat history. + public Func, string> PromptFormatter + { + get => _promptFormatter; + set => _promptFormatter = value ?? throw new ArgumentNullException(nameof(value)); + } +} \ No newline at end of file diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs index ddced9e42..9480d0d84 100644 --- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs +++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs @@ -5,9 +5,11 @@ using System.IO; using System.Linq; using System.Runtime.InteropServices; -using System.Runtime.CompilerServices; using Xunit; using Xunit.Abstractions; +using System.Collections.Generic; +using Microsoft.Extensions.AI; +using System.Text; namespace Microsoft.ML.OnnxRuntimeGenAI.Tests { @@ -349,6 +351,32 @@ public void TestTopKTopPSearch() } } + [IgnoreOnModelAbsenceFact(DisplayName = "TestChatClient")] + public async void TestChatClient() + { + using var client = new ChatClient( + _phi2Path, + new(["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"], + (IEnumerable messages) => + { + StringBuilder prompt = new(); + + foreach (var message in messages) + foreach (var content in message.Contents.OfType()) + prompt.Append("<|").Append(message.Role.Value).Append("|>\n").Append(content.Text).Append("<|end|>\n"); + + return prompt.Append("<|assistant|>\n").ToString(); + })); + + var completion = await client.CompleteAsync("What is 2 + 3?", new() + { + MaxOutputTokens = 20, + Temperature = 0f, + }); + + Assert.Contains("5", completion.ToString()); + } + [IgnoreOnModelAbsenceFact(DisplayName = "TestTokenizerBatchEncodeDecode")] public void TestTokenizerBatchEncodeDecode() {