Skip to content

Commit

Permalink
add cuda support to whisper models (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
ling0322 authored Aug 2, 2024
1 parent da716b3 commit 5c94971
Show file tree
Hide file tree
Showing 31 changed files with 863 additions and 71 deletions.
10 changes: 10 additions & 0 deletions go/bin/transcribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import (
"flag"
"fmt"
"log"
"log/slog"
"os"
"time"

"github.com/ling0322/libllm/go/skill"
)
Expand Down Expand Up @@ -67,6 +69,7 @@ func transcribeMain(args []string) {
}
defer fd.Close()

d0 := time.Now()
transcriber := skill.NewWhisperTranscriber(model, fd)
for transcriber.Transcribe() {
r := transcriber.Result()
Expand All @@ -76,4 +79,11 @@ func transcribeMain(args []string) {
if err = transcriber.Err(); err != nil {
log.Fatal(err)
}

processingTime := time.Since(d0)
slog.Info(
fmt.Sprintf("processed %s audio in %s, rtf=%.3f",
transcriber.Offset(),
processingTime.Round(time.Millisecond),
processingTime.Seconds()/transcriber.Offset().Seconds()))
}
1 change: 1 addition & 0 deletions go/skill/transcriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type Transcriber interface {
Transcribe() bool
Result() TranscriptionResult
Err() error
Offset() time.Duration
}

func (r *TranscriptionResult) String() string {
Expand Down
7 changes: 6 additions & 1 deletion go/skill/whisper.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,14 @@ func (w *WhisperTranscriber) parseLanguageToken(token string) (lang string, ok b
// prefill audio and prompt when in the begining of decoding or last audio segment finished. If no
// transcriotion result or <|nospeech|> got, return ErrNoMoreResults.
func (w *WhisperTranscriber) prefillNextAudioSegment() error {
slog.Debug("prefill segment", "offset", w.waveOffset)
nsPerSample := 1000000000 / SampleRate
sampleOffset := int(w.waveOffset.Nanoseconds() / int64(nsPerSample))
byteOffset := sampleOffset * 2
if len(w.wavePayload)-byteOffset < SampleRate/10 {
// ignore the last segment that less than 0.1s.
return ErrAudioEndOfStream
}
slog.Debug("prefill segment", "offset", w.waveOffset, "byteOffset", byteOffset)

nBytes := min(len(w.wavePayload)-byteOffset, 30*SampleRate*2)
audio := w.wavePayload[byteOffset : byteOffset+nBytes]
Expand Down Expand Up @@ -373,3 +373,8 @@ func (w *WhisperTranscriber) Result() TranscriptionResult {
func (w *WhisperTranscriber) Err() error {
return w.err
}

// implements interface Transcriber.
func (w *WhisperTranscriber) Offset() time.Duration {
return w.waveOffset
}
6 changes: 5 additions & 1 deletion src/libllm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ if (WITH_CUDA)
"cuda/cuda_operators.cc"
"cuda/cuda_tensor_data.cc"
"cuda/dequant.cu"
"cuda/fill.cu"
"cuda/gelu.cu"
"cuda/layer_norm.cu"
"cuda/lookup.cu"
"cuda/matmul.cc"
"cuda/matvec.cu"
Expand All @@ -129,7 +132,8 @@ if (WITH_CUDA)
"cuda/softmax.cu"
"cuda/swiglu.cu"
"cuda/to_device.cc"
"cuda/transform.cu")
"cuda/transform.cu"
"cuda/unfold.cu")

set(unittest_SOURCES ${unittest_SOURCES} "cuda/test.cc")

Expand Down
23 changes: 15 additions & 8 deletions src/libllm/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// restriction, including without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
//
// The above copyright notice and this permission notice shall be included in all copies or
// substantial portions of the Software.
//
Expand All @@ -19,10 +19,10 @@

#include "libllm/context.h"

#include "libllm/lut/error.h"
#include "libllm/lut/strings.h"
#include "libllm/device.h"
#include "libllm/lut/error.h"
#include "libllm/lut/log.h"
#include "libllm/lut/strings.h"

namespace libllm {

Expand All @@ -33,19 +33,26 @@ Context Context::getCpu() {
return ctx;
}

Context::Context() : _floatType(DType::kFloat) {}
Context::Context()
: _floatType(DType::kFloat),
_debug(false) {
}

Context Context::withName(const std::string &name) const {
CHECK(!name.empty());
Context ctx;
ctx._device = _device;
ctx._floatType = _floatType;
ctx._propertyBag = _propertyBag;
Context ctx = *this;
ctx._ns = this->name(name);

return ctx;
}

Context Context::withDebugMode(bool debugMode) const {
Context ctx = *this;
ctx._debug = debugMode;

return ctx;
}

const Device &Context::getDevice() const {
return _device;
}
Expand Down
27 changes: 21 additions & 6 deletions src/libllm/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// restriction, including without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
//
// The above copyright notice and this permission notice shall be included in all copies or
// substantial portions of the Software.
//
Expand All @@ -21,6 +21,7 @@

#include <map>
#include <string>

#include "libllm/device.h"
#include "libllm/dtype.h"

Expand All @@ -43,25 +44,39 @@ class Context {
// get name under the namespace of this context. If no parameter given, return the name of the
// context itself.
std::string name(const std::string &name) const;
std::string name() const { return _ns; }
std::string name() const {
return _ns;
}

// device.
const Device &getDevice() const;
void setDevice(const Device &device) { _device = device; }
const Device &getDevice() const;
void setDevice(const Device &device) {
_device = device;
}

// default float type.
DType getFloatDType() const { return _floatType; }
void setFloatDType(DType dtype) { _floatType = dtype; }
DType getFloatDType() const {
return _floatType;
}
void setFloatDType(DType dtype) {
_floatType = dtype;
}

/// Get or set value from the k-v store.
std::string get(const std::string &key) const;
void set(const std::string &key, const std::string &value);

Context withDebugMode(bool debugMode) const;
bool getDebugMode() const {
return _debug;
}

private:
std::string _ns;
std::map<std::string, std::string> _propertyBag;
Device _device;
DType _floatType;
bool _debug;
};

} // namespace libllm
34 changes: 17 additions & 17 deletions src/libllm/cpu/cpu_operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,33 +42,33 @@ class CPUOperators : public Operators {
static std::unique_ptr<Operators> createFp32Only();

// implement interface Operators
Tensor applyRotaryPosEmb(Tensor A, Tensor roPE) override;
Tensor add(Tensor a, Tensor b) override;
bool allClose(Tensor A, Tensor B, float rtol, float atol) override;
Tensor cast(Tensor tensor, DType dtype) override;
Tensor causalMask(int max_len) override;
void copy(Tensor src, Tensor dest) override;
void fill(Tensor input, float value) override;
Tensor gelu(Tensor input) override;
Tensor layerNorm(Tensor input, Tensor weight, Tensor bias, float eps) override;
Tensor logMelSpectrogram(Tensor wave) override;
Tensor lookup(Tensor table, Tensor indices) override;
Tensor matmul(Tensor a, Tensor b) override;
Tensor layerNorm(Tensor input, Tensor weight, Tensor bias, float eps) override;
Tensor max(Tensor inputs) override;
Tensor mul(Tensor input, float other) override;
Tensor mul(Tensor input, Tensor other) override;
void print(Tensor tensor) override;
Tensor rand(lut::Span<const int> shape, DType dtype, lut::Random *generator, float min, float max)
override;
Tensor rmsNorm(Tensor input, Tensor weight, float eps) override;
Tensor softmax(Tensor input) override;
Tensor gelu(Tensor input) override;
void fill(Tensor input, float value) override;
Tensor add(Tensor a, Tensor b) override;
Tensor sum(Tensor inputs) override;
Tensor max(Tensor inputs) override;
Tensor swiglu(Tensor A) override;
Tensor tensor(lut::Span<const int> shape, DType dtype) override;
Tensor tensorLike(Tensor input) override;
Tensor zeros(lut::Span<const int> shape, DType dtype) override;
bool allClose(Tensor A, Tensor B, float rtol, float atol) override;
void print(Tensor tensor) override;
Tensor rmsNorm(Tensor input, Tensor weight, float eps) override;
Tensor causalMask(int max_len) override;
Tensor applyRotaryPosEmb(Tensor A, Tensor roPE) override;
void copy(Tensor src, Tensor dest) override;
Tensor swiglu(Tensor A) override;
Tensor to(Device device, Tensor tensor) override;
Tensor cast(Tensor tensor, DType dtype) override;
Tensor logMelSpectrogram(Tensor wave) override;
Tensor unfold(Tensor input, int kernelSize, int stride) override;
Tensor rand(lut::Span<const int> shape, DType dtype, lut::Random *generator, float min, float max)
override;
Tensor zeros(lut::Span<const int> shape, DType dtype) override;

DType getDefaultFloatType() override;

Expand Down
5 changes: 5 additions & 0 deletions src/libllm/cpu/kernel/interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,11 @@ void convertFloatToHalf(int n, const float *x, Float16 *y, Mode mode, CpuMathBac
#if LUT_CPU_ARCH == LUT_AARCH64
} else if (backendType == CpuMathBackend::ASIMDHP && mode == Mode::OMP) {
cvt<float, Float16, CpuMathBackend::ASIMDHP, Mode::OMP>(n, x, 0, y, 0);
#elif LUT_CPU_ARCH == LUT_AMD64
} else if (backendType == CpuMathBackend::AVX2 && mode == Mode::OMP) {
cvt<float, Float16, CpuMathBackend::FALLBACK, Mode::OMP>(n, x, 0, y, 0);
} else if (backendType == CpuMathBackend::AVX512 && mode == Mode::OMP) {
cvt<float, Float16, CpuMathBackend::FALLBACK, Mode::OMP>(n, x, 0, y, 0);
#endif
} else {
NOT_IMPL();
Expand Down
13 changes: 6 additions & 7 deletions src/libllm/cuda/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// restriction, including without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
//
// The above copyright notice and this permission notice shall be included in all copies or
// substantial portions of the Software.
//
Expand All @@ -25,12 +25,11 @@ namespace libllm {
namespace op {
namespace cuda {

Tensor castFloatToHalf(const Tensor &tensor, DType dtype);
Tensor castHalfToFloat(const Tensor &tensor, DType dtype);
Tensor castFloatToHalf(const Tensor &tensor);
Tensor castHalfToFloat(const Tensor &tensor);

Tensor cast(const Tensor &tensor, DType dtype);

} // cuda
} // op
} // ly

} // namespace cuda
} // namespace op
} // namespace libllm
37 changes: 37 additions & 0 deletions src/libllm/cuda/cuda_operators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,19 @@
#include "libllm/cuda/cast.h"
#include "libllm/cuda/causal_mask.h"
#include "libllm/cuda/copy.h"
#include "libllm/cuda/fill.h"
#include "libllm/cuda/gelu.h"
#include "libllm/cuda/layer_norm.h"
#include "libllm/cuda/lookup.h"
#include "libllm/cuda/matmul.h"
#include "libllm/cuda/print.h"
#include "libllm/cuda/reduce.h"
#include "libllm/cuda/rms_norm.h"
#include "libllm/cuda/softmax.h"
#include "libllm/cuda/swiglu.h"
#include "libllm/cuda/to_device.h"
#include "libllm/cuda/transform.h"
#include "libllm/cuda/unfold.h"
#include "libllm/functional.h"

namespace libllm {
Expand All @@ -57,6 +62,23 @@ Operators *CudaOperators::create() {
return op.release();
}

void CudaOperators::fill(Tensor input, float value) {
return op::cuda::fill(input, value);
}

Tensor CudaOperators::gelu(Tensor input) {
return op::cuda::gelu(input);
}

Tensor CudaOperators::max(Tensor inputs) {
return op::cuda::reduce(inputs, MapReduceType::MAX);
}

Tensor CudaOperators::sum(Tensor inputs) {
Tensor A = op::cuda::reduce(inputs, MapReduceType::SUM_FP16_FP32);
return castFloatToHalf(A);
}

Tensor CudaOperators::lookup(Tensor table, Tensor indices) {
return cuda::lookup(table, indices);
}
Expand Down Expand Up @@ -85,6 +107,10 @@ Tensor CudaOperators::rmsNorm(Tensor input, Tensor weight, float eps) {
return op::cuda::rmsNorm(input, weight, eps);
}

Tensor CudaOperators::layerNorm(Tensor input, Tensor weight, Tensor bias, float eps) {
return op::cuda::layerNorm(input, weight, bias, eps);
}

Tensor CudaOperators::causalMask(int max_len) {
return op::cuda::causalMask(max_len);
}
Expand All @@ -99,6 +125,10 @@ Tensor CudaOperators::tensor(lut::Span<const int> shape, DType dtype) {
NOT_IMPL();
}

Tensor CudaOperators::unfold(Tensor input, int kernelSize, int stride) {
return op::cuda::unfold(input, kernelSize, stride);
}

Tensor CudaOperators::tensorLike(Tensor input) {
CHECK(input.getDevice().getType() == Device::kCuda);

Expand Down Expand Up @@ -142,6 +172,13 @@ DType CudaOperators::getDefaultFloatType() {
return DType::kFloat16;
}

Tensor CudaOperators::zeros(lut::Span<const int> shape, DType dtype) {
Tensor tensor = createCudaTensorHalf(shape);
op::cuda::fill(tensor, 0.0);

return tensor;
}

} // namespace cuda
} // namespace op
} // namespace libllm
Expand Down
Loading

0 comments on commit 5c94971

Please sign in to comment.