From 0b759a7f03d69a045939de69510a58ca5c92ce44 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 11 Oct 2024 11:16:00 -0400 Subject: [PATCH] Make Microsoft.ML.OnnxRuntimeGenAI.Tokenizer a Microsoft.ML.Tokenizers.Tokenizer This enables an ONNX Runtime GenAI tokenizer instance to be used anywhere a Microsoft.ML.Tokenizers tokenizer is accepted. If we'd prefer, rather than having Tokenizer be a base class for the ONNX Runtime one, we could instead expose some sort of `public Microsoft.ML.Tokenizer.Tokenizer AsTokenizer()` conversion method that returns a wrapper object (though that's a bit confusing given the names of the type are the same, just different namespaces). --- src/csharp/Exceptions.cs | 1 - src/csharp/GeneratorParams.cs | 9 +- src/csharp/Images.cs | 3 +- .../Microsoft.ML.OnnxRuntimeGenAI.csproj | 4 + src/csharp/Model.cs | 3 +- src/csharp/MultiModalProcessor.cs | 5 +- src/csharp/Result.cs | 13 +- src/csharp/Sequences.cs | 1 - src/csharp/Tensor.cs | 1 - src/csharp/Tokenizer.cs | 194 +++++++++++++++--- src/csharp/TokenizerStream.cs | 3 +- src/csharp/Utils.cs | 56 +++-- 12 files changed, 224 insertions(+), 69 deletions(-) diff --git a/src/csharp/Exceptions.cs b/src/csharp/Exceptions.cs index 3c023894b..8228bf794 100644 --- a/src/csharp/Exceptions.cs +++ b/src/csharp/Exceptions.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.Collections.Generic; namespace Microsoft.ML.OnnxRuntimeGenAI { diff --git a/src/csharp/GeneratorParams.cs b/src/csharp/GeneratorParams.cs index ac225e21c..644e2c586 100644 --- a/src/csharp/GeneratorParams.cs +++ b/src/csharp/GeneratorParams.cs @@ -2,9 +2,6 @@ // Licensed under the MIT License. using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntimeGenAI { @@ -21,12 +18,12 @@ public GeneratorParams(Model model) public void SetSearchOption(string searchOption, double value) { - Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchNumber(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), value)); + Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchNumber(_generatorParamsHandle, StringUtils.ToNullTerminatedUtf8(searchOption), value)); } public void SetSearchOption(string searchOption, bool value) { - Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchBool(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), value)); + Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchBool(_generatorParamsHandle, StringUtils.ToNullTerminatedUtf8(searchOption), value)); } public void TryGraphCaptureWithMaxBatchSize(int maxBatchSize) @@ -52,7 +49,7 @@ public void SetInputSequences(Sequences sequences) public void SetModelInput(string name, Tensor value) { - Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetModelInput(_generatorParamsHandle, StringUtils.ToUtf8(name), value.Handle)); + Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetModelInput(_generatorParamsHandle, StringUtils.ToNullTerminatedUtf8(name), value.Handle)); } public void SetInputs(NamedTensors namedTensors) diff --git a/src/csharp/Images.cs b/src/csharp/Images.cs index b160f869e..edf1a7c69 100644 --- a/src/csharp/Images.cs +++ b/src/csharp/Images.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntimeGenAI { @@ -23,7 +22,7 @@ public static Images Load(string[] imagePaths) Result.VerifySuccess(NativeMethods.OgaCreateStringArray(out IntPtr stringArray)); foreach (string imagePath in imagePaths) { - Result.VerifySuccess(NativeMethods.OgaStringArrayAddString(stringArray, StringUtils.ToUtf8(imagePath))); + Result.VerifySuccess(NativeMethods.OgaStringArrayAddString(stringArray, StringUtils.ToNullTerminatedUtf8(imagePath))); } Result.VerifySuccess(NativeMethods.OgaLoadImages(stringArray, out IntPtr imagesHandle)); NativeMethods.OgaDestroyStringArray(stringArray); diff --git a/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj b/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj index ee53c83fb..28328c250 100644 --- a/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj +++ b/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj @@ -121,4 +121,8 @@ + + + + diff --git a/src/csharp/Model.cs b/src/csharp/Model.cs index 675bc2540..e33c8293b 100644 --- a/src/csharp/Model.cs +++ b/src/csharp/Model.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntimeGenAI { @@ -13,7 +12,7 @@ public class Model : IDisposable public Model(string modelPath) { - Result.VerifySuccess(NativeMethods.OgaCreateModel(StringUtils.ToUtf8(modelPath), out _modelHandle)); + Result.VerifySuccess(NativeMethods.OgaCreateModel(StringUtils.ToNullTerminatedUtf8(modelPath), out _modelHandle)); } internal IntPtr Handle { get { return _modelHandle; } } diff --git a/src/csharp/MultiModalProcessor.cs b/src/csharp/MultiModalProcessor.cs index 9eae0eb20..3c06edf8b 100644 --- a/src/csharp/MultiModalProcessor.cs +++ b/src/csharp/MultiModalProcessor.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntimeGenAI { @@ -21,7 +20,7 @@ public MultiModalProcessor(Model model) public NamedTensors ProcessImages(string prompt, Images images) { IntPtr imagesHandle = images == null ? IntPtr.Zero : images.Handle; - Result.VerifySuccess(NativeMethods.OgaProcessorProcessImages(_processorHandle, StringUtils.ToUtf8(prompt), + Result.VerifySuccess(NativeMethods.OgaProcessorProcessImages(_processorHandle, StringUtils.ToNullTerminatedUtf8(prompt), imagesHandle, out IntPtr namedTensorsHandle)); return new NamedTensors(namedTensorsHandle); } @@ -38,7 +37,7 @@ public string Decode(ReadOnlySpan sequence) } try { - return StringUtils.FromUtf8(outStr); + return StringUtils.FromNullTerminatedUtf8(outStr); } finally { diff --git a/src/csharp/Result.cs b/src/csharp/Result.cs index a6185b907..6d5e67f07 100644 --- a/src/csharp/Result.cs +++ b/src/csharp/Result.cs @@ -3,22 +3,25 @@ using System; using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntimeGenAI { - class Result + internal static class Result { - private static string GetErrorMessage(IntPtr nativeResult) + internal static string GetErrorMessage(IntPtr nativeResult) { - - return StringUtils.FromUtf8(NativeMethods.OgaResultGetError(nativeResult)); + return StringUtils.FromNullTerminatedUtf8(NativeMethods.OgaResultGetError(nativeResult)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static void VerifySuccess(IntPtr nativeResult) { if (nativeResult != IntPtr.Zero) + { + Throw(nativeResult); + } + + static void Throw(IntPtr nativeResult) { try { diff --git a/src/csharp/Sequences.cs b/src/csharp/Sequences.cs index 0bbf8da83..efa4ccaf2 100644 --- a/src/csharp/Sequences.cs +++ b/src/csharp/Sequences.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntimeGenAI { diff --git a/src/csharp/Tensor.cs b/src/csharp/Tensor.cs index ae20a1b8b..ed8a7976a 100644 --- a/src/csharp/Tensor.cs +++ b/src/csharp/Tensor.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntimeGenAI { diff --git a/src/csharp/Tokenizer.cs b/src/csharp/Tokenizer.cs index 24fd85ceb..64d4ef31f 100644 --- a/src/csharp/Tokenizer.cs +++ b/src/csharp/Tokenizer.cs @@ -2,29 +2,42 @@ // Licensed under the MIT License. using System; -using System.Runtime.InteropServices; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; using System.Text; +using Microsoft.ML.Tokenizers; namespace Microsoft.ML.OnnxRuntimeGenAI { - public class Tokenizer : IDisposable + public sealed class Tokenizer : Tokenizers.Tokenizer, IDisposable { private IntPtr _tokenizerHandle; - private bool _disposed = false; public Tokenizer(Model model) { + if (model is null) + { + throw new ArgumentNullException(nameof(model)); + } + Result.VerifySuccess(NativeMethods.OgaCreateTokenizer(model.Handle, out _tokenizerHandle)); } public Sequences EncodeBatch(string[] strings) { + if (strings is null) + { + throw new ArgumentNullException(nameof(strings)); + } + Result.VerifySuccess(NativeMethods.OgaCreateSequences(out IntPtr nativeSequences)); try { foreach (string str in strings) { - Result.VerifySuccess(NativeMethods.OgaTokenizerEncode(_tokenizerHandle, StringUtils.ToUtf8(str), nativeSequences)); + Result.VerifySuccess(NativeMethods.OgaTokenizerEncode(_tokenizerHandle, StringUtils.ToNullTerminatedUtf8(str), nativeSequences)); } return new Sequences(nativeSequences); @@ -38,6 +51,11 @@ public Sequences EncodeBatch(string[] strings) public string[] DecodeBatch(Sequences sequences) { + if (sequences is null) + { + throw new ArgumentNullException(nameof(sequences)); + } + string[] result = new string[sequences.NumSequences]; for (ulong i = 0; i < sequences.NumSequences; i++) { @@ -48,11 +66,21 @@ public string[] DecodeBatch(Sequences sequences) } public Sequences Encode(string str) + { + if (str is null) + { + throw new ArgumentNullException(nameof(str)); + } + + return Encode(str.AsSpan()); + } + + public Sequences Encode(ReadOnlySpan str) { Result.VerifySuccess(NativeMethods.OgaCreateSequences(out IntPtr nativeSequences)); try { - Result.VerifySuccess(NativeMethods.OgaTokenizerEncode(_tokenizerHandle, StringUtils.ToUtf8(str), nativeSequences)); + Result.VerifySuccess(NativeMethods.OgaTokenizerEncode(_tokenizerHandle, StringUtils.ToNullTerminatedUtf8(str), nativeSequences)); return new Sequences(nativeSequences); } catch @@ -62,19 +90,18 @@ public Sequences Encode(string str) } } - public string Decode(ReadOnlySpan sequence) + public unsafe string Decode(ReadOnlySpan sequence) { - IntPtr outStr = IntPtr.Zero; - unsafe + IntPtr outStr; + + fixed (int* sequencePtr = sequence) { - fixed (int* sequencePtr = sequence) - { - Result.VerifySuccess(NativeMethods.OgaTokenizerDecode(_tokenizerHandle, sequencePtr, (UIntPtr)sequence.Length, out outStr)); - } + Result.VerifySuccess(NativeMethods.OgaTokenizerDecode(_tokenizerHandle, sequencePtr, (UIntPtr)sequence.Length, out outStr)); } + try { - return StringUtils.FromUtf8(outStr); + return StringUtils.FromNullTerminatedUtf8(outStr); } finally { @@ -84,32 +111,151 @@ public string Decode(ReadOnlySpan sequence) public TokenizerStream CreateStream() { - IntPtr tokenizerStreamHandle = IntPtr.Zero; - Result.VerifySuccess(NativeMethods.OgaCreateTokenizerStream(_tokenizerHandle, out tokenizerStreamHandle)); + Result.VerifySuccess(NativeMethods.OgaCreateTokenizerStream(_tokenizerHandle, out nint tokenizerStreamHandle)); return new TokenizerStream(tokenizerStreamHandle); } - ~Tokenizer() { - Dispose(false); + Dispose(); } public void Dispose() { - Dispose(true); + if (_tokenizerHandle != IntPtr.Zero) + { + NativeMethods.OgaDestroyTokenizer(_tokenizerHandle); + _tokenizerHandle = IntPtr.Zero; + } + GC.SuppressFinalize(this); } - protected virtual void Dispose(bool disposing) + #region Base Tokenizer Overrides + private static int GetMaxTokenCount(EncodeSettings settings) + { + if (settings.MaxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(settings.MaxTokenCount), "The maximum number of tokens must be greater than zero."); + } + + return settings.MaxTokenCount; + } + + /// + protected override int CountTokens(string text, ReadOnlySpan textSpan, EncodeSettings settings) + { + Debug.Assert(text is null || textSpan.SequenceEqual(text.AsSpan())); + + using Sequences sequences = Encode(textSpan); + Debug.Assert(sequences.NumSequences == 1); + + return Math.Min(GetMaxTokenCount(settings), sequences[0].Length); + } + + /// + protected override EncodeResults EncodeToTokens(string text, ReadOnlySpan textSpan, EncodeSettings settings) + { + Debug.Assert(text is null || textSpan.SequenceEqual(text.AsSpan())); + + int maxTokenCount = GetMaxTokenCount(settings); + + using Sequences sequences = Encode(textSpan); + if (sequences.NumSequences != 1) + { + throw new InvalidOperationException("Expected exactly one sequence."); + } + + ReadOnlySpan sequence = sequences[0]; + if (sequence.Length > maxTokenCount) + { + sequence = sequence.Slice(0, maxTokenCount); + } + + // Only the token IDs are returned. The Sequences doesn't contain offset information about each token. + EncodedToken[] tokens = new EncodedToken[sequence.Length]; + for (int i = 0; i < sequence.Length; i++) + { + tokens[i] = new EncodedToken(sequence[i], string.Empty, default); + } + + return new EncodeResults() { Tokens = tokens }; + } + + /// + protected override EncodeResults EncodeToIds(string text, ReadOnlySpan textSpan, EncodeSettings settings) + { + Debug.Assert(text is null || textSpan.SequenceEqual(text.AsSpan())); + + int maxTokenCount = GetMaxTokenCount(settings); + + using Sequences sequences = Encode(textSpan); + Debug.Assert(sequences.NumSequences == 1); + + ReadOnlySpan sequence = sequences[0]; + if (sequence.Length > maxTokenCount) + { + sequence = sequence.Slice(0, maxTokenCount); + } + + return new EncodeResults() { Tokens = sequence.ToArray() }; + } + + /// + public override string Decode(IEnumerable ids) + { + if (ids is null) + { + throw new ArgumentNullException(nameof(ids)); + } + + return Decode(ids as int[] ?? ids.ToArray()); + } + + /// + public override unsafe OperationStatus Decode(IEnumerable ids, Span destination, out int idsConsumed, out int charsWritten) { - if (_disposed) + if (ids is null) { - return; + throw new ArgumentNullException(nameof(ids)); } - NativeMethods.OgaDestroyTokenizer(_tokenizerHandle); - _tokenizerHandle = IntPtr.Zero; - _disposed = true; + + IntPtr outStr; + + int[] idsArray = ids as int[] ?? ids.ToArray(); + fixed (int* sequencePtr = idsArray) + { + try + { + Result.VerifySuccess(NativeMethods.OgaTokenizerDecode(_tokenizerHandle, sequencePtr, (UIntPtr)idsArray.Length, out outStr)); + } + catch + { + idsConsumed = charsWritten = 0; + return OperationStatus.InvalidData; + } + } + + try + { + fixed (char* pDest = destination) + { + charsWritten = Encoding.UTF8.GetChars((byte*)outStr, StringUtils.GetNullTerminatedUtf8Length(outStr), pDest, destination.Length); + idsConsumed = idsArray.Length; + } + } + catch (ArgumentException) + { + idsConsumed = charsWritten = 0; + return OperationStatus.DestinationTooSmall; + } + finally + { + NativeMethods.OgaDestroyString(outStr); + } + + return OperationStatus.Done; } +#endregion } } diff --git a/src/csharp/TokenizerStream.cs b/src/csharp/TokenizerStream.cs index 51cac5a28..8d892e6f4 100644 --- a/src/csharp/TokenizerStream.cs +++ b/src/csharp/TokenizerStream.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntimeGenAI { @@ -22,7 +21,7 @@ public string Decode(int token) { IntPtr decodedStr = IntPtr.Zero; Result.VerifySuccess(NativeMethods.OgaTokenizerStreamDecode(_tokenizerStreamHandle, token, out decodedStr)); - return StringUtils.FromUtf8(decodedStr); + return StringUtils.FromNullTerminatedUtf8(decodedStr); } ~TokenizerStream() diff --git a/src/csharp/Utils.cs b/src/csharp/Utils.cs index 90d007bc7..5fa5e54fe 100644 --- a/src/csharp/Utils.cs +++ b/src/csharp/Utils.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; @@ -53,45 +54,56 @@ public static int GetCurrentGpuDeviceId() public static void SetLogBool(string name, bool value) { - Result.VerifySuccess(NativeMethods.OgaSetLogBool(StringUtils.ToUtf8(name), value)); + Result.VerifySuccess(NativeMethods.OgaSetLogBool(StringUtils.ToNullTerminatedUtf8(name), value)); } public static void SetLogString(string name, string value) { - Result.VerifySuccess(NativeMethods.OgaSetLogString(StringUtils.ToUtf8(name), StringUtils.ToUtf8(value))); + Result.VerifySuccess(NativeMethods.OgaSetLogString(StringUtils.ToNullTerminatedUtf8(name), StringUtils.ToNullTerminatedUtf8(value))); } } internal class StringUtils { - internal static byte[] EmptyByteArray = new byte[] { 0 }; + internal static readonly byte[] EmptyByteArray = [0]; - internal static byte[] ToUtf8(string str) + internal static byte[] ToNullTerminatedUtf8(string str) => ToNullTerminatedUtf8(str.AsSpan()); + + internal static unsafe byte[] ToNullTerminatedUtf8(ReadOnlySpan str) { - if (string.IsNullOrEmpty(str)) + if (str.IsEmpty) return EmptyByteArray; - int arraySize = UTF8Encoding.UTF8.GetByteCount(str); - byte[] utf8Bytes = new byte[arraySize + 1]; - UTF8Encoding.UTF8.GetBytes(str, 0, str.Length, utf8Bytes, 0); - utf8Bytes[utf8Bytes.Length - 1] = 0; - return utf8Bytes; - } - - internal static string FromUtf8(IntPtr nativeUtf8) - { - unsafe + fixed (char* pStr = str) { - int len = 0; - while (*(byte*)(nativeUtf8 + len) != 0) ++len; - - if (len == 0) + int byteCount = Encoding.UTF8.GetByteCount(pStr, str.Length); + + byte[] utf8Bytes = new byte[byteCount + 1]; + fixed (byte* pBytes = utf8Bytes) { - return string.Empty; + Encoding.UTF8.GetBytes(pStr, str.Length, pBytes, byteCount); + pBytes[byteCount] = 0; } - var nativeBytes = (byte*)nativeUtf8; - return Encoding.UTF8.GetString(nativeBytes, len); + + return utf8Bytes; } } + + internal static unsafe string FromNullTerminatedUtf8(IntPtr nativeUtf8) + { + int len = GetNullTerminatedUtf8Length(nativeUtf8); + return len > 0 ? Encoding.UTF8.GetString((byte*)nativeUtf8, len) : string.Empty; + } + + internal static unsafe int GetNullTerminatedUtf8Length(IntPtr nativeUtf8) + { +#if NETCOREAPP + return MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)nativeUtf8).Length; +#else + int len = 0; + while (*(byte*)(nativeUtf8 + len) != 0) ++len; + return len; +#endif + } } }