diff --git a/src/csharp/Exceptions.cs b/src/csharp/Exceptions.cs
index 3c023894b..8228bf794 100644
--- a/src/csharp/Exceptions.cs
+++ b/src/csharp/Exceptions.cs
@@ -2,7 +2,6 @@
// Licensed under the MIT License.
using System;
-using System.Collections.Generic;
namespace Microsoft.ML.OnnxRuntimeGenAI
{
diff --git a/src/csharp/GeneratorParams.cs b/src/csharp/GeneratorParams.cs
index ac225e21c..644e2c586 100644
--- a/src/csharp/GeneratorParams.cs
+++ b/src/csharp/GeneratorParams.cs
@@ -2,9 +2,6 @@
// Licensed under the MIT License.
using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntimeGenAI
{
@@ -21,12 +18,12 @@ public GeneratorParams(Model model)
public void SetSearchOption(string searchOption, double value)
{
- Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchNumber(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), value));
+ Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchNumber(_generatorParamsHandle, StringUtils.ToNullTerminatedUtf8(searchOption), value));
}
public void SetSearchOption(string searchOption, bool value)
{
- Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchBool(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), value));
+ Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchBool(_generatorParamsHandle, StringUtils.ToNullTerminatedUtf8(searchOption), value));
}
public void TryGraphCaptureWithMaxBatchSize(int maxBatchSize)
@@ -52,7 +49,7 @@ public void SetInputSequences(Sequences sequences)
public void SetModelInput(string name, Tensor value)
{
- Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetModelInput(_generatorParamsHandle, StringUtils.ToUtf8(name), value.Handle));
+ Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetModelInput(_generatorParamsHandle, StringUtils.ToNullTerminatedUtf8(name), value.Handle));
}
public void SetInputs(NamedTensors namedTensors)
diff --git a/src/csharp/Images.cs b/src/csharp/Images.cs
index b160f869e..edf1a7c69 100644
--- a/src/csharp/Images.cs
+++ b/src/csharp/Images.cs
@@ -2,7 +2,6 @@
// Licensed under the MIT License.
using System;
-using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntimeGenAI
{
@@ -23,7 +22,7 @@ public static Images Load(string[] imagePaths)
Result.VerifySuccess(NativeMethods.OgaCreateStringArray(out IntPtr stringArray));
foreach (string imagePath in imagePaths)
{
- Result.VerifySuccess(NativeMethods.OgaStringArrayAddString(stringArray, StringUtils.ToUtf8(imagePath)));
+ Result.VerifySuccess(NativeMethods.OgaStringArrayAddString(stringArray, StringUtils.ToNullTerminatedUtf8(imagePath)));
}
Result.VerifySuccess(NativeMethods.OgaLoadImages(stringArray, out IntPtr imagesHandle));
NativeMethods.OgaDestroyStringArray(stringArray);
diff --git a/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj b/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj
index ee53c83fb..28328c250 100644
--- a/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj
+++ b/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj
@@ -121,4 +121,8 @@
+
+
+
+
diff --git a/src/csharp/Model.cs b/src/csharp/Model.cs
index 675bc2540..e33c8293b 100644
--- a/src/csharp/Model.cs
+++ b/src/csharp/Model.cs
@@ -2,7 +2,6 @@
// Licensed under the MIT License.
using System;
-using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntimeGenAI
{
@@ -13,7 +12,7 @@ public class Model : IDisposable
public Model(string modelPath)
{
- Result.VerifySuccess(NativeMethods.OgaCreateModel(StringUtils.ToUtf8(modelPath), out _modelHandle));
+ Result.VerifySuccess(NativeMethods.OgaCreateModel(StringUtils.ToNullTerminatedUtf8(modelPath), out _modelHandle));
}
internal IntPtr Handle { get { return _modelHandle; } }
diff --git a/src/csharp/MultiModalProcessor.cs b/src/csharp/MultiModalProcessor.cs
index 9eae0eb20..3c06edf8b 100644
--- a/src/csharp/MultiModalProcessor.cs
+++ b/src/csharp/MultiModalProcessor.cs
@@ -2,7 +2,6 @@
// Licensed under the MIT License.
using System;
-using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntimeGenAI
{
@@ -21,7 +20,7 @@ public MultiModalProcessor(Model model)
public NamedTensors ProcessImages(string prompt, Images images)
{
IntPtr imagesHandle = images == null ? IntPtr.Zero : images.Handle;
- Result.VerifySuccess(NativeMethods.OgaProcessorProcessImages(_processorHandle, StringUtils.ToUtf8(prompt),
+ Result.VerifySuccess(NativeMethods.OgaProcessorProcessImages(_processorHandle, StringUtils.ToNullTerminatedUtf8(prompt),
imagesHandle, out IntPtr namedTensorsHandle));
return new NamedTensors(namedTensorsHandle);
}
@@ -38,7 +37,7 @@ public string Decode(ReadOnlySpan sequence)
}
try
{
- return StringUtils.FromUtf8(outStr);
+ return StringUtils.FromNullTerminatedUtf8(outStr);
}
finally
{
diff --git a/src/csharp/Result.cs b/src/csharp/Result.cs
index a6185b907..6d5e67f07 100644
--- a/src/csharp/Result.cs
+++ b/src/csharp/Result.cs
@@ -3,22 +3,25 @@
using System;
using System.Runtime.CompilerServices;
-using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntimeGenAI
{
- class Result
+ internal static class Result
{
- private static string GetErrorMessage(IntPtr nativeResult)
+ internal static string GetErrorMessage(IntPtr nativeResult)
{
-
- return StringUtils.FromUtf8(NativeMethods.OgaResultGetError(nativeResult));
+ return StringUtils.FromNullTerminatedUtf8(NativeMethods.OgaResultGetError(nativeResult));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void VerifySuccess(IntPtr nativeResult)
{
if (nativeResult != IntPtr.Zero)
+ {
+ Throw(nativeResult);
+ }
+
+ static void Throw(IntPtr nativeResult)
{
try
{
diff --git a/src/csharp/Sequences.cs b/src/csharp/Sequences.cs
index 0bbf8da83..efa4ccaf2 100644
--- a/src/csharp/Sequences.cs
+++ b/src/csharp/Sequences.cs
@@ -2,7 +2,6 @@
// Licensed under the MIT License.
using System;
-using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntimeGenAI
{
diff --git a/src/csharp/Tensor.cs b/src/csharp/Tensor.cs
index ae20a1b8b..ed8a7976a 100644
--- a/src/csharp/Tensor.cs
+++ b/src/csharp/Tensor.cs
@@ -2,7 +2,6 @@
// Licensed under the MIT License.
using System;
-using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntimeGenAI
{
diff --git a/src/csharp/Tokenizer.cs b/src/csharp/Tokenizer.cs
index 24fd85ceb..64d4ef31f 100644
--- a/src/csharp/Tokenizer.cs
+++ b/src/csharp/Tokenizer.cs
@@ -2,29 +2,42 @@
// Licensed under the MIT License.
using System;
-using System.Runtime.InteropServices;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
using System.Text;
+using Microsoft.ML.Tokenizers;
namespace Microsoft.ML.OnnxRuntimeGenAI
{
- public class Tokenizer : IDisposable
+ public sealed class Tokenizer : Tokenizers.Tokenizer, IDisposable
{
private IntPtr _tokenizerHandle;
- private bool _disposed = false;
public Tokenizer(Model model)
{
+ if (model is null)
+ {
+ throw new ArgumentNullException(nameof(model));
+ }
+
Result.VerifySuccess(NativeMethods.OgaCreateTokenizer(model.Handle, out _tokenizerHandle));
}
public Sequences EncodeBatch(string[] strings)
{
+ if (strings is null)
+ {
+ throw new ArgumentNullException(nameof(strings));
+ }
+
Result.VerifySuccess(NativeMethods.OgaCreateSequences(out IntPtr nativeSequences));
try
{
foreach (string str in strings)
{
- Result.VerifySuccess(NativeMethods.OgaTokenizerEncode(_tokenizerHandle, StringUtils.ToUtf8(str), nativeSequences));
+ Result.VerifySuccess(NativeMethods.OgaTokenizerEncode(_tokenizerHandle, StringUtils.ToNullTerminatedUtf8(str), nativeSequences));
}
return new Sequences(nativeSequences);
@@ -38,6 +51,11 @@ public Sequences EncodeBatch(string[] strings)
public string[] DecodeBatch(Sequences sequences)
{
+ if (sequences is null)
+ {
+ throw new ArgumentNullException(nameof(sequences));
+ }
+
string[] result = new string[sequences.NumSequences];
for (ulong i = 0; i < sequences.NumSequences; i++)
{
@@ -48,11 +66,21 @@ public string[] DecodeBatch(Sequences sequences)
}
public Sequences Encode(string str)
+ {
+ if (str is null)
+ {
+ throw new ArgumentNullException(nameof(str));
+ }
+
+ return Encode(str.AsSpan());
+ }
+
+ public Sequences Encode(ReadOnlySpan str)
{
Result.VerifySuccess(NativeMethods.OgaCreateSequences(out IntPtr nativeSequences));
try
{
- Result.VerifySuccess(NativeMethods.OgaTokenizerEncode(_tokenizerHandle, StringUtils.ToUtf8(str), nativeSequences));
+ Result.VerifySuccess(NativeMethods.OgaTokenizerEncode(_tokenizerHandle, StringUtils.ToNullTerminatedUtf8(str), nativeSequences));
return new Sequences(nativeSequences);
}
catch
@@ -62,19 +90,18 @@ public Sequences Encode(string str)
}
}
- public string Decode(ReadOnlySpan sequence)
+ public unsafe string Decode(ReadOnlySpan sequence)
{
- IntPtr outStr = IntPtr.Zero;
- unsafe
+ IntPtr outStr;
+
+ fixed (int* sequencePtr = sequence)
{
- fixed (int* sequencePtr = sequence)
- {
- Result.VerifySuccess(NativeMethods.OgaTokenizerDecode(_tokenizerHandle, sequencePtr, (UIntPtr)sequence.Length, out outStr));
- }
+ Result.VerifySuccess(NativeMethods.OgaTokenizerDecode(_tokenizerHandle, sequencePtr, (UIntPtr)sequence.Length, out outStr));
}
+
try
{
- return StringUtils.FromUtf8(outStr);
+ return StringUtils.FromNullTerminatedUtf8(outStr);
}
finally
{
@@ -84,32 +111,151 @@ public string Decode(ReadOnlySpan sequence)
public TokenizerStream CreateStream()
{
- IntPtr tokenizerStreamHandle = IntPtr.Zero;
- Result.VerifySuccess(NativeMethods.OgaCreateTokenizerStream(_tokenizerHandle, out tokenizerStreamHandle));
+ Result.VerifySuccess(NativeMethods.OgaCreateTokenizerStream(_tokenizerHandle, out nint tokenizerStreamHandle));
return new TokenizerStream(tokenizerStreamHandle);
}
-
~Tokenizer()
{
- Dispose(false);
+ Dispose();
}
public void Dispose()
{
- Dispose(true);
+ if (_tokenizerHandle != IntPtr.Zero)
+ {
+ NativeMethods.OgaDestroyTokenizer(_tokenizerHandle);
+ _tokenizerHandle = IntPtr.Zero;
+ }
+
GC.SuppressFinalize(this);
}
- protected virtual void Dispose(bool disposing)
+ #region Base Tokenizer Overrides
+ private static int GetMaxTokenCount(EncodeSettings settings)
+ {
+ if (settings.MaxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(settings.MaxTokenCount), "The maximum number of tokens must be greater than zero.");
+ }
+
+ return settings.MaxTokenCount;
+ }
+
+ ///
+ protected override int CountTokens(string text, ReadOnlySpan textSpan, EncodeSettings settings)
+ {
+ Debug.Assert(text is null || textSpan.SequenceEqual(text.AsSpan()));
+
+ using Sequences sequences = Encode(textSpan);
+ Debug.Assert(sequences.NumSequences == 1);
+
+ return Math.Min(GetMaxTokenCount(settings), sequences[0].Length);
+ }
+
+ ///
+ protected override EncodeResults EncodeToTokens(string text, ReadOnlySpan textSpan, EncodeSettings settings)
+ {
+ Debug.Assert(text is null || textSpan.SequenceEqual(text.AsSpan()));
+
+ int maxTokenCount = GetMaxTokenCount(settings);
+
+ using Sequences sequences = Encode(textSpan);
+ if (sequences.NumSequences != 1)
+ {
+ throw new InvalidOperationException("Expected exactly one sequence.");
+ }
+
+ ReadOnlySpan sequence = sequences[0];
+ if (sequence.Length > maxTokenCount)
+ {
+ sequence = sequence.Slice(0, maxTokenCount);
+ }
+
+ // Only the token IDs are returned. The Sequences doesn't contain offset information about each token.
+ EncodedToken[] tokens = new EncodedToken[sequence.Length];
+ for (int i = 0; i < sequence.Length; i++)
+ {
+ tokens[i] = new EncodedToken(sequence[i], string.Empty, default);
+ }
+
+ return new EncodeResults() { Tokens = tokens };
+ }
+
+ ///
+ protected override EncodeResults EncodeToIds(string text, ReadOnlySpan textSpan, EncodeSettings settings)
+ {
+ Debug.Assert(text is null || textSpan.SequenceEqual(text.AsSpan()));
+
+ int maxTokenCount = GetMaxTokenCount(settings);
+
+ using Sequences sequences = Encode(textSpan);
+ Debug.Assert(sequences.NumSequences == 1);
+
+ ReadOnlySpan sequence = sequences[0];
+ if (sequence.Length > maxTokenCount)
+ {
+ sequence = sequence.Slice(0, maxTokenCount);
+ }
+
+ return new EncodeResults() { Tokens = sequence.ToArray() };
+ }
+
+ ///
+ public override string Decode(IEnumerable ids)
+ {
+ if (ids is null)
+ {
+ throw new ArgumentNullException(nameof(ids));
+ }
+
+ return Decode(ids as int[] ?? ids.ToArray());
+ }
+
+ ///
+ public override unsafe OperationStatus Decode(IEnumerable ids, Span destination, out int idsConsumed, out int charsWritten)
{
- if (_disposed)
+ if (ids is null)
{
- return;
+ throw new ArgumentNullException(nameof(ids));
}
- NativeMethods.OgaDestroyTokenizer(_tokenizerHandle);
- _tokenizerHandle = IntPtr.Zero;
- _disposed = true;
+
+ IntPtr outStr;
+
+ int[] idsArray = ids as int[] ?? ids.ToArray();
+ fixed (int* sequencePtr = idsArray)
+ {
+ try
+ {
+ Result.VerifySuccess(NativeMethods.OgaTokenizerDecode(_tokenizerHandle, sequencePtr, (UIntPtr)idsArray.Length, out outStr));
+ }
+ catch
+ {
+ idsConsumed = charsWritten = 0;
+ return OperationStatus.InvalidData;
+ }
+ }
+
+ try
+ {
+ fixed (char* pDest = destination)
+ {
+ charsWritten = Encoding.UTF8.GetChars((byte*)outStr, StringUtils.GetNullTerminatedUtf8Length(outStr), pDest, destination.Length);
+ idsConsumed = idsArray.Length;
+ }
+ }
+ catch (ArgumentException)
+ {
+ idsConsumed = charsWritten = 0;
+ return OperationStatus.DestinationTooSmall;
+ }
+ finally
+ {
+ NativeMethods.OgaDestroyString(outStr);
+ }
+
+ return OperationStatus.Done;
}
+#endregion
}
}
diff --git a/src/csharp/TokenizerStream.cs b/src/csharp/TokenizerStream.cs
index 51cac5a28..8d892e6f4 100644
--- a/src/csharp/TokenizerStream.cs
+++ b/src/csharp/TokenizerStream.cs
@@ -2,7 +2,6 @@
// Licensed under the MIT License.
using System;
-using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntimeGenAI
{
@@ -22,7 +21,7 @@ public string Decode(int token)
{
IntPtr decodedStr = IntPtr.Zero;
Result.VerifySuccess(NativeMethods.OgaTokenizerStreamDecode(_tokenizerStreamHandle, token, out decodedStr));
- return StringUtils.FromUtf8(decodedStr);
+ return StringUtils.FromNullTerminatedUtf8(decodedStr);
}
~TokenizerStream()
diff --git a/src/csharp/Utils.cs b/src/csharp/Utils.cs
index 90d007bc7..5fa5e54fe 100644
--- a/src/csharp/Utils.cs
+++ b/src/csharp/Utils.cs
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
using System;
+using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
@@ -53,45 +54,56 @@ public static int GetCurrentGpuDeviceId()
public static void SetLogBool(string name, bool value)
{
- Result.VerifySuccess(NativeMethods.OgaSetLogBool(StringUtils.ToUtf8(name), value));
+ Result.VerifySuccess(NativeMethods.OgaSetLogBool(StringUtils.ToNullTerminatedUtf8(name), value));
}
public static void SetLogString(string name, string value)
{
- Result.VerifySuccess(NativeMethods.OgaSetLogString(StringUtils.ToUtf8(name), StringUtils.ToUtf8(value)));
+ Result.VerifySuccess(NativeMethods.OgaSetLogString(StringUtils.ToNullTerminatedUtf8(name), StringUtils.ToNullTerminatedUtf8(value)));
}
}
internal class StringUtils
{
- internal static byte[] EmptyByteArray = new byte[] { 0 };
+ internal static readonly byte[] EmptyByteArray = [0];
- internal static byte[] ToUtf8(string str)
+ internal static byte[] ToNullTerminatedUtf8(string str) => ToNullTerminatedUtf8(str.AsSpan());
+
+ internal static unsafe byte[] ToNullTerminatedUtf8(ReadOnlySpan str)
{
- if (string.IsNullOrEmpty(str))
+ if (str.IsEmpty)
return EmptyByteArray;
- int arraySize = UTF8Encoding.UTF8.GetByteCount(str);
- byte[] utf8Bytes = new byte[arraySize + 1];
- UTF8Encoding.UTF8.GetBytes(str, 0, str.Length, utf8Bytes, 0);
- utf8Bytes[utf8Bytes.Length - 1] = 0;
- return utf8Bytes;
- }
-
- internal static string FromUtf8(IntPtr nativeUtf8)
- {
- unsafe
+ fixed (char* pStr = str)
{
- int len = 0;
- while (*(byte*)(nativeUtf8 + len) != 0) ++len;
-
- if (len == 0)
+ int byteCount = Encoding.UTF8.GetByteCount(pStr, str.Length);
+
+ byte[] utf8Bytes = new byte[byteCount + 1];
+ fixed (byte* pBytes = utf8Bytes)
{
- return string.Empty;
+ Encoding.UTF8.GetBytes(pStr, str.Length, pBytes, byteCount);
+ pBytes[byteCount] = 0;
}
- var nativeBytes = (byte*)nativeUtf8;
- return Encoding.UTF8.GetString(nativeBytes, len);
+
+ return utf8Bytes;
}
}
+
+ internal static unsafe string FromNullTerminatedUtf8(IntPtr nativeUtf8)
+ {
+ int len = GetNullTerminatedUtf8Length(nativeUtf8);
+ return len > 0 ? Encoding.UTF8.GetString((byte*)nativeUtf8, len) : string.Empty;
+ }
+
+ internal static unsafe int GetNullTerminatedUtf8Length(IntPtr nativeUtf8)
+ {
+#if NETCOREAPP
+ return MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)nativeUtf8).Length;
+#else
+ int len = 0;
+ while (*(byte*)(nativeUtf8 + len) != 0) ++len;
+ return len;
+#endif
+ }
}
}