diff --git a/src/csharp/ChatClient.cs b/src/csharp/ChatClient.cs
index d47c3275b..f4033d1aa 100644
--- a/src/csharp/ChatClient.cs
+++ b/src/csharp/ChatClient.cs
@@ -8,9 +8,11 @@
namespace Microsoft.ML.OnnxRuntimeGenAI;
-/// An implementation based on ONNX Runtime GenAI.
-public sealed class ChatClient : IChatClient, IDisposable
+/// Provides an implementation based on ONNX Runtime GenAI.
+public sealed partial class ChatClient : IChatClient
{
+ /// The options used to configure the instance.
+ private readonly ChatClientConfiguration _config;
/// The wrapped .
private readonly Model _model;
/// The wrapped .
@@ -20,8 +22,9 @@ public sealed class ChatClient : IChatClient, IDisposable
/// Initializes an instance of the class.
/// The file path to the model to load.
+ /// Options used to configure the client instance.
/// is null.
- public ChatClient(string modelPath)
+ public ChatClient(string modelPath, ChatClientConfiguration configuration)
{
if (modelPath is null)
{
@@ -54,32 +57,12 @@ public ChatClient(Model model, bool ownsModel = true)
_model = model;
_tokenizer = new Tokenizer(_model);
- Metadata = new("Microsoft.ML.OnnxRuntimeGenAI");
+ Metadata = new("onnxruntime-genai");
}
///
public ChatClientMetadata Metadata { get; }
- ///
- /// Gets or sets stop sequences to use during generation.
- ///
- ///
- /// These will apply in addition to any stop sequences that are a part of the .
- ///
- public IList StopSequences { get; set; } =
- [
- // Default stop sequences based on Phi3
- "<|system|>",
- "<|user|>",
- "<|assistant|>",
- "<|end|>"
- ];
-
- ///
- /// Gets or sets a function that creates a prompt string from the chat history.
- ///
- public Func, string> PromptFormatter { get; set; }
-
///
public void Dispose()
{
@@ -102,12 +85,13 @@ public async Task CompleteAsync(IList chatMessages,
StringBuilder text = new();
await Task.Run(() =>
{
- using Sequences tokens = _tokenizer.Encode(CreatePrompt(chatMessages));
+ using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
using GeneratorParams generatorParams = new(_model);
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
- generatorParams.SetInputSequences(tokens);
using Generator generator = new(_model, generatorParams);
+ generator.AppendTokenSequences(tokens);
+
using var tokenizerStream = _tokenizer.CreateStream();
var completionId = Guid.NewGuid().ToString();
@@ -115,7 +99,6 @@ await Task.Run(() =>
{
cancellationToken.ThrowIfCancellationRequested();
- generator.ComputeLogits();
generator.GenerateNextToken();
ReadOnlySpan outputSequence = generator.GetSequence(0);
@@ -147,12 +130,13 @@ public async IAsyncEnumerable CompleteStreamingAs
throw new ArgumentNullException(nameof(chatMessages));
}
- using Sequences tokens = _tokenizer.Encode(CreatePrompt(chatMessages));
+ using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
using GeneratorParams generatorParams = new(_model);
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
- generatorParams.SetInputSequences(tokens);
using Generator generator = new(_model, generatorParams);
+ generator.AppendTokenSequences(tokens);
+
using var tokenizerStream = _tokenizer.CreateStream();
var completionId = Guid.NewGuid().ToString();
@@ -160,7 +144,6 @@ public async IAsyncEnumerable CompleteStreamingAs
{
string next = await Task.Run(() =>
{
- generator.ComputeLogits();
generator.GenerateNextToken();
ReadOnlySpan outputSequence = generator.GetSequence(0);
@@ -193,43 +176,7 @@ public object GetService(Type serviceType, object key = null) =>
/// Gets whether the specified token is a stop sequence.
private bool IsStop(string token, ChatOptions options) =>
options?.StopSequences?.Contains(token) is true ||
- StopSequences?.Contains(token) is true;
-
- /// Creates a prompt string from the supplied chat history.
- private string CreatePrompt(IEnumerable messages)
- {
- if (messages is null)
- {
- throw new ArgumentNullException(nameof(messages));
- }
-
- if (PromptFormatter is not null)
- {
- return PromptFormatter(messages) ?? string.Empty;
- }
-
- // Default formatting based on Phi3.
- StringBuilder prompt = new();
-
- foreach (var message in messages)
- {
- foreach (var content in message.Contents)
- {
- switch (content)
- {
- case TextContent tc when !string.IsNullOrWhiteSpace(tc.Text):
- prompt.Append("<|").Append(message.Role.Value).Append("|>\n")
- .Append(tc.Text.Replace("<|end|>\n", ""))
- .Append("<|end|>\n");
- break;
- }
- }
- }
-
- prompt.Append("<|assistant|>");
-
- return prompt.ToString();
- }
+ Array.IndexOf(_config.StopSequences, token) >= 0;
/// Updates the based on the supplied .
private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options)
@@ -262,6 +209,11 @@ private static void UpdateGeneratorParamsFromOptions(int numInputTokens, Generat
}
}
+ if (options.Seed.HasValue)
+ {
+ generatorParams.SetSearchOption("random_seed", options.Seed.Value);
+ }
+
if (options.AdditionalProperties is { } props)
{
foreach (var entry in props)
diff --git a/src/csharp/ChatClientConfiguration.cs b/src/csharp/ChatClientConfiguration.cs
new file mode 100644
index 000000000..282ae9362
--- /dev/null
+++ b/src/csharp/ChatClientConfiguration.cs
@@ -0,0 +1,73 @@
+using Microsoft.Extensions.AI;
+using System;
+using System.Collections.Generic;
+
+namespace Microsoft.ML.OnnxRuntimeGenAI;
+
+/// Provides configuration options used when constructing a .
+///
+/// Every model has different requirements for stop sequences and prompt formatting. For best results,
+/// the configuration should be tailored to the exact nature of the model being used. For example,
+/// when using a Phi3 model, a configuration like the following may be used:
+///
+/// static ChatClientConfiguration CreateForPhi3() =>
+/// new(["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"],
+/// (IEnumerable<ChatMessage> messages) =>
+/// {
+/// StringBuilder prompt = new();
+///
+/// foreach (var message in messages)
+/// foreach (var content in message.Contents.OfType<TextContent>())
+/// prompt.Append("<|").Append(message.Role.Value).Append("|>\n").Append(tc.Text).Append("<|end|>\n");
+///
+/// return prompt.Append("<|assistant|>\n").ToString();
+/// });
+///
+///
+public sealed class ChatClientConfiguration
+{
+ private string[] _stopSequences;
+ private Func, string> _promptFormatter;
+
+ /// Initializes a new instance of the class.
+ /// The stop sequences used by the model.
+ /// The function to use to format a list of messages for input into the model.
+ /// is null.
+ /// is null.
+ public ChatClientConfiguration(
+ string[] stopSequences,
+ Func, string> promptFormatter)
+ {
+ if (stopSequences is null)
+ {
+ throw new ArgumentNullException(nameof(stopSequences));
+ }
+
+ if (promptFormatter is null)
+ {
+ throw new ArgumentNullException(nameof(promptFormatter));
+ }
+
+ StopSequences = stopSequences;
+ PromptFormatter = promptFormatter;
+ }
+
+ ///
+ /// Gets or sets stop sequences to use during generation.
+ ///
+ ///
+ /// These will apply in addition to any stop sequences that are a part of the .
+ ///
+ public string[] StopSequences
+ {
+ get => _stopSequences;
+ set => _stopSequences = value ?? throw new ArgumentNullException(nameof(value));
+ }
+
+ /// Gets the function that creates a prompt string from the chat history.
+ public Func, string> PromptFormatter
+ {
+ get => _promptFormatter;
+ set => _promptFormatter = value ?? throw new ArgumentNullException(nameof(value));
+ }
+}
\ No newline at end of file
diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs
index ddced9e42..9480d0d84 100644
--- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs
+++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs
@@ -5,9 +5,11 @@
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
-using System.Runtime.CompilerServices;
using Xunit;
using Xunit.Abstractions;
+using System.Collections.Generic;
+using Microsoft.Extensions.AI;
+using System.Text;
namespace Microsoft.ML.OnnxRuntimeGenAI.Tests
{
@@ -349,6 +351,32 @@ public void TestTopKTopPSearch()
}
}
+ [IgnoreOnModelAbsenceFact(DisplayName = "TestChatClient")]
+ public async void TestChatClient()
+ {
+ using var client = new ChatClient(
+ _phi2Path,
+ new(["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"],
+ (IEnumerable messages) =>
+ {
+ StringBuilder prompt = new();
+
+ foreach (var message in messages)
+ foreach (var content in message.Contents.OfType())
+ prompt.Append("<|").Append(message.Role.Value).Append("|>\n").Append(content.Text).Append("<|end|>\n");
+
+ return prompt.Append("<|assistant|>\n").ToString();
+ }));
+
+ var completion = await client.CompleteAsync("What is 2 + 3?", new()
+ {
+ MaxOutputTokens = 20,
+ Temperature = 0f,
+ });
+
+ Assert.Contains("5", completion.ToString());
+ }
+
[IgnoreOnModelAbsenceFact(DisplayName = "TestTokenizerBatchEncodeDecode")]
public void TestTokenizerBatchEncodeDecode()
{