Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support setting custom alignment heads for dtw #301

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion Whisper.net/Ggml/GgmlType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,16 @@ public enum WhisperAlignmentHeadsPreset
LargeV2,
LargeV3,
LargeV3Turbo
}
}

public class WhisperAlignmentHead
{
public int TextLayer;
public int Head;

public WhisperAlignmentHead(int textLayer, int head)
{
TextLayer = textLayer;
Head = head;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,23 @@ namespace Whisper.net.Internals.ModelLoader;
internal class WhisperProcessorModelBufferLoader(byte[] buffer) : IWhisperProcessorModelLoader
{
private readonly GCHandle pinnedBuffer = GCHandle.Alloc(buffer, GCHandleType.Pinned);
private GCHandle aheadsHandle;

public void Dispose()
{
pinnedBuffer.Free();
if (aheadsHandle.IsAllocated)
{
aheadsHandle.Free();
}
}

public IntPtr LoadNativeContext(INativeWhisper nativeWhisper)
{
var bufferLength = new UIntPtr((uint)buffer.Length);

var aHeads = WhisperProcessorModelFileLoader.GetWhisperAlignmentHeads(RuntimeOptions.Instance.CustomAlignmentHeads, ref aheadsHandle);

return nativeWhisper.Whisper_Init_From_Buffer_With_Params_No_State(pinnedBuffer.AddrOfPinnedObject(), bufferLength,
new WhisperContextParams()
{
Expand All @@ -28,11 +36,7 @@ public IntPtr LoadNativeContext(INativeWhisper nativeWhisper)
DtwTokenLevelTimestamp = RuntimeOptions.Instance.UseDtwTimeStamps ? (byte)1 : (byte)0,
HeadsPreset = (WhisperAlignmentHeadsPreset)RuntimeOptions.Instance.HeadsPreset,
DtwNTop = -1,
WhisperAheads = new WhisperAheads()
{
NHeads = 0,
Heads = IntPtr.Zero
},
WhisperAheads = aHeads,
Dtw_mem_size = 1024 * 1024 * 128,
});
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Licensed under the MIT license: https://opensource.org/licenses/MIT

using System.Runtime.InteropServices;
using Whisper.net.Internals.Native;
using Whisper.net.LibraryLoader;
using Whisper.net.Native;
Expand All @@ -8,13 +9,49 @@ namespace Whisper.net.Internals.ModelLoader;

internal sealed class WhisperProcessorModelFileLoader(string pathModel) : IWhisperProcessorModelLoader
{
private GCHandle aheadsHandle;

public void Dispose()
{
if (aheadsHandle.IsAllocated)
{
aheadsHandle.Free();
}
}

public static WhisperAheads GetWhisperAlignmentHeads(Ggml.WhisperAlignmentHead[]? alignmentHeads, ref GCHandle aHeadsHandle)
{
var aHeadsPtr = IntPtr.Zero;
var nHeads = alignmentHeads?.Length ?? 0;

if (nHeads > 0)
{
var aHeads = new int[nHeads * 2];
if (aHeadsHandle.IsAllocated)
{
aHeadsHandle.Free();
}
aHeadsHandle = GCHandle.Alloc(aHeads, GCHandleType.Pinned);
aHeadsPtr = aHeadsHandle.AddrOfPinnedObject();

for (var i = 0; i < nHeads; i++)
{
aHeads[i * 2] = alignmentHeads![i].TextLayer;
aHeads[i * 2 + 1] = alignmentHeads[i].Head;
}
}

return new WhisperAheads()
{
NHeads = (nuint)nHeads,
Heads = aHeadsPtr
};
}

public IntPtr LoadNativeContext(INativeWhisper nativeWhisper)
{
var aHeads = GetWhisperAlignmentHeads(RuntimeOptions.Instance.CustomAlignmentHeads, ref aheadsHandle);

return nativeWhisper.Whisper_Init_From_File_With_Params_No_State(pathModel,
new WhisperContextParams()
{
Expand All @@ -24,11 +61,7 @@ public IntPtr LoadNativeContext(INativeWhisper nativeWhisper)
DtwTokenLevelTimestamp = RuntimeOptions.Instance.UseDtwTimeStamps ? (byte)1 : (byte)0,
HeadsPreset = (WhisperAlignmentHeadsPreset)RuntimeOptions.Instance.HeadsPreset,
DtwNTop = -1,
WhisperAheads = new WhisperAheads()
{
NHeads = 0,
Heads = IntPtr.Zero
},
WhisperAheads = aHeads,
Dtw_mem_size = 1024 * 1024 * 128,
});
}
Expand Down
14 changes: 14 additions & 0 deletions Whisper.net/LibraryLoader/RuntimeOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public class RuntimeOptions
internal bool UseFlashAttention { get; private set; }
internal bool UseDtwTimeStamps { get; private set; }
internal WhisperAlignmentHeadsPreset HeadsPreset { get; private set; }
internal WhisperAlignmentHead[]? CustomAlignmentHeads { get; private set; }
internal int GpuDevice { get; private set; }
internal List<RuntimeLibrary> RuntimeLibraryOrder { get; private set; }
internal RuntimeLibrary? LoadedLibrary { get; private set; }
Expand All @@ -27,6 +28,7 @@ private RuntimeOptions()
UseFlashAttention = false;
UseDtwTimeStamps = false;
HeadsPreset = WhisperAlignmentHeadsPreset.None;
CustomAlignmentHeads = null;
RuntimeLibraryOrder = defaultRuntimeOrder;
GpuDevice = 0;
}
Expand Down Expand Up @@ -127,6 +129,17 @@ public void SetHeadsPreset(WhisperAlignmentHeadsPreset headsPreset)
HeadsPreset = headsPreset;
}

/// <summary>
/// Sets custom attention heads array for DTW.
/// </summary>
/// <remarks>
/// By default, it is null. Required when using DTW with models which don't have a matching WhisperAlignmentHeadsPreset.
/// </remarks>
public void SetAlignmentHeads(WhisperAlignmentHead[]? alignmentHeads)
{
CustomAlignmentHeads = alignmentHeads;
}

/// <summary>
/// Resets the runtime options to their default values.
/// </summary>
Expand All @@ -138,6 +151,7 @@ public void Reset()
UseFlashAttention = false;
UseDtwTimeStamps = false;
HeadsPreset = WhisperAlignmentHeadsPreset.None;
CustomAlignmentHeads = null;
RuntimeLibraryOrder = defaultRuntimeOrder;
GpuDevice = 0;
}
Expand Down