diff --git a/src/csharp/ChatClient.cs b/src/csharp/ChatClient.cs new file mode 100644 index 000000000..753ec8d5d --- /dev/null +++ b/src/csharp/ChatClient.cs @@ -0,0 +1,278 @@ +using Microsoft.Extensions.AI; +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.ML.OnnxRuntimeGenAI; + +/// An implementation based on ONNX Runtime GenAI. +public sealed class ChatClient : IChatClient, IDisposable +{ + /// The wrapped . + private readonly Model _model; + /// The wrapped . + private readonly Tokenizer _tokenizer; + /// Whether to dispose of when this instance is disposed. + private readonly bool _ownsModel; + + /// Initializes an instance of the class. + /// The file path to the model to load. + /// is null. + public ChatClient(string modelPath) + { + if (modelPath is null) + { + throw new ArgumentNullException(nameof(modelPath)); + } + + _ownsModel = true; + _model = new Model(modelPath); + _tokenizer = new Tokenizer(_model); + + Metadata = new(typeof(ChatClient).Namespace, new Uri($"file://{modelPath}"), modelPath); + } + + /// Initializes an instance of the class. + /// The model to employ. + /// + /// if this owns the and should + /// dispose of it when this is disposed; otherwise, . + /// The default is . + /// + /// is null. + public ChatClient(Model model, bool ownsModel = true) + { + if (model is null) + { + throw new ArgumentNullException(nameof(model)); + } + + _ownsModel = ownsModel; + _model = model; + _tokenizer = new Tokenizer(_model); + + Metadata = new("Microsoft.ML.OnnxRuntimeGenAI"); + } + + /// + public ChatClientMetadata Metadata { get; } + + /// + /// Gets or sets stop sequences to use during generation. + /// + /// + /// These will apply in addition to any stop sequences that are a part of the . + /// + public IList StopSequences { get; set; } = + [ + // Default stop sequences based on Phi3 + "<|system|>", + "<|user|>", + "<|assistant|>", + "<|end|>" + ]; + + /// + /// Gets or sets a function that creates a prompt string from the chat history. + /// + public Func, string> PromptFormatter { get; set; } + + /// + public void Dispose() + { + _tokenizer.Dispose(); + + if (_ownsModel) + { + _model.Dispose(); + } + } + + /// + public async Task CompleteAsync(IList chatMessages, ChatOptions options = null, CancellationToken cancellationToken = default) + { + if (chatMessages is null) + { + throw new ArgumentNullException(nameof(chatMessages)); + } + + StringBuilder text = new(); + await Task.Run(() => + { + using Sequences tokens = _tokenizer.Encode(CreatePrompt(chatMessages)); + using GeneratorParams generatorParams = new(_model); + UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options); + generatorParams.SetInputSequences(tokens); + + using Generator generator = new(_model, generatorParams); + using var tokenizerStream = _tokenizer.CreateStream(); + + var completionId = Guid.NewGuid().ToString(); + while (!generator.IsDone()) + { + cancellationToken.ThrowIfCancellationRequested(); + + generator.ComputeLogits(); + generator.GenerateNextToken(); + + ReadOnlySpan outputSequence = generator.GetSequence(0); + string next = tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]); + + if (IsStop(next, options)) + { + break; + } + + text.Append(next); + } + }, cancellationToken); + + return new ChatCompletion(new ChatMessage(ChatRole.Assistant, text.ToString())) + { + CompletionId = Guid.NewGuid().ToString(), + CreatedAt = DateTimeOffset.UtcNow, + ModelId = Metadata.ModelId, + }; + } + + /// + public async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (chatMessages is null) + { + throw new ArgumentNullException(nameof(chatMessages)); + } + + using Sequences tokens = _tokenizer.Encode(CreatePrompt(chatMessages)); + using GeneratorParams generatorParams = new(_model); + UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options); + generatorParams.SetInputSequences(tokens); + + using Generator generator = new(_model, generatorParams); + using var tokenizerStream = _tokenizer.CreateStream(); + + var completionId = Guid.NewGuid().ToString(); + while (!generator.IsDone()) + { + string next = await Task.Run(() => + { + generator.ComputeLogits(); + generator.GenerateNextToken(); + + ReadOnlySpan outputSequence = generator.GetSequence(0); + return tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]); + }, cancellationToken); + + if (IsStop(next, options)) + { + break; + } + + yield return new StreamingChatCompletionUpdate + { + CompletionId = completionId, + CreatedAt = DateTimeOffset.UtcNow, + Role = ChatRole.Assistant, + Text = next, + }; + } + } + + /// + public TService GetService(object key = null) where TService : class => + typeof(TService) == typeof(Model) ? (TService)(object)_model : + typeof(TService) == typeof(Tokenizer) ? (TService)(object)_tokenizer : + this as TService; + + /// Gets whether the specified token is a stop sequence. + private bool IsStop(string token, ChatOptions options) => + options?.StopSequences?.Contains(token) is true || + StopSequences?.Contains(token) is true; + + /// Creates a prompt string from the supplied chat history. + private string CreatePrompt(IEnumerable messages) + { + if (messages is null) + { + throw new ArgumentNullException(nameof(messages)); + } + + if (PromptFormatter is not null) + { + return PromptFormatter(messages) ?? string.Empty; + } + + // Default formatting based on Phi3. + StringBuilder prompt = new(); + + foreach (var message in messages) + { + foreach (var content in message.Contents) + { + switch (content) + { + case TextContent tc when !string.IsNullOrWhiteSpace(tc.Text): + prompt.Append("<|").Append(message.Role.Value).Append("|>\n") + .Append(tc.Text.Replace("<|end|>\n", "")) + .Append("<|end|>\n"); + break; + } + } + } + + prompt.Append("<|assistant|>"); + + return prompt.ToString(); + } + + /// Updates the based on the supplied . + private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options) + { + if (options is null) + { + return; + } + + if (options.MaxOutputTokens.HasValue) + { + generatorParams.SetSearchOption("max_length", numInputTokens + options.MaxOutputTokens.Value); + } + + if (options.Temperature.HasValue) + { + generatorParams.SetSearchOption("temperature", options.Temperature.Value); + } + + if (options.TopP.HasValue || options.TopK.HasValue) + { + if (options.TopP.HasValue) + { + generatorParams.SetSearchOption("top_p", options.TopP.Value); + } + + if (options.TopK.HasValue) + { + generatorParams.SetSearchOption("top_k", options.TopK.Value); + } + } + + if (options.AdditionalProperties is { } props) + { + foreach (var entry in props) + { + switch (entry.Value) + { + case int i: generatorParams.SetSearchOption(entry.Key, i); break; + case long l: generatorParams.SetSearchOption(entry.Key, l); break; + case float f: generatorParams.SetSearchOption(entry.Key, f); break; + case double d: generatorParams.SetSearchOption(entry.Key, d); break; + case bool b: generatorParams.SetSearchOption(entry.Key, b); break; + } + } + } + } +} \ No newline at end of file diff --git a/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj b/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj index ee53c83fb..b6e838adb 100644 --- a/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj +++ b/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj @@ -121,4 +121,8 @@ + + + +