From 15713445095fc8bedef7087079021b1eb31ed08b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 9 Oct 2024 23:25:39 +0800 Subject: [PATCH] Swift API for speaker diarization (#1404) --- .github/scripts/test-swift.sh | 5 + swift-api-examples/SherpaOnnx.swift | 113 ++++++++++++++++++ swift-api-examples/run-speaker-diarization.sh | 35 ++++++ swift-api-examples/speaker-diarization.swift | 56 +++++++++ 4 files changed, 209 insertions(+) create mode 100755 swift-api-examples/run-speaker-diarization.sh create mode 100644 swift-api-examples/speaker-diarization.swift diff --git a/.github/scripts/test-swift.sh b/.github/scripts/test-swift.sh index 18c9bed41..0da23eb24 100755 --- a/.github/scripts/test-swift.sh +++ b/.github/scripts/test-swift.sh @@ -7,6 +7,11 @@ echo "pwd: $PWD" cd swift-api-examples ls -lh +./run-speaker-diarization.sh +rm -rf *.onnx +rm -rf sherpa-onnx-pyannote-segmentation-3-0 +rm -fv *.wav + ./run-add-punctuations.sh rm ./add-punctuations rm -rf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12 diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index 778bccb9b..881291fd6 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -1078,3 +1078,116 @@ class SherpaOnnxOfflinePunctuationWrapper { return ans } } + +func sherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig(model: String) + -> SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig +{ + return SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig(model: toCPointer(model)) +} + +func sherpaOnnxOfflineSpeakerSegmentationModelConfig( + pyannote: SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig, + numThreads: Int = 1, + debug: Int = 0, + provider: String = "cpu" +) -> SherpaOnnxOfflineSpeakerSegmentationModelConfig { + return SherpaOnnxOfflineSpeakerSegmentationModelConfig( + pyannote: pyannote, + num_threads: Int32(numThreads), + debug: Int32(debug), + provider: toCPointer(provider) + ) +} + +func sherpaOnnxFastClusteringConfig(numClusters: Int = -1, threshold: Float = 0.5) + -> SherpaOnnxFastClusteringConfig +{ + return SherpaOnnxFastClusteringConfig(num_clusters: Int32(numClusters), threshold: threshold) +} + +func sherpaOnnxSpeakerEmbeddingExtractorConfig( + model: String, + numThreads: Int = 1, + debug: Int = 0, + provider: String = "cpu" +) -> SherpaOnnxSpeakerEmbeddingExtractorConfig { + return SherpaOnnxSpeakerEmbeddingExtractorConfig( + model: toCPointer(model), + num_threads: Int32(numThreads), + debug: Int32(debug), + provider: toCPointer(provider) + ) +} + +func sherpaOnnxOfflineSpeakerDiarizationConfig( + segmentation: SherpaOnnxOfflineSpeakerSegmentationModelConfig, + embedding: SherpaOnnxSpeakerEmbeddingExtractorConfig, + clustering: SherpaOnnxFastClusteringConfig, + minDurationOn: Float = 0.3, + minDurationOff: Float = 0.5 +) -> SherpaOnnxOfflineSpeakerDiarizationConfig { + return SherpaOnnxOfflineSpeakerDiarizationConfig( + segmentation: segmentation, + embedding: embedding, + clustering: clustering, + min_duration_on: minDurationOn, + min_duration_off: minDurationOff + ) +} + +struct SherpaOnnxOfflineSpeakerDiarizationSegmentWrapper { + var start: Float = 0 + var end: Float = 0 + var speaker: Int = 0 +} + +class SherpaOnnxOfflineSpeakerDiarizationWrapper { + /// A pointer to the underlying counterpart in C + let impl: OpaquePointer! + + init( + config: UnsafePointer! + ) { + impl = SherpaOnnxCreateOfflineSpeakerDiarization(config) + } + + deinit { + if let impl { + SherpaOnnxDestroyOfflineSpeakerDiarization(impl) + } + } + + var sampleRate: Int { + return Int(SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(impl)) + } + + func process(samples: [Float]) -> [SherpaOnnxOfflineSpeakerDiarizationSegmentWrapper] { + let result = SherpaOnnxOfflineSpeakerDiarizationProcess( + impl, samples, Int32(samples.count)) + + if result == nil { + return [] + } + + let numSegments = Int(SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(result)) + + let p: UnsafePointer? = + SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(result) + + if p == nil { + return [] + } + + var ans: [SherpaOnnxOfflineSpeakerDiarizationSegmentWrapper] = [] + for i in 0.. [Float] { + return Array(UnsafeBufferPointer(self)) + } +} + +extension AVAudioPCMBuffer { + func array() -> [Float] { + return self.audioBufferList.pointee.mBuffers.array() + } +} + +func run() { + let segmentationModel = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx" + let embeddingExtractorModel = "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" + let waveFilename = "./0-four-speakers-zh.wav" + + // There are 4 speakers in ./0-four-speakers-zh.wav, so we use 4 here + let numSpeakers = 4 + var config = sherpaOnnxOfflineSpeakerDiarizationConfig( + segmentation: sherpaOnnxOfflineSpeakerSegmentationModelConfig( + pyannote: sherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig(model: segmentationModel)), + embedding: sherpaOnnxSpeakerEmbeddingExtractorConfig(model: embeddingExtractorModel), + clustering: sherpaOnnxFastClusteringConfig(numClusters: numSpeakers) + ) + + let sd = SherpaOnnxOfflineSpeakerDiarizationWrapper(config: &config) + + let fileURL: NSURL = NSURL(fileURLWithPath: waveFilename) + let audioFile = try! AVAudioFile(forReading: fileURL as URL) + + let audioFormat = audioFile.processingFormat + assert(Int(audioFormat.sampleRate) == sd.sampleRate) + assert(audioFormat.channelCount == 1) + assert(audioFormat.commonFormat == AVAudioCommonFormat.pcmFormatFloat32) + + let audioFrameCount = UInt32(audioFile.length) + let audioFileBuffer = AVAudioPCMBuffer(pcmFormat: audioFormat, frameCapacity: audioFrameCount) + + try! audioFile.read(into: audioFileBuffer!) + let array: [Float]! = audioFileBuffer?.array() + print("Started!") + let segments = sd.process(samples: array) + for i in 0..