diff --git a/cmake/global_variables.cmake b/cmake/global_variables.cmake index 9145bdb3c..b05731ee5 100644 --- a/cmake/global_variables.cmake +++ b/cmake/global_variables.cmake @@ -66,6 +66,10 @@ file(GLOB generator_srcs CONFIGURE_DEPENDS "${GENERATORS_ROOT}/*.cpp" "${GENERATORS_ROOT}/cpu/*.h" "${GENERATORS_ROOT}/cpu/*.cpp" + "${GENERATORS_ROOT}/qnn/*.h" + "${GENERATORS_ROOT}/qnn/*.cpp" + "${GENERATORS_ROOT}/webgpu/*.h" + "${GENERATORS_ROOT}/webgpu/*.cpp" "${MODELS_ROOT}/*.h" "${MODELS_ROOT}/*.cpp" ) diff --git a/src/beam_search_scorer.cpp b/src/beam_search_scorer.cpp index d2f056038..b9cbceffe 100644 --- a/src/beam_search_scorer.cpp +++ b/src/beam_search_scorer.cpp @@ -67,7 +67,7 @@ BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters) // Space to store intermediate sequence size_t const per_beam = (max_length_ * (max_length_ + 1)) / 2; - hypothesis_buffer_ = device.Allocate(batch_beam_size * per_beam, true); + hypothesis_buffer_ = device.Allocate(batch_beam_size * per_beam); memset(next_beam_scores_.Span().data(), 0, next_beam_scores_.Span().size_bytes()); diff --git a/src/cpu/interface.cpp b/src/cpu/interface.cpp index 420ba8f74..93bbcc6f1 100644 --- a/src/cpu/interface.cpp +++ b/src/cpu/interface.cpp @@ -3,16 +3,18 @@ #include "../generators.h" #include "../search.h" +#include "../models/utils.h" #include "interface.h" namespace Generators { +static Ort::Allocator* ort_allocator_{}; const char* label_cpu = "cpu"; struct CpuMemory final : DeviceBuffer { CpuMemory(size_t size) : owned_{true} { size_in_bytes_ = size; - p_cpu_ = p_device_ = new uint8_t[size_in_bytes_]; + p_cpu_ = p_device_ = static_cast(ort_allocator_->Alloc(size_in_bytes_)); } CpuMemory(void* p, size_t size) : owned_{false} { @@ -22,7 +24,7 @@ struct CpuMemory final : DeviceBuffer { ~CpuMemory() override { if (owned_) - delete[] p_device_; + ort_allocator_->Free(p_device_); } const char* GetType() const override { return label_cpu; } @@ -30,18 +32,32 @@ struct CpuMemory final : DeviceBuffer { void CopyDeviceToCpu() override {} // Nothing to do, device is also CPU void CopyCpuToDevice() override {} // Nothing to do, device is also CPU void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override { - if (GetType() == label_cpu) - memcpy(p_device_ + begin_dest, source.p_device_ + begin_source, size_in_bytes); - else - throw std::runtime_error("CpuMemory::CopyFromDevice not implemented for " + std::string(source.GetType())); + CopyThroughCpu(*this, begin_dest, source, begin_source, size_in_bytes); + } + + void Zero() override { + memset(p_device_, 0, size_in_bytes_); } bool owned_; }; struct CpuInterface : DeviceInterface { - std::shared_ptr AllocateBase(size_t size, bool cpu_accessible) override { - // cpu_accessible is ignored, as with the cpu, the device is also the cpu + CpuInterface() { + } + + DeviceType GetType() const override { return DeviceType::CPU; } + + void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override { + assert(!ort_allocator_); + ort_allocator_ = &allocator; + } + + Ort::Allocator& GetAllocator() override { + return *ort_allocator_; + } + + std::shared_ptr AllocateBase(size_t size) override { return std::make_shared(size); } @@ -49,12 +65,48 @@ struct CpuInterface : DeviceInterface { return std::make_shared(p, size); } + bool Cast(OrtValue& input, OrtValue& output) override { + auto input_info = input.GetTensorTypeAndShapeInfo(); + auto output_info = output.GetTensorTypeAndShapeInfo(); + + auto input_type = input_info->GetElementType(); + auto output_type = output_info->GetElementType(); + + auto element_count = input_info->GetElementCount(); + if (element_count != output_info->GetElementCount()) + throw std::runtime_error("Cast - input and output element counts do not match"); + if (input_type == output_type) + throw std::runtime_error("Cast - input and output types are the same"); + + if (input_type == Ort::TypeToTensorType && output_type == Ort::TypeToTensorType) { + auto* fp32 = input.GetTensorData(); + auto* fp16 = output.GetTensorMutableData(); + for (size_t i = 0; i < element_count; i++) + fp16[i] = FastFloat32ToFloat16(fp32[i]); + } else if (input_type == Ort::TypeToTensorType && output_type == Ort::TypeToTensorType) { + auto* fp16 = input.GetTensorData(); + auto* fp32 = output.GetTensorMutableData(); + for (size_t i = 0; i < element_count; i++) + fp32[i] = FastFloat16ToFloat32(fp16[i]); + } else if (input_type == Ort::TypeToTensorType && output_type == Ort::TypeToTensorType) { + auto* input_data = input.GetTensorData(); + auto* output_data = output.GetTensorMutableData(); + for (size_t i = 0; i < element_count; i++) + output_data[i] = input_data[i]; + } else + throw std::runtime_error("Cast - Unimplemented cast"); + return true; + } + std::unique_ptr CreateGreedy(const GeneratorParams& params) override { return std::make_unique(params); } std::unique_ptr CreateBeam(const GeneratorParams& params) override { return std::make_unique(params); } void Synchronize() override {} // Nothing to do as CPU is always in sync with itself -} g_cpu; +}; -DeviceInterface* GetCpuInterface() { return &g_cpu; } +DeviceInterface* GetCpuInterface() { + static std::unique_ptr g_cpu = std::make_unique(); + return g_cpu.get(); +} } // namespace Generators diff --git a/src/cuda/beam_search_scorer_cuda.cpp b/src/cuda/beam_search_scorer_cuda.cpp index a321af4d7..a59bc80bc 100644 --- a/src/cuda/beam_search_scorer_cuda.cpp +++ b/src/cuda/beam_search_scorer_cuda.cpp @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include "generators.h" #include "search.h" #include "search_cuda.h" @@ -8,7 +11,7 @@ namespace Generators { BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters) - : stream_{parameters.cuda_stream} { + : stream_{GetStream()} { state_cpu_ = CudaMallocHostArray(1); state_cpu_->batch_size_ = static_cast(parameters.search.batch_size); state_cpu_->num_beams_ = static_cast(parameters.search.num_beams); diff --git a/src/cuda/beam_search_scorer_cuda.cu b/src/cuda/beam_search_scorer_cuda.cu index fa9148dcc..3add21bc1 100644 --- a/src/cuda/beam_search_scorer_cuda.cu +++ b/src/cuda/beam_search_scorer_cuda.cu @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include #include diff --git a/src/cuda/beam_search_scorer_cuda.cuh b/src/cuda/beam_search_scorer_cuda.cuh index 68be19fee..7a8834b69 100644 --- a/src/cuda/beam_search_scorer_cuda.cuh +++ b/src/cuda/beam_search_scorer_cuda.cuh @@ -1,3 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "models/onnxruntime_api.h" #include "smartptrs.h" namespace Generators { diff --git a/src/cuda/beam_search_scorer_cuda.h b/src/cuda/beam_search_scorer_cuda.h index 8b23a4225..7ec208485 100644 --- a/src/cuda/beam_search_scorer_cuda.h +++ b/src/cuda/beam_search_scorer_cuda.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + namespace Generators { struct BeamSearchScorer_Cuda { diff --git a/src/cuda/beam_search_topk.cu b/src/cuda/beam_search_topk.cu index 222561ce8..32da76fa2 100644 --- a/src/cuda/beam_search_topk.cu +++ b/src/cuda/beam_search_topk.cu @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include #include diff --git a/src/cuda/cuda_sampling.cu b/src/cuda/cuda_sampling.cu index 8b5720479..f6be0b623 100644 --- a/src/cuda/cuda_sampling.cu +++ b/src/cuda/cuda_sampling.cu @@ -8,6 +8,7 @@ #include "span.h" #include "beam_search_topk.h" #include "cuda_sampling.cuh" +#include "models/onnxruntime_api.h" #include "smartptrs.h" #include #include @@ -297,22 +298,22 @@ __global__ void SoftmaxBlockForward(outscalar_t* output, scalar_t* input, int cl } template -void DispatchBlockwiseSoftmaxForward(cudaStream_t* stream, float* output, const float* input, int softmax_elements, +void DispatchBlockwiseSoftmaxForward(cudaStream_t stream, float* output, const float* input, int softmax_elements, int input_stride, int output_stride, int batch_count, float temperature) { dim3 grid(batch_count); constexpr int ILP = sizeof(float4) / sizeof(float); dim3 block = SoftmaxGetBlockSize(ILP, softmax_elements); if (is_log_softmax) { SoftmaxBlockForward - <<>>(output, const_cast(input), + <<>>(output, const_cast(input), softmax_elements, input_stride, output_stride, temperature); } else { SoftmaxBlockForward - <<>>(output, const_cast(input), + <<>>(output, const_cast(input), softmax_elements, input_stride, output_stride, temperature); } } -template void DispatchBlockwiseSoftmaxForward(cudaStream_t*, float*, const float*, int, int, int, int, float); +template void DispatchBlockwiseSoftmaxForward(cudaStream_t, float*, const float*, int, int, int, int, float); // Populate Kernels and Launchers @@ -521,7 +522,7 @@ void LaunchSampleKernel(SamplingData* data, cudaStream_t stream, float* scores, void SoftmaxAndSort(SamplingData* data, cudaStream_t stream, float* scores_in, float* scores_out, int* indices_out, int vocab_size, int batch_size, float temperature) { // Softmax scores std::span scores{data->scores_softmaxed.get(), static_cast(vocab_size * batch_size)}; - DispatchBlockwiseSoftmaxForward(&stream, scores.data(), const_cast(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature); + DispatchBlockwiseSoftmaxForward(stream, scores.data(), const_cast(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature); // Sort indices by scores std::span offsets_gpu{data->offsets.get(), static_cast(batch_size + 1)}; LaunchPopulateOffsets(offsets_gpu.data(), vocab_size, batch_size, stream); @@ -550,7 +551,7 @@ void LaunchGetTopKSubsetFullSort(SamplingData* data, cudaStream_t stream, float* void GetTopKSubset(SamplingData* data, cudaStream_t stream, float* scores_in, float* scores_out, int* indices_out, int vocab_size, int batch_size, int k, float temperature) { // Softmax scores std::span scores_softmaxed{data->scores_softmaxed.get(), static_cast(vocab_size * batch_size)}; - DispatchBlockwiseSoftmaxForward(&stream, scores_softmaxed.data(), const_cast(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature); + DispatchBlockwiseSoftmaxForward(stream, scores_softmaxed.data(), const_cast(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature); // Get top k subset #define GetTopK(max_k) \ LaunchGetTopKSubset(stream, \ diff --git a/src/cuda/cuda_sampling.cuh b/src/cuda/cuda_sampling.cuh index e6e0f184f..529d2e65a 100644 --- a/src/cuda/cuda_sampling.cuh +++ b/src/cuda/cuda_sampling.cuh @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + #include #include "cuda_common.h" #include @@ -25,7 +26,7 @@ void LaunchPopulateIndices(int* indices, int size, int batch_size, cudaStream_t void GetSample(SamplingData* data, cudaStream_t stream, int32_t* d_next_token, float* d_scores, int vocab_size, int batch_size, int k, float p, float temperature); template -void DispatchBlockwiseSoftmaxForward(cudaStream_t* stream, float* output, const float* input, int softmax_elements, int input_stride, int output_stride, int batch_count, float temperature = 1.0); +void DispatchBlockwiseSoftmaxForward(cudaStream_t stream, float* output, const float* input, int softmax_elements, int input_stride, int output_stride, int batch_count, float temperature = 1.0); } // namespace cuda } // namespace Generators \ No newline at end of file diff --git a/src/cuda/interface.cpp b/src/cuda/interface.cpp index b225604f1..bf21d48c3 100644 --- a/src/cuda/interface.cpp +++ b/src/cuda/interface.cpp @@ -6,43 +6,22 @@ #include "interface.h" #include "../search.h" #include "search_cuda.h" -#include "../models/kernels.h" +#include "kernels.h" #include namespace Generators { -const char* label_cuda = "cuda"; -const char* label_cuda_cpu = "cuda_cpu"; - -struct HostMemory final : DeviceBuffer { - HostMemory(size_t size) { - size_in_bytes_ = size; - ::cudaHostAlloc(&p_device_, size, 0); - p_cpu_ = p_device_; // CPU & GPU both access the same memory here - } - - ~HostMemory() override { - ::cudaFreeHost(p_device_); - } +GenaiInterface* gp_genai{}; +Ort::Allocator* ort_allocator_{}; +const char* device_label = "cuda"; - const char* GetType() const override { return label_cuda_cpu; } - void AllocateCpu() override {} // Nothing to do, device is also CPU - void CopyDeviceToCpu() override {} // Nothing to do, device is also CPU - void CopyCpuToDevice() override {} // Nothing to do, device is also CPU - void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override { - if (source.GetType() == label_cuda_cpu) - ::memcpy(p_cpu_ + begin_dest, source.p_cpu_ + begin_source, size_in_bytes); - else if (source.GetType() == label_cuda) - ::cudaMemcpyAsync(p_device_ + begin_dest, source.p_device_ + begin_source, size_in_bytes, ::cudaMemcpyDeviceToHost, GetStream()); - else - throw std::runtime_error("Cuda HostMemory::CopyFromDevice not implemented for " + std::string(source.GetType())); - } -}; +cuda_stream_holder g_stream; +cudaStream_t GetStream() { return g_stream.get(); } struct GpuMemory final : DeviceBuffer { GpuMemory(size_t size) : owned_{true} { size_in_bytes_ = size; - ::cudaMalloc(&p_device_, size); + p_device_ = static_cast(ort_allocator_->Alloc(size)); } GpuMemory(void* p, size_t size) : owned_{false} { @@ -52,12 +31,12 @@ struct GpuMemory final : DeviceBuffer { ~GpuMemory() override { if (owned_) - ::cudaFree(p_device_); + ort_allocator_->Free(p_device_); if (p_cpu_) ::cudaFreeHost(p_cpu_); } - const char* GetType() const override { return label_cuda; } + const char* GetType() const override { return device_label; } void AllocateCpu() override { if (!p_cpu_) @@ -66,37 +45,49 @@ struct GpuMemory final : DeviceBuffer { void CopyDeviceToCpu() override { AllocateCpu(); - ::cudaMemcpy(p_cpu_, p_device_, size_in_bytes_, ::cudaMemcpyDeviceToHost); + ::cudaMemcpyAsync(p_cpu_, p_device_, size_in_bytes_, ::cudaMemcpyDeviceToHost, GetStream()); + ::cudaStreamSynchronize(GetStream()); } void CopyCpuToDevice() override { assert(p_cpu_); - ::cudaMemcpy(p_device_, p_cpu_, size_in_bytes_, ::cudaMemcpyHostToDevice); + ::cudaMemcpyAsync(p_device_, p_cpu_, size_in_bytes_, ::cudaMemcpyHostToDevice, GetStream()); } - void CopyFrom(size_t begin_source, DeviceBuffer& source, size_t begin_dest, size_t size_in_bytes) override { - if (source.GetType() == label_cuda_cpu) - ::cudaMemcpyAsync(p_device_ + begin_source, source.p_device_ + begin_dest, size_in_bytes, ::cudaMemcpyHostToDevice, GetStream()); - else if (source.GetType() == label_cuda) - ::cudaMemcpyAsync(p_device_ + begin_source, source.p_device_ + begin_dest, size_in_bytes, ::cudaMemcpyDeviceToDevice, GetStream()); + void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override { + if (source.GetType() == device_label) + ::cudaMemcpyAsync(p_device_ + begin_dest, source.p_device_ + begin_source, size_in_bytes, ::cudaMemcpyDeviceToDevice, GetStream()); else - throw std::runtime_error("Cuda GpuMemory::CopyFromDevice not implemented for " + std::string(source.GetType())); + gp_genai->CopyThroughCpu(*this, begin_dest, source, begin_source, size_in_bytes); + } + + void Zero() override { + ::cudaMemsetAsync(p_device_, 0, size_in_bytes_, GetStream()); } bool owned_; // If we own the memory, we delete it on destruction }; -struct CudaInterfaceImpl : CudaInterface { +struct CudaInterfaceImpl final : DeviceInterface { CudaInterfaceImpl() { - cuda_stream_.Create(); } ~CudaInterfaceImpl() { } - std::shared_ptr AllocateBase(size_t size, bool cpu_accessible) override { - if (cpu_accessible) - return std::make_shared(size); + DeviceType GetType() const override { return DeviceType::CUDA; } + + void InitOrt(const OrtApi& api, Ort::Allocator& allocator) override { + Ort::api = &api; + assert(!ort_allocator_); + ort_allocator_ = &allocator; + } + + Ort::Allocator& GetAllocator() override { + return *ort_allocator_; + } + + std::shared_ptr AllocateBase(size_t size) override { return std::make_shared(size); } @@ -113,95 +104,79 @@ struct CudaInterfaceImpl : CudaInterface { } void Synchronize() override { - ::cudaStreamSynchronize(cuda_stream_.get()); - } - - cudaStream_t GetCudaStream() override { - return cuda_stream_.get(); - } - - void Int32ToInt64(const int32_t* input, int64_t* output, int count, cudaStream_t stream) override { - cuda::LaunchInt32ToInt64(input, output, count, stream); - } - - void Fp16ToFp32(const uint16_t* input, float* output, int count, cudaStream_t stream) override { - cuda::LaunchFp16ToFp32(input, output, count, stream); + ::cudaStreamSynchronize(GetStream()); } - void Fp32ToFp16(const float* input, uint16_t* output, int count, cudaStream_t stream) override { - cuda::LaunchFp32ToFp16(input, output, count, stream); + void* GetCudaStream() override { + return GetStream(); } - void LaunchExpandAndInt32ToInt64(const int32_t* src, int64_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) override { - cuda::LaunchExpandAndInt32ToInt64(src, dst, num_beams, batch_size, sequence_length, stream); - } + bool Cast(OrtValue& input, OrtValue& output) override { + auto input_info = input.GetTensorTypeAndShapeInfo(); + auto output_info = output.GetTensorTypeAndShapeInfo(); - void LaunchExpand(const int32_t* src, int32_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) override { - cuda::LaunchExpand(src, dst, num_beams, batch_size, sequence_length, stream); - } + auto input_type = input_info->GetElementType(); + auto output_type = output_info->GetElementType(); - void Launch_UpdatePositionIds(int32_t* position_ids, int batch_beam_size, int total_length, int new_kv_length, cudaStream_t stream) override { - cuda::Launch_UpdatePositionIds(position_ids, batch_beam_size, total_length, new_kv_length, stream); - } + auto input_data = input.GetTensorRawData(); + auto output_data = output.GetTensorMutableRawData(); - void Launch_UpdatePositionIds(int64_t* position_ids, int batch_beam_size, int total_length, int new_kv_length, cudaStream_t stream) override { - cuda::Launch_UpdatePositionIds(position_ids, batch_beam_size, total_length, new_kv_length, stream); - } + auto element_count = input_info->GetElementCount(); + if (element_count != output_info->GetElementCount()) + throw std::runtime_error("Cast - input and output element counts do not match"); + if (input_type == output_type) + throw std::runtime_error("Cast - input and output types are the same"); - void Launch_UpdateAttentionMask(int32_t* mask_data, const int32_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) override { - cuda::Launch_UpdateAttentionMask(mask_data, old_data, batch_beam_size, new_kv_length, total_length, max_length, update_only, stream); + if (input_type == Ort::TypeToTensorType && output_type == Ort::TypeToTensorType) { + cuda::LaunchFp32ToFp16(reinterpret_cast(input_data), reinterpret_cast(output_data), static_cast(element_count), GetStream()); + } else if (input_type == Ort::TypeToTensorType && output_type == Ort::TypeToTensorType) { + cuda::LaunchFp16ToFp32(reinterpret_cast(input_data), reinterpret_cast(output_data), static_cast(element_count), GetStream()); + } else if (input_type == Ort::TypeToTensorType && output_type == Ort::TypeToTensorType) { + cuda::LaunchInt32ToInt64(reinterpret_cast(input_data), reinterpret_cast(output_data), static_cast(element_count), GetStream()); + } else + return false; + return true; } - void Launch_UpdateAttentionMask(int64_t* mask_data, const int64_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) override { - cuda::Launch_UpdateAttentionMask(mask_data, old_data, batch_beam_size, new_kv_length, total_length, max_length, update_only, stream); - } - - void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream) override { - cuda::LaunchHandleEOSArray(batch_logits, batch_beam_size, vocab_size, eos_token_ids, eos_token_ids_count, stream); - } - - void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length, cudaStream_t stream) override { - cuda::UpdateCacheIndirectionKernelLauncher(tgt_indir_cache, src_indir_cache, beam_ids, batch_size, beam_width, input_seq_length, max_seq_length, current_length, stream); - } - - void ReorderPastStatesKernelLauncher(void* out_buffer, const void* in_buffer, int batch_size, int num_heads, int max_length, int head_size, int chunk_size, cudaStream_t stream) override { - cuda::ReorderPastStatesKernelLauncher(out_buffer, in_buffer, batch_size, num_heads, max_length, head_size, chunk_size, stream); + void UpdatePositionIds(void* position_ids, int batch_beam_size, int total_length, int new_kv_length, ONNXTensorElementDataType type) override { + if (type == Ort::TypeToTensorType) + cuda::Launch_UpdatePositionIds(static_cast(position_ids), batch_beam_size, total_length, new_kv_length, GetStream()); + else + cuda::Launch_UpdatePositionIds(static_cast(position_ids), batch_beam_size, total_length, new_kv_length, GetStream()); } - void LaunchCopyCrossQKSingleDecodeStep(cudaStream_t stream, float* cross_qk_buffer_data, float** qk_layer_pointers, int token_index, int batch_beam_size, int num_layers, int num_heads, int num_alignment_heads, const int* alignment_heads, int frames, int max_length) override { - cuda::LaunchCopyCrossQKSingleDecodeStep(stream, cross_qk_buffer_data, qk_layer_pointers, token_index, batch_beam_size, num_layers, num_heads, num_alignment_heads, alignment_heads, frames, max_length); + void UpdateAttentionMask(void* mask_data, const void* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, ONNXTensorElementDataType type) override { + if (type == Ort::TypeToTensorType) + cuda::Launch_UpdateAttentionMask(static_cast(mask_data), static_cast(old_data), batch_beam_size, new_kv_length, total_length, max_length, update_only, GetStream()); + else + cuda::Launch_UpdateAttentionMask(static_cast(mask_data), static_cast(old_data), batch_beam_size, new_kv_length, total_length, max_length, update_only, GetStream()); } - void LaunchFinalizeCrossQK(cudaStream_t stream, int iteration_number, int context_decoding_len, int batch_size, int num_beams, int max_length, int num_alignment_heads, int frames_of_k, const float* cross_qk_buffer_data, float* cross_qk_output, int num_return_sequences, const int* cache_indir_data) override { - cuda::LaunchFinalizeCrossQK(stream, iteration_number, context_decoding_len, batch_size, num_beams, max_length, num_alignment_heads, frames_of_k, cross_qk_buffer_data, cross_qk_output, num_return_sequences, cache_indir_data); + void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count) override { + cuda::LaunchHandleEOSArray(batch_logits, batch_beam_size, vocab_size, eos_token_ids, eos_token_ids_count, GetStream()); } - cudaError_t cudaMemcpyAsync(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream) override { - return ::cudaMemcpyAsync(dst, src, count, kind, stream); + void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length) override { + cuda::UpdateCacheIndirectionKernelLauncher(tgt_indir_cache, src_indir_cache, beam_ids, batch_size, beam_width, input_seq_length, max_seq_length, current_length, GetStream()); } - cudaError_t cudaMemcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind) override { - return ::cudaMemcpy(dst, src, count, kind); + void ReorderPastStatesKernelLauncher(void* out_buffer, const void* in_buffer, int batch_size, int num_heads, int max_length, int head_size, int chunk_size) override { + cuda::ReorderPastStatesKernelLauncher(out_buffer, in_buffer, batch_size, num_heads, max_length, head_size, chunk_size, GetStream()); } - cudaError_t cudaMemsetAsync(void* ptr, int value, size_t count, cudaStream_t stream) override { - return ::cudaMemsetAsync(ptr, value, count, stream); + void LaunchCopyCrossQKSingleDecodeStep(float* cross_qk_buffer_data, float** qk_layer_pointers, int token_index, int batch_beam_size, int num_layers, int num_heads, int num_alignment_heads, const int* alignment_heads, int frames, int max_length) override { + cuda::LaunchCopyCrossQKSingleDecodeStep(GetStream(), cross_qk_buffer_data, qk_layer_pointers, token_index, batch_beam_size, num_layers, num_heads, num_alignment_heads, alignment_heads, frames, max_length); } - cudaError_t cudaMemset(void* ptr, int value, size_t count) override { - return ::cudaMemset(ptr, value, count); + void LaunchFinalizeCrossQK(int iteration_number, int context_decoding_len, int batch_size, int num_beams, int max_length, int num_alignment_heads, int frames_of_k, const float* cross_qk_buffer_data, float* cross_qk_output, int num_return_sequences, const int* cache_indir_data) override { + cuda::LaunchFinalizeCrossQK(GetStream(), iteration_number, context_decoding_len, batch_size, num_beams, max_length, num_alignment_heads, frames_of_k, cross_qk_buffer_data, cross_qk_output, num_return_sequences, cache_indir_data); } - - private: - cuda_stream_holder cuda_stream_; }; -std::unique_ptr g_cuda_device; +std::unique_ptr g_cuda_device; DeviceInterface& GetCudaDeviceInterface() { return *g_cuda_device; } -cudaStream_t GetStream() { return g_cuda_device->GetCudaStream(); } -GenaiInterface* gp_genai{}; LogItems& GetLogItems() { return gp_genai->GetLogItems(); } std::ostream& operator<<(std::ostream& stream, SGR sgr_code) { return gp_genai->operator_leftshift(stream, sgr_code); } std::ostream& Log(std::string_view label, std::string_view text) { return gp_genai->Log(label, text); } @@ -239,7 +214,7 @@ void operator delete(void* p, size_t /*size*/) noexcept { Generators::gp_genai-> #endif extern "C" { -Generators::CudaInterface* GetInterface(GenaiInterface* p_genai) { +Generators::DeviceInterface* GetInterface(GenaiInterface* p_genai) { Generators::gp_genai = p_genai; Generators::g_cuda_device = std::make_unique(); return Generators::g_cuda_device.get(); diff --git a/src/cuda/interface.h b/src/cuda/interface.h index d664277cc..2236e7d31 100644 --- a/src/cuda/interface.h +++ b/src/cuda/interface.h @@ -7,6 +7,8 @@ struct GenaiInterface { virtual void HeapFree(void*) = 0; #endif + virtual void CopyThroughCpu(Generators::DeviceBuffer& dest, size_t begin_dest, Generators::DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) = 0; + virtual Generators::LogItems& GetLogItems() = 0; virtual std::ostream& operator_leftshift(std::ostream& stream, Generators::SGR sgr_code) = 0; virtual std::ostream& Log(std::string_view label, std::string_view text = {}) = 0; @@ -20,31 +22,5 @@ struct GenaiInterface { namespace Generators { LogItems& GetLogItems(); - -#if USE_CUDA DeviceInterface& GetCudaDeviceInterface(); - -struct CudaInterface : DeviceInterface { - virtual void Int32ToInt64(const int32_t* input, int64_t* output, int count, cudaStream_t stream) = 0; - virtual void Fp16ToFp32(const uint16_t* input, float* output, int count, cudaStream_t stream) = 0; - virtual void Fp32ToFp16(const float* input, uint16_t* output, int count, cudaStream_t stream) = 0; - // TODO: This can be collapsed into a single function with a template parameter - virtual void LaunchExpandAndInt32ToInt64(const int32_t* src, int64_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) = 0; - virtual void LaunchExpand(const int32_t* src, int32_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) = 0; - virtual void Launch_UpdatePositionIds(int32_t* position_ids, int batch_beam_size, int total_length, int new_kv_length, cudaStream_t stream) = 0; - virtual void Launch_UpdatePositionIds(int64_t* position_ids, int batch_beam_size, int total_length, int new_kv_length, cudaStream_t stream) = 0; - virtual void Launch_UpdateAttentionMask(int32_t* mask_data, const int32_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) = 0; - virtual void Launch_UpdateAttentionMask(int64_t* mask_data, const int64_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) = 0; - virtual void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream) = 0; - virtual void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length, cudaStream_t stream) = 0; - virtual void ReorderPastStatesKernelLauncher(void* out_buffer, const void* in_buffer, int batch_size, int num_heads, int max_length, int head_size, int chunk_size, cudaStream_t stream) = 0; - virtual void LaunchCopyCrossQKSingleDecodeStep(cudaStream_t stream, float* cross_qk_buffer_data, float** qk_layer_pointers, int token_index, int batch_beam_size, int num_layers, int num_heads, int num_alignment_heads, const int* alignment_heads, int frames, int max_length) = 0; - virtual void LaunchFinalizeCrossQK(cudaStream_t stream, int iteration_number, int context_decoding_len, int batch_size, int num_beams, int max_length, int num_alignment_heads, int frames_of_k, const float* cross_qk_buffer_data, float* cross_qk_output, int num_return_sequences, const int* cache_indir_data) = 0; - - virtual cudaError_t cudaMemcpyAsync(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream) = 0; - virtual cudaError_t cudaMemcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind) = 0; - virtual cudaError_t cudaMemsetAsync(void* ptr, int value, size_t count, cudaStream_t stream) = 0; - virtual cudaError_t cudaMemset(void* ptr, int value, size_t count) = 0; -}; -#endif } // namespace Generators diff --git a/src/models/kernels.h b/src/cuda/kernels.h similarity index 92% rename from src/models/kernels.h rename to src/cuda/kernels.h index fece7dadf..860af48c8 100644 --- a/src/models/kernels.h +++ b/src/cuda/kernels.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + #pragma once namespace Generators { @@ -15,8 +16,6 @@ void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_si void LaunchFp16ToFp32(const uint16_t* fp16, float* fp32, int count, cudaStream_t stream); void LaunchFp32ToFp16(const float* fp32, uint16_t* fp16, int count, cudaStream_t stream); void LaunchInt32ToInt64(const int32_t* src, int64_t* dst, int count, cudaStream_t stream); -void LaunchExpandAndInt32ToInt64(const int32_t* src, int64_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream); -void LaunchExpand(const int32_t* src, int32_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream); template void BufferExpansionKernelLauncher(const T* input, T* output, int batch_size, int beam_width, int chunk_size, cudaStream_t stream); diff --git a/src/cuda/model_kernels.cu b/src/cuda/model_kernels.cu index 0eb316383..23d9037e0 100644 --- a/src/cuda/model_kernels.cu +++ b/src/cuda/model_kernels.cu @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + #include #include #include @@ -137,36 +138,6 @@ void LaunchInt32ToInt64(const int32_t* src, int64_t* dst, int count, cudaStream_ ConvertInt32ToInt64<<>>(src, dst, count); } -__global__ void ExpandAndConvertInt32ToInt64(const int32_t* src, int64_t* dst, int num_beams, int batch_size, int sequence_length) { - int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < num_beams * batch_size * sequence_length) { - int batch_id = idx / (num_beams * sequence_length); - int seq_id = idx % sequence_length; - dst[idx] = (int64_t)src[batch_id * sequence_length + seq_id]; - } -} - -void LaunchExpandAndInt32ToInt64(const int32_t* src, int64_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) { - int block_size = 256; - int num_blocks = (num_beams * batch_size * sequence_length + block_size - 1) / block_size; - ExpandAndConvertInt32ToInt64<<>>(src, dst, num_beams, batch_size, sequence_length); -} - -__global__ void Expand(const int32_t* src, int32_t* dst, int num_beams, int batch_size, int sequence_length) { - int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < num_beams * batch_size * sequence_length) { - int batch_id = idx / (num_beams * sequence_length); - int seq_id = idx % sequence_length; - dst[idx] = src[batch_id * sequence_length + seq_id]; - } -} - -void LaunchExpand(const int32_t* src, int32_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) { - int block_size = 256; - int num_blocks = (num_beams * batch_size * sequence_length + block_size - 1) / block_size; - Expand<<>>(src, dst, num_beams, batch_size, sequence_length); -} - namespace { struct ReorderPastStateParams { diff --git a/src/cuda/search_cuda.cpp b/src/cuda/search_cuda.cpp index fa80f67dc..5bfbc6ac8 100644 --- a/src/cuda/search_cuda.cpp +++ b/src/cuda/search_cuda.cpp @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include "generators.h" #include "interface.h" #include "search.h" @@ -22,7 +25,7 @@ Search_Cuda::Search_Cuda(const GeneratorParams& params) sequence_lengths_ = params.p_device->Allocate(batch_beam_size); eos_meet_buffer_ = CudaMallocArray(batch_beam_size, &eos_meet_); - cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), params_->cuda_stream); + cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), GetStream()); done_cpu_ = CudaMallocHostArray(1); *done_cpu_ = false; @@ -31,15 +34,15 @@ Search_Cuda::Search_Cuda(const GeneratorParams& params) GreedySearch_Cuda::GreedySearch_Cuda(const GeneratorParams& params) : Search_Cuda{params} { next_tokens_buffer_ = params.p_device->Allocate(params.search.batch_size); + next_tokens_buffer_.Zero(); next_tokens_ = gpu_span(next_tokens_buffer_.Span()); - cudaMemsetAsync(next_tokens_.data(), 0, next_tokens_.size_bytes(), params_->cuda_stream); unsigned long long random_seed; if (params_->search.random_seed != -1) random_seed = params_->search.random_seed; else random_seed = std::random_device{}(); - samplingdata_ = std::make_unique(random_seed, params_->search.batch_size, params_->config.model.vocab_size, params_->cuda_stream); + samplingdata_ = std::make_unique(random_seed, params_->search.batch_size, params_->config.model.vocab_size, GetStream()); } BeamSearch_Cuda::BeamSearch_Cuda(const GeneratorParams& params) @@ -58,7 +61,7 @@ BeamSearch_Cuda::BeamSearch_Cuda(const GeneratorParams& params) topk_buffer_ = CudaMallocArray(topk_buffer_size); static_assert(sizeof(float) == sizeof(int32_t)); // The topk_buffer assumes these match, fix for float16 - cudaMemsetAsync(topk_buffer_.get(), 0, topk_buffer_size * sizeof(float), params_->cuda_stream); + cudaMemsetAsync(topk_buffer_.get(), 0, topk_buffer_size * sizeof(float), GetStream()); } BeamSearch_Cuda::~BeamSearch_Cuda() = default; @@ -84,20 +87,20 @@ DeviceSpan BeamSearch_Cuda::GetNextIndices() { } void BeamSearch_Cuda::SelectTop() { - cuda::DispatchBlockwiseSoftmaxForward(const_cast(¶ms_->cuda_stream), softmax_buffer_.get(), next_token_scores_.Span().data(), params_->config.model.vocab_size, + cuda::DispatchBlockwiseSoftmaxForward(GetStream(), softmax_buffer_.get(), next_token_scores_.Span().data(), params_->config.model.vocab_size, params_->config.model.vocab_size, params_->config.model.vocab_size, params_->BatchBeamSize()); // Copy next_token_scores to CPU auto next_token_scores_cpu = CudaMallocHostArray(params_->BatchBeamSize() * params_->config.model.vocab_size); - cudaMemcpyAsync(next_token_scores_cpu.get(), softmax_buffer_.get(), params_->BatchBeamSize() * params_->config.model.vocab_size * sizeof(float), cudaMemcpyDeviceToHost, params_->cuda_stream); - CudaCheck() == cudaStreamSynchronize(params_->cuda_stream); + cudaMemcpyAsync(next_token_scores_cpu.get(), softmax_buffer_.get(), params_->BatchBeamSize() * params_->config.model.vocab_size * sizeof(float), cudaMemcpyDeviceToHost, GetStream()); + CudaCheck() == cudaStreamSynchronize(GetStream()); auto beam_scores = beam_scorer_->GetNextScores(); // Add beam score to next token scores. Corresponding python code is like: // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) cuda::LaunchAddProbsKernel(softmax_buffer_.get(), beam_scores.Span().data(), - params_->search.batch_size, params_->search.num_beams, params_->config.model.vocab_size, params_->cuda_stream); + params_->search.batch_size, params_->search.num_beams, params_->config.model.vocab_size, GetStream()); if (params_->search.num_beams <= 32) { constexpr size_t max_parts_of_vocab = 128; @@ -120,11 +123,11 @@ void BeamSearch_Cuda::SelectTop() { topk_next_scores_.get(), topk_next_tokens_.get(), topk_next_indices_.get(), - params_->cuda_stream); + GetStream()); } else assert(false); - CudaCheck() == cudaStreamSynchronize(params_->cuda_stream); + CudaCheck() == cudaStreamSynchronize(GetStream()); size_t size = params_->BatchBeamSize() * 2; std::span next_scores{topk_next_scores_.get(), size}; @@ -146,13 +149,13 @@ void BeamSearch_Cuda::SelectTop() { void GreedySearch_Cuda::SampleTopKTopP(int k, float p, float temperature) { std::span scores = next_token_scores_.Span(); assert(scores.size() == params_->search.batch_size * params_->config.model.vocab_size); - cuda::GetSample(samplingdata_.get(), params_->cuda_stream, next_tokens_.data(), scores.data(), int(scores.size() / params_->search.batch_size), + cuda::GetSample(samplingdata_.get(), GetStream(), next_tokens_.data(), scores.data(), int(scores.size() / params_->search.batch_size), params_->search.batch_size, k, p, temperature); // Check for EOS assert(next_tokens_.size() == eos_meet_.size()); // Don't replace EOS with pad for batch_size == 1 for continuous decoding mode - cuda::Launch_CheckForEOSAndPad(next_tokens_.data(), static_cast(next_tokens_.size()), eos_meet_.data(), params_->config.model.eos_token_id, params_->search.batch_size > 1 ? params_->config.model.pad_token_id : params_->config.model.eos_token_id, done_cpu_.get(), params_->cuda_stream); + cuda::Launch_CheckForEOSAndPad(next_tokens_.data(), static_cast(next_tokens_.size()), eos_meet_.data(), params_->config.model.eos_token_id, params_->search.batch_size > 1 ? params_->config.model.pad_token_id : params_->config.model.eos_token_id, done_cpu_.get(), GetStream()); // Append tokens cuda::Launch_AppendNextTokensToSequences(next_tokens_buffer_.Span(), sequences_.GetSequences().Span(), params_->BatchBeamSize(), sequences_.GetSequenceLength(), sequences_.max_length_, GetStream()); @@ -207,7 +210,7 @@ std::span Search_Cuda::GetScores() { // Set user input tokens (batch_beam_size, sequence_length) void GreedySearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { - cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), params_->cuda_stream); + cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), GetStream()); *done_cpu_ = false; auto next_tokens_gpu = next_tokens.Span(); @@ -221,7 +224,7 @@ void GreedySearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { return; } - cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), params_->cuda_stream); + cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), GetStream()); *done_cpu_ = false; } @@ -234,12 +237,12 @@ void BeamSearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { } void GreedySearch_Cuda::RewindTo(size_t index) { - cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), params_->cuda_stream); + cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), GetStream()); *done_cpu_ = false; if (index > 0) cuda::Launch_GetLastTokens(next_tokens_.data(), sequences_.GetSequences().Span().data(), static_cast(params_->BatchBeamSize()), static_cast(index), sequences_.max_length_, GetStream()); else - cudaMemsetAsync(next_tokens_.data(), 0, params_->search.batch_size * sizeof(int32_t), params_->cuda_stream); + cudaMemsetAsync(next_tokens_.data(), 0, params_->search.batch_size * sizeof(int32_t), GetStream()); sequences_.RewindTo(index); } @@ -247,7 +250,7 @@ void Search_Cuda::ApplyMinLength(int min_length) { if (sequences_.GetSequenceLength() >= min_length) return; - cuda::LaunchSetScoreProcessor(GetScores().data(), params_->BatchBeamSize(), params_->config.model.vocab_size, params_->config.model.eos_token_id, std::numeric_limits::lowest(), params_->cuda_stream); + cuda::LaunchSetScoreProcessor(GetScores().data(), params_->BatchBeamSize(), params_->config.model.vocab_size, params_->config.model.eos_token_id, std::numeric_limits::lowest(), GetStream()); } void Search_Cuda::ApplyRepetitionPenalty(float penalty) { @@ -256,7 +259,7 @@ void Search_Cuda::ApplyRepetitionPenalty(float penalty) { cuda::LaunchRepetitionPenaltyProcessor(sequences_.GetSequences().Span().data(), GetScores().data(), params_->search.batch_size, params_->search.num_beams, params_->config.model.vocab_size, - params_->search.max_length, GetSequenceLength(), penalty, params_->cuda_stream); + params_->search.max_length, GetSequenceLength(), penalty, GetStream()); } } // namespace Generators \ No newline at end of file diff --git a/src/cuda/search_cuda.cu b/src/cuda/search_cuda.cu index f8c9ed3bf..fcb21200d 100644 --- a/src/cuda/search_cuda.cu +++ b/src/cuda/search_cuda.cu @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include #include diff --git a/src/cuda/search_cuda.cuh b/src/cuda/search_cuda.cuh index f21237c41..a0a07fb9f 100644 --- a/src/cuda/search_cuda.cuh +++ b/src/cuda/search_cuda.cuh @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + namespace Generators { namespace cuda { diff --git a/src/cuda/search_cuda.h b/src/cuda/search_cuda.h index 2e0ec4610..acdd0525f 100644 --- a/src/cuda/search_cuda.h +++ b/src/cuda/search_cuda.h @@ -1,4 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once +#include #include "search_cuda.cuh" #include "cuda_sampling.cuh" @@ -12,7 +16,7 @@ struct Search_Cuda : Search { DeviceSpan GetSequenceLengths() override { return sequence_lengths_; } bool IsDone() const { - cudaStreamSynchronize(params_->cuda_stream); + cudaStreamSynchronize(GetStream()); return *done_cpu_; } // TODO: Use an event diff --git a/src/dml/interface.cpp b/src/dml/interface.cpp new file mode 100644 index 000000000..5f6c40c6a --- /dev/null +++ b/src/dml/interface.cpp @@ -0,0 +1,229 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "../generators.h" +#include "../search.h" +#include "../cpu/interface.h" +#include "interface.h" +#include + +#include +#include "dml_provider_factory.h" +#include "../dml/dml_helpers.h" +#include "../dml/dml_execution_context.h" +#include "../dml/dml_pooled_upload_heap.h" +#include "../dml/dml_readback_heap.h" + +std::string CurrentModulePath(); + +namespace Generators { +namespace Dml { // If this was in a shared library it wouldn't need to be in its own namespace + +Ort::Allocator* ort_allocator_{}; +const char* device_label = "dml"; + +wil::unique_hmodule smart_directml_dll_; +DmlObjects dml_objects_; +const OrtDmlApi* dml_api_{}; +std::unique_ptr dml_pooled_upload_heap_; +std::unique_ptr dml_execution_context_; +std::unique_ptr dml_readback_heap_; +ComPtr dml_device_; + +struct GpuMemory final : DeviceBuffer { + GpuMemory(size_t size) : owned_{true} { + size_in_bytes_ = size; + p_device_ = static_cast(ort_allocator_->Alloc(size_in_bytes_)); + Ort::ThrowOnError(dml_api_->GetD3D12ResourceFromAllocation(ort_allocator_, p_device_, &gpu_resource_)); + } + + GpuMemory(void* p, size_t size) : owned_{false} { + size_in_bytes_ = size; + p_device_ = static_cast(p); + Ort::ThrowOnError(dml_api_->GetD3D12ResourceFromAllocation(ort_allocator_, p_device_, &gpu_resource_)); + } + + ~GpuMemory() override { + if (owned_) + ort_allocator_->Free(p_device_); + if (p_cpu_) + free(p_cpu_); + } + + const char* GetType() const override { return device_label; } + + void AllocateCpu() override { + if (!p_cpu_) + p_cpu_ = static_cast(malloc(size_in_bytes_)); + } + + void CopyDeviceToCpu() override { + AllocateCpu(); + dml_readback_heap_->ReadbackFromGpu(std::span(p_cpu_, size_in_bytes_), gpu_resource_.Get(), 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS); + } + + void CopyCpuToDevice() override { + assert(p_cpu_); + auto source = std::span(p_cpu_, size_in_bytes_); + dml_pooled_upload_heap_->BeginUploadToGpu(gpu_resource_.Get(), 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, source); + } + + void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override { + if (source.GetType() == device_label) { + auto& source_gpu = dynamic_cast(source); + dml_execution_context_->CopyBufferRegion( + gpu_resource_.Get(), + begin_dest, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + source_gpu.gpu_resource_.Get(), + begin_source, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + size_in_bytes); + } else + CopyThroughCpu(*this, begin_dest, source, begin_source, size_in_bytes); + } + + void Zero() override { + // TODO: Implement a zeroing that runs directly on DML vs going through CPU + AllocateCpu(); + memset(p_cpu_, 0, size_in_bytes_); + CopyCpuToDevice(); + } + + ComPtr gpu_resource_; + bool owned_; // If we own the memory, we delete it on destruction +}; + +struct InterfaceImpl : DeviceInterface { + InterfaceImpl(LUID* p_device_luid) { + Ort::ThrowOnError(Ort::api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast(&dml_api_))); + if (!dml_api_) { + throw std::runtime_error("Unexpected nullptr getting OrtDmlApi"); + } + + dml_objects_ = DmlHelpers::CreateDmlObjects(CurrentModulePath(), p_device_luid); + + constexpr auto directml_dll = "DirectML.dll"; + smart_directml_dll_ = wil::unique_hmodule{LoadLibraryEx(directml_dll, nullptr, 0)}; + if (!smart_directml_dll_) + throw std::runtime_error("DirectML.dll not found"); + + auto dml_create_device1_fn = reinterpret_cast(GetProcAddress(smart_directml_dll_.get(), "DMLCreateDevice1")); + THROW_LAST_ERROR_IF(!dml_create_device1_fn); + THROW_IF_FAILED(dml_create_device1_fn(dml_objects_.d3d12_device.Get(), DML_CREATE_DEVICE_FLAG_NONE, DML_FEATURE_LEVEL_5_0, IID_PPV_ARGS(&dml_device_))); + + Ort::ThrowOnError(Ort::api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast(&dml_api_))); + } + + DeviceType GetType() const override { return DeviceType::DML; } + + void InitOrt(const OrtApi& api, Ort::Allocator& allocator) override { + Ort::api = &api; + assert(!ort_allocator_); + ort_allocator_ = &allocator; + + dml_execution_context_ = std::make_unique( + dml_objects_.d3d12_device.Get(), + dml_device_.Get(), + dml_objects_.command_queue.Get(), + *ort_allocator_, + dml_api_); + + dml_pooled_upload_heap_ = std::make_unique(dml_objects_.d3d12_device.Get(), dml_execution_context_.get()); + dml_readback_heap_ = std::make_unique(dml_objects_.d3d12_device.Get(), dml_execution_context_.get()); + } + + Ort::Allocator& GetAllocator() override { + return *ort_allocator_; + } + + std::shared_ptr AllocateBase(size_t size) override { + return std::make_shared(size); + } + + std::shared_ptr WrapMemoryBase(void* p, size_t size) override { + return std::make_shared(p, size); + } + + std::unique_ptr CreateGreedy(const GeneratorParams& params) override { + return GetCpuInterface()->CreateGreedy(params); + } + + std::unique_ptr CreateBeam(const GeneratorParams& params) override { + return GetCpuInterface()->CreateBeam(params); + } + +#if 0 + void UpdatePositionIDs() { + ComPtr target_resource; + Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, position_ids_->GetTensorMutableRawData(), &target_resource)); + + dml_update_position_ids_kernel_ = DmlIncrementValuesKernel( + model_.GetD3D12Device(), + model_.GetDmlExecutionContext(), + static_cast(position_ids_shape_[0]), + type_, + target_resource.Get()); + + // Execute the cached command list + ComPtr fence; + uint64_t completion_value; + model_.GetDmlExecutionContext()->ExecuteCommandList(dml_update_position_ids_kernel_->GetCommandList(), &fence, &completion_value); + } + + void UpdateAttentionMask(int total_length) { + ComPtr attention_mask_resource; + Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_->GetTensorMutableRawData(), &attention_mask_resource)); + ComPtr attention_mask_next_resource; + Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_next_->GetTensorMutableRawData(), &attention_mask_next_resource)); + if (is_first_mask_update_) { + dml_update_mask_kernel_ = DmlUpdateMaskKernel( + model_.GetD3D12Device(), + model_.GetDmlExecutionContext(), + static_cast(attention_mask_shape_[0]), + static_cast(attention_mask_shape_[1]), + type_, + total_length, + attention_mask_resource.Get(), + attention_mask_next_resource.Get()); + is_second_mask_update_ = true; + } else if (is_second_mask_update_) { + dml_update_mask_kernel_ = DmlUpdateMaskKernel( + model_.GetD3D12Device(), + model_.GetDmlExecutionContext(), + static_cast(attention_mask_shape_[0]), + static_cast(attention_mask_shape_[1]), + type_, + 1, + attention_mask_resource.Get(), + attention_mask_next_resource.Get()); + is_second_mask_update_ = false; + } + ComPtr fence; + uint64_t completion_value; + model_.GetDmlExecutionContext()->ExecuteCommandList(dml_update_mask_kernel_->GetCommandList(), &fence, &completion_value); + } +#endif + + void Synchronize() override { + } +}; + +} // namespace Dml + +std::unique_ptr g_dml_device; + +void InitDmlInterface(LUID* p_device_luid) { + if (!g_dml_device) + g_dml_device = std::make_unique(p_device_luid); +} + +void SetDmlProvider(OrtSessionOptions& session_options) { + Ort::ThrowOnError(Dml::dml_api_->SessionOptionsAppendExecutionProvider_DML1(&session_options, Dml::dml_device_.Get(), Dml::dml_objects_.command_queue.Get())); +} + +DeviceInterface* GetDmlInterface() { + return g_dml_device.get(); +} + +} // namespace Generators diff --git a/src/dml/interface.h b/src/dml/interface.h new file mode 100644 index 000000000..9ef2c5785 --- /dev/null +++ b/src/dml/interface.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _WIN32 +typedef struct _LUID { + uint32_t LowPart; + int32_t HighPart; +} LUID, *PLUID; +#endif + +namespace Generators { + +void InitDmlInterface(LUID* p_device_luid); +void SetDmlProvider(OrtSessionOptions& options); + +DeviceInterface* GetDmlInterface(); + +} // namespace Generators diff --git a/src/generators.cpp b/src/generators.cpp index 961795fba..fd550421b 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -9,11 +9,11 @@ #include "search.h" #include "cpu/interface.h" #include "cuda/interface.h" -#if USE_CUDA -#include "models/kernels.h" -#endif +#include "dml/interface.h" +#include "qnn/interface.h" +#include "webgpu/interface.h" -#if _WIN32 +#if defined(_WIN32) EXTERN_C IMAGE_DOS_HEADER __ImageBase; std::string CurrentModulePath() { @@ -39,11 +39,6 @@ void ThrowErrorIfSessionTerminated(bool is_session_terminated) { namespace Generators { -#if USE_CUDA -// TODO: Remove once we remove all dependencies -void OnCudaError(cudaError_t error) { assert(false); } -#endif - static bool _ = (Ort::InitApi(), false); static OrtLoggingLevel GetDefaultOrtLoggingLevel() { @@ -57,6 +52,9 @@ OrtGlobals::OrtGlobals() auto arena_config = OrtArenaCfg::Create(0, -1, -1, -1); Ort::Allocator& allocator_cpu{Ort::Allocator::GetWithDefaultOptions()}; env_->CreateAndRegisterAllocator(allocator_cpu.GetInfo(), *arena_config); + + // Init the CPU device (special case because it always exists, and its allocator is special + GetDeviceInterface(DeviceType::CPU)->InitOrt(*Ort::api, allocator_cpu); } // Ensure Shutdown() has been called before process exit @@ -96,46 +94,66 @@ OrtEnv& GetOrtEnv() { return *GetOrtGlobals()->env_; } +// Fallback to copy between two separate device buffers by going through CPU memory (slow unless we're the CPU device) +void CopyThroughCpu(DeviceBuffer& dest, size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) { + source.CopyDeviceToCpu(); + auto source_span = std::span(source.p_cpu_ + begin_source, size_in_bytes); + // If we're overwriting the entire destination + if (dest.size_in_bytes_ == size_in_bytes) + dest.AllocateCpu(); + else + dest.CopyDeviceToCpu(); // Overwriting part of destination, so copy over initial contents first + std::copy(source_span.begin(), source_span.end(), dest.p_cpu_ + begin_dest); + dest.CopyCpuToDevice(); +} + struct GenaiInterfaceImpl : GenaiInterface { #if _WIN32 void* HeapAllocate(size_t size) override { return std::malloc(size); } void HeapFree(void* p) override { std::free(p); } #endif + void CopyThroughCpu(DeviceBuffer& dest, size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override { + return Generators::CopyThroughCpu(dest, begin_dest, source, begin_source, size_in_bytes); + } + Generators::LogItems& GetLogItems() override { return g_log; } std::ostream& operator_leftshift(std::ostream& stream, Generators::SGR sgr_code) override { return stream << sgr_code; } std::ostream& Log(std::string_view label, std::string_view text = {}) override { return Log(label, text); } - void DumpSpan(std::ostream& stream, std::span values) override { return DumpSpan(stream, values); } - void DumpSpan(std::ostream& stream, std::span values) override { return DumpSpan(stream, values); } + void DumpSpan(std::ostream& stream, std::span values) override { return Generators::DumpSpan(stream, values); } + void DumpSpan(std::ostream& stream, std::span values) override { return Generators::DumpSpan(stream, values); } void Sequences_AfterAppendNextTokens(Sequences* p_this, DeviceSpan next_tokens, size_t batch_beam_size) override { return p_this->AfterAppendNextTokens(next_tokens, batch_beam_size); } void Sequences_RewindTo(Sequences* p_this, size_t new_length) override { return p_this->RewindTo(new_length); } } g_genai; -#if USE_CUDA -CudaInterface* GetCudaInterface() { +DeviceInterface* GetCudaInterface() { // Load the shared library onnxruntime-genai-cuda.dll // This is a workaround to avoid linking the CUDA library to the generator library // The CUDA library is only needed for the CUDA allocator -#ifdef _WIN32 +#if defined(_WIN32) static std::unique_ptr cuda_library{LoadLibrary((CurrentModulePath() + "onnxruntime-genai-cuda.dll").c_str()), [](void* h) { FreeLibrary(reinterpret_cast(h)); }}; -#else +#elif defined(__linux__) && !defined(__ANDROID__) static std::unique_ptr cuda_library{dlopen((Ort::GetCurrentModuleDir() + "/libonnxruntime-genai-cuda.so").c_str(), RTLD_NOW | RTLD_DEEPBIND), [](void* h) { dlclose(h); }}; +#else + static std::unique_ptr cuda_library{nullptr, [](void* h) {}}; #endif if (!cuda_library) { throw std::runtime_error("Cuda interface not available."); } - Generators::CudaInterface* GetInterface(GenaiInterface * p_genai); - static CudaInterface* cuda_interface{[] { -#ifdef _WIN32 + Generators::DeviceInterface* GetInterface(GenaiInterface * p_genai); + static DeviceInterface* cuda_interface{[] { +#if defined(_WIN32) auto get_cuda_fn = reinterpret_cast(GetProcAddress(reinterpret_cast(cuda_library.get()), "GetInterface")); -#else +#elif defined(__linux__) && !defined(__ANDROID__) auto get_cuda_fn = reinterpret_cast(dlsym(cuda_library.get(), "GetInterface")); +#else + auto get_cuda_fn = [](GenaiInterface*) { return nullptr; }; #endif return get_cuda_fn(&g_genai); }()}; @@ -143,30 +161,6 @@ CudaInterface* GetCudaInterface() { return cuda_interface; } -namespace cuda { -void LaunchInt32ToInt64(const int32_t* input, int64_t* output, int count, cudaStream_t stream) { GetCudaInterface()->Int32ToInt64(input, output, count, stream); } -void LaunchFp16ToFp32(const uint16_t* input, float* output, int count, cudaStream_t stream) { GetCudaInterface()->Fp16ToFp32(input, output, count, stream); } -void LaunchFp32ToFp16(const float* input, uint16_t* output, int count, cudaStream_t stream) { GetCudaInterface()->Fp32ToFp16(input, output, count, stream); } -void LaunchExpandAndInt32ToInt64(const int32_t* src, int64_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) { GetCudaInterface()->LaunchExpandAndInt32ToInt64(src, dst, num_beams, batch_size, sequence_length, stream); } -void LaunchExpand(const int32_t* src, int32_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) { GetCudaInterface()->LaunchExpand(src, dst, num_beams, batch_size, sequence_length, stream); } -template <> -void Launch_UpdatePositionIds(int32_t* position_ids, int batch_beam_size, int total_length, int new_kv_length, cudaStream_t stream) { GetCudaInterface()->Launch_UpdatePositionIds(position_ids, batch_beam_size, total_length, new_kv_length, stream); } -template <> -void Launch_UpdatePositionIds(int64_t* position_ids, int batch_beam_size, int total_length, int new_kv_length, cudaStream_t stream) { GetCudaInterface()->Launch_UpdatePositionIds(position_ids, batch_beam_size, total_length, new_kv_length, stream); } -template <> -void Launch_UpdateAttentionMask(int32_t* mask_data, const int32_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) { GetCudaInterface()->Launch_UpdateAttentionMask(mask_data, old_data, batch_beam_size, new_kv_length, total_length, max_length, update_only, stream); } -template <> -void Launch_UpdateAttentionMask(int64_t* mask_data, const int64_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) { GetCudaInterface()->Launch_UpdateAttentionMask(mask_data, old_data, batch_beam_size, new_kv_length, total_length, max_length, update_only, stream); } -void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream) { GetCudaInterface()->LaunchHandleEOSArray(batch_logits, batch_beam_size, vocab_size, eos_token_ids, eos_token_ids_count, stream); } -void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length, cudaStream_t stream) { GetCudaInterface()->UpdateCacheIndirectionKernelLauncher(tgt_indir_cache, src_indir_cache, beam_ids, batch_size, beam_width, input_seq_length, max_seq_length, current_length, stream); } -void ReorderPastStatesKernelLauncher(void* out_buffer, const void* in_buffer, int batch_size, int num_heads, int max_length, int head_size, int chunk_size, cudaStream_t stream) { GetCudaInterface()->ReorderPastStatesKernelLauncher(out_buffer, in_buffer, batch_size, num_heads, max_length, head_size, chunk_size, stream); } -template <> -void LaunchCopyCrossQKSingleDecodeStep(cudaStream_t stream, float* cross_qk_buffer_data, float** qk_layer_pointers, int token_index, int batch_beam_size, int num_layers, int num_heads, int num_alignment_heads, const int* alignment_heads, int frames, int max_length) { GetCudaInterface()->LaunchCopyCrossQKSingleDecodeStep(stream, cross_qk_buffer_data, qk_layer_pointers, token_index, batch_beam_size, num_layers, num_heads, num_alignment_heads, alignment_heads, frames, max_length); } -template <> -void LaunchFinalizeCrossQK(cudaStream_t stream, int iteration_number, int context_decoding_len, int batch_size, int num_beams, int max_length, int num_alignment_heads, int frames_of_k, const float* cross_qk_buffer_data, float* cross_qk_output, int num_return_sequences, const int* cache_indir_data) { GetCudaInterface()->LaunchFinalizeCrossQK(stream, iteration_number, context_decoding_len, batch_size, num_beams, max_length, num_alignment_heads, frames_of_k, cross_qk_buffer_data, cross_qk_output, num_return_sequences, cache_indir_data); } -} // namespace cuda -#endif - std::string to_string(DeviceType device_type) { switch (device_type) { case DeviceType::CPU: @@ -179,8 +173,9 @@ std::string to_string(DeviceType device_type) { return "WebGpu"; case DeviceType::QNN: return "QnnWithSharedMemory"; + default: + throw std::runtime_error("Unknown device type"); } - throw std::runtime_error("Unknown device type"); } DeviceInterface* GetDeviceInterface(DeviceType type) { @@ -188,10 +183,16 @@ DeviceInterface* GetDeviceInterface(DeviceType type) { default: case DeviceType::CPU: return GetCpuInterface(); -#if USE_CUDA case DeviceType::CUDA: return GetCudaInterface(); +#if USE_DML + case DeviceType::DML: + return GetDmlInterface(); #endif + case DeviceType::WEBGPU: + return GetWebGPUInterface(); + case DeviceType::QNN: + return GetQNNInterface(); } } @@ -202,9 +203,7 @@ GeneratorParams::GeneratorParams(const Config& config) GeneratorParams::GeneratorParams(const Model& model) : config{*model.config_.get()}, - p_device{model.p_device_}, - device_type{model.device_type_}, - cuda_stream{model.cuda_stream_}, + p_device{model.p_device_inputs_}, is_cuda_graph_enabled_{IsCudaGraphEnabled(model.config_->model.decoder.session_options)} { use_cuda_graph = is_cuda_graph_enabled_; if (use_cuda_graph) { @@ -213,12 +212,12 @@ GeneratorParams::GeneratorParams(const Model& model) } void GeneratorParams::TryGraphCapture(int max_bs) { - if (!is_cuda_graph_enabled_ || device_type == DeviceType::CPU) { + if (!is_cuda_graph_enabled_ || p_device->GetType() == DeviceType::CPU) { // no-op return; } - if (DeviceType::CUDA == device_type || DeviceType::DML == device_type) { + if (DeviceType::CUDA == p_device->GetType() || DeviceType::DML == p_device->GetType()) { if (max_bs == 0) { throw std::runtime_error("Graph capture is enabled, but max_batch_size is not set."); } @@ -323,8 +322,8 @@ void Generator::AppendTokens(cpu_span input_ids) { constexpr std::array devices_supporting_continuous_decoding{DeviceType::CPU, DeviceType::CUDA, DeviceType::WEBGPU}; if (search_->GetSequenceLength() != 0 && std::none_of(devices_supporting_continuous_decoding.begin(), devices_supporting_continuous_decoding.end(), - [this](DeviceType device_type) { return device_type == state_->params_->device_type; })) - throw std::runtime_error("Continuous decoding is not supported on the selected device type (" + to_string(state_->params_->device_type) + + [this](DeviceType device_type) { return device_type == state_->params_->p_device->GetType(); })) + throw std::runtime_error("Continuous decoding is not supported on the selected device type (" + to_string(state_->params_->p_device->GetType()) + "). Please recreate the generator instance to avoid using continuous decoding."); if (last_action_ == Action::generated) { @@ -485,10 +484,3 @@ DeviceSpan Generator::GetSequence(size_t index) const { } } // namespace Generators - -#if USE_CUDA -cudaError_t cudaMemcpyAsync(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream) { return Generators::GetCudaInterface()->cudaMemcpyAsync(dst, src, count, kind, stream); } -cudaError_t cudaMemcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind) { return Generators::GetCudaInterface()->cudaMemcpy(dst, src, count, kind); } -cudaError_t cudaMemsetAsync(void* ptr, int value, size_t count, cudaStream_t stream) { return Generators::GetCudaInterface()->cudaMemsetAsync(ptr, value, count, stream); } -cudaError_t cudaMemset(void* ptr, int value, size_t count) { return Generators::GetCudaInterface()->cudaMemset(ptr, value, count); } -#endif diff --git a/src/generators.h b/src/generators.h index 87468bb03..50962b744 100644 --- a/src/generators.h +++ b/src/generators.h @@ -23,16 +23,10 @@ #include #include #include -#if USE_CUDA -#include -#else -// If we don't include cuda_runtime.h, we define this to avoid lots of extra #ifdefs -using cudaStream_t = void*; -#endif #include "leakcheck.h" -#include "smartptrs.h" #include "models/onnxruntime_api.h" +#include "smartptrs.h" #include "models/debugging.h" #include "config.h" #include "logging.h" @@ -49,20 +43,27 @@ struct Tokenizer; template DeviceSpan WrapTensor(DeviceInterface& device, OrtValue& value) { - return device.WrapMemory(std::span{value.GetTensorMutableData(), value.GetTensorTypeAndShapeInfo()->GetElementCount()}); + auto info = value.GetTensorTypeAndShapeInfo(); + assert(info->GetElementType() == Ort::TypeToTensorType>); + return device.WrapMemory(std::span{value.GetTensorMutableData(), info->GetElementCount()}); } -// OgaSequences are a vector of int32 vectors -using TokenSequences = std::vector>; +DeviceSpan ByteWrapTensor(DeviceInterface& device, OrtValue& value); + +template +struct OrtTensor { + OrtTensor(std::unique_ptr ort_value, DeviceInterface& device) + : ort_value_{std::move(ort_value)}, device_span_{WrapTensor(device, *ort_value_)} {} -enum struct DeviceType { - CPU, - CUDA, - DML, - WEBGPU, - QNN, + operator OrtValue*() { return ort_value_.get(); } + + std::unique_ptr ort_value_; + DeviceSpan device_span_; }; +// OgaSequences are a vector of int32 vectors +using TokenSequences = std::vector>; + std::string to_string(DeviceType device_type); DeviceInterface* GetDeviceInterface(DeviceType type); @@ -77,9 +78,7 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec bool use_cuda_graph{}; int BatchBeamSize() const { return search.num_beams * search.batch_size; } - DeviceInterface* p_device{}; - DeviceType device_type{DeviceType::CPU}; - cudaStream_t cuda_stream{}; + DeviceInterface* p_device{}; // Scoring device (usually CPU, but can be CUDA) cpu_span aux_input_ids{}; // Intermediate solution to be used with SetInputs function for multimodal and whisper models @@ -141,10 +140,8 @@ struct OrtGlobals { OrtGlobals(); std::unique_ptr env_; -#if USE_CUDA - std::unique_ptr memory_info_cuda_; - std::unique_ptr allocator_cuda_; -#endif + std::unique_ptr allocator_device_[static_cast(DeviceType::MAX)]; + private: OrtGlobals(const OrtGlobals&) = delete; void operator=(const OrtGlobals&) = delete; @@ -160,6 +157,9 @@ std::shared_ptr CreateGeneratorParams(const Model& model); std::shared_ptr CreateGeneratorParams(const Config& config); // For benchmarking purposes only std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params); +// Fallback to copy between two separate device buffers by going through CPU memory (slow unless we're the CPU device) +void CopyThroughCpu(DeviceBuffer& dest, size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes); + float Float16ToFloat32(uint16_t v); // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction } // namespace Generators diff --git a/src/models/adapters.cpp b/src/models/adapters.cpp index 3840c9b88..13791386c 100644 --- a/src/models/adapters.cpp +++ b/src/models/adapters.cpp @@ -34,9 +34,9 @@ void Adapters::LoadAdapter(const char* adapter_file_path, const std::string& ada } adapters_.emplace(adapter_name, std::make_unique(adapter_file_path, - model_->allocator_device_ == &model_->allocator_cpu_ - ? nullptr - : model_->allocator_device_)); + model_->p_device_->GetType() == DeviceType::CUDA + ? &model_->p_device_->GetAllocator() + : nullptr)); } void Adapters::UnloadAdapter(const std::string& adapter_name) { diff --git a/src/models/audio_processor.cpp b/src/models/audio_processor.cpp index 1c6c88d14..70c8493fb 100644 --- a/src/models/audio_processor.cpp +++ b/src/models/audio_processor.cpp @@ -30,7 +30,7 @@ std::unique_ptr ProcessMel(ort_extensions::OrtxObjectPtr& allocator.GetInfo(), std::span(const_cast(mel_data), input_features_value->GetTensorTypeAndShapeInfo()->GetElementCount()), shape_span); - ConvertFp32ToFp16(allocator, *input_features_fp32, input_features_value, DeviceType::CPU, nullptr); + Cast(*input_features_fp32, input_features_value, *GetDeviceInterface(DeviceType::CPU), Ort::TypeToTensorType); } return input_features_value; diff --git a/src/models/captured_graph_pool.cpp b/src/models/captured_graph_pool.cpp index 84c8bec11..a5cea3701 100644 --- a/src/models/captured_graph_pool.cpp +++ b/src/models/captured_graph_pool.cpp @@ -19,7 +19,7 @@ void CapturedGraphInfoRecycler::operator()(CapturedGraphInfo* captured_graph_inf } CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, const GeneratorParams& params) const { - if (!params.use_cuda_graph || (model.device_type_ != DeviceType::CUDA && model.device_type_ != DeviceType::DML)) { + if (!params.use_cuda_graph || (model.p_device_->GetType() != DeviceType::CUDA)) { return nullptr; } @@ -48,12 +48,6 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, size_t max_beam_batch_size = static_cast(params.search.num_beams) * params.max_batch_size; new_captured_graph->sb_input_ids_ = std::make_unique(allocator_device_, max_beam_batch_size); -#if USE_DML - if (model.device_type_ == DeviceType::DML) { - new_captured_graph->sb_input_ids_int32_ = std::make_unique(allocator_device_, max_beam_batch_size); - } -#endif - // Create the static buffers for the cache int layer_count = config_->model.decoder.num_hidden_layers; new_captured_graph->sb_kv_caches_.reserve(layer_count * 2); @@ -70,13 +64,6 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, // Create the static buffer for the attention mask, if needed if (session_info_->HasInput(config_->model.decoder.inputs.attention_mask)) { new_captured_graph->sb_attention_mask_ = std::make_unique(allocator_device_, max_beam_batch_size); - -#if USE_DML - // DML currently needs an additional static buffer for the mask - if (model.device_type_ == DeviceType::DML) { - new_captured_graph->sb_attention_mask_next_ = std::make_unique(allocator_device_, max_beam_batch_size); - } -#endif } auto output_type = session_info_->GetOutputDataType(config_->model.decoder.outputs.logits); diff --git a/src/models/captured_graph_pool.h b/src/models/captured_graph_pool.h index 42e3be51d..4d405018e 100644 --- a/src/models/captured_graph_pool.h +++ b/src/models/captured_graph_pool.h @@ -145,11 +145,6 @@ struct CapturedGraphInfo { std::unique_ptr sb_embeddings_; std::unique_ptr key_; -#if USE_DML - std::unique_ptr sb_attention_mask_next_; - std::unique_ptr sb_input_ids_int32_; -#endif - // Generates a unique annotation ID across different captured graph objects. This is necessary because different // generators could be alive at the same time and run the same batch size but with different static buffers, so // they need to have different annotation IDs. diff --git a/src/models/debugging.cpp b/src/models/debugging.cpp index f986b8688..0c6a60329 100644 --- a/src/models/debugging.cpp +++ b/src/models/debugging.cpp @@ -3,15 +3,7 @@ #include "../generators.h" #include "utils.h" #include - -#if USE_CUDA -#include "../cuda/cuda_common.h" -#endif - -#if USE_DML -#include "../dml/dml_helpers.h" #include "model.h" -#endif namespace Generators { static constexpr size_t c_value_count = 10; // Dump this many values from the start of a tensor @@ -88,48 +80,22 @@ void DumpTensor(const Model& model, std::ostream& stream, OrtValue* value, bool stream << SGR::Fg_Green << " Location: " << SGR::Reset; const auto& memory_info = value->GetTensorMemoryInfo(); - auto device_type = memory_info.GetDeviceType(); - if (device_type == OrtMemoryInfoDeviceType_CPU) { - stream << "CPU\r\n"; - DumpValues(stream, type_info->GetElementType(), value->GetTensorRawData(), element_count); - } else if (device_type == OrtMemoryInfoDeviceType_GPU) { - stream << "GPU\r\n"; -#if USE_CUDA - auto type = type_info->GetElementType(); - size_t element_size = SizeOf(type); - auto cpu_copy = std::make_unique(element_size * element_count); - CudaCheck() == cudaMemcpy(cpu_copy.get(), value->GetTensorRawData(), element_size * element_count, cudaMemcpyDeviceToHost); - DumpValues(stream, type, cpu_copy.get(), element_count); -#else - throw std::runtime_error("Unexpected error. Trying to access GPU memory but the project is not compiled with CUDA."); -#endif - } else if (static_cast(device_type) == 4) { - stream << "DML\r\n"; -#if USE_DML - auto type = type_info->GetElementType(); - size_t element_size = SizeOf(type); - auto cpu_copy = std::make_unique(element_size * element_count); - - if (value->GetTensorMutableRawData()) { - ComPtr gpu_resource; - Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation( - model.allocator_device_, - value->GetTensorMutableRawData(), - &gpu_resource)); - - model.GetDmlReadbackHeap()->ReadbackFromGpu( - std::span(cpu_copy.get(), element_size * element_count), - gpu_resource.Get(), - 0, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS); + switch (memory_info.GetDeviceType()) { + case OrtMemoryInfoDeviceType_CPU: + stream << "CPU\r\n"; + DumpValues(stream, type_info->GetElementType(), value->GetTensorRawData(), element_count); + break; + case OrtMemoryInfoDeviceType_GPU: { + stream << "GPU\r\n"; + auto type = type_info->GetElementType(); + auto tensor_span = std::span{const_cast(value)->GetTensorMutableData(), SizeOf(type) * element_count}; + auto device_span = model.p_device_->WrapMemory(tensor_span); + DumpValues(stream, type, device_span.CopyDeviceToCpu().data(), element_count); + break; } - - DumpValues(stream, type, cpu_copy.get(), element_count); -#else - throw std::runtime_error("Unexpected error. Trying to access DML memory but the project is not compiled with DML."); -#endif - } else { - stream << "Unhandled device type: " << static_cast(device_type) << "\r\n"; + default: + stream << "Unhandled device type: " << static_cast(memory_info.GetDeviceType()) << "\r\n"; + break; } } diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index b18fea179..8ae9246f9 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -12,7 +12,7 @@ DecoderOnlyPipelineModel::DecoderOnlyPipelineModel(std::unique_ptr confi sessions_.emplace_back(OrtSession::Create(ort_env, (config_->config_path / fs::path(model.filename)).c_str(), GetSessionOptions(model.model_id))); - if (!allocator_device_ && model.session_options.has_value()) { + if (!p_device_inputs_ && model.session_options.has_value()) { const auto& provider_options = (*model.session_options).provider_options; if (std::any_of(provider_options.begin(), provider_options.end(), [](const auto& elem) { return !elem.name.empty(); })) { @@ -21,7 +21,7 @@ DecoderOnlyPipelineModel::DecoderOnlyPipelineModel(std::unique_ptr confi } } - if (!allocator_device_) { + if (!p_device_inputs_) { // If the device allocator has not been created, it implies all // sessions are configured to run on CPU. // Pick any session to create the device allocator. @@ -58,9 +58,9 @@ bool IntermediatePipelineState::HasOutput(std::string_view name) const { } bool IntermediatePipelineState::SupportsPrimaryDevice() const { - if (model_.device_type_ == DeviceType::CPU || model_.device_type_ == DeviceType::QNN) { + if (model_.p_device_->GetType() == DeviceType::CPU || model_.p_device_->GetType() == DeviceType::QNN) { return true; - } else if (model_.device_type_ == DeviceType::CUDA) { + } else if (model_.p_device_->GetType() == DeviceType::CUDA) { if (!model_.config_->model.decoder.pipeline[id_].session_options.has_value()) { // No session options, so this session uses the default session options. // Default session options supports the cuda device type. @@ -134,7 +134,7 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan if (!pipeline_state->SupportsPrimaryDevice()) { std::ostringstream oss; oss << "Managed input " << input_name << " resides on the primary device type (" - << to_string(model_.device_type_) << "). " + << to_string(model_.p_device_->GetType()) << "). " << "But the pipeline model " << model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id << " is expecting it to reside elsewhere."; @@ -159,7 +159,7 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan if (!pipeline_state->SupportsPrimaryDevice()) { std::ostringstream oss; oss << "Managed output " << output_name << " resides on the primary device type (" - << to_string(model_.device_type_) << "). " + << to_string(model_.p_device_->GetType()) << "). " << "But the pipeline model " << model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id << " is expecting it to reside elsewhere."; @@ -178,7 +178,7 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan if (!pipeline_state->SupportsPrimaryDevice()) { std::ostringstream oss; oss << "Managed input " << input_name << " resides on the primary device type (" - << to_string(model_.device_type_) << "). " + << to_string(model_.p_device_->GetType()) << "). " << "But the pipeline model " << model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id << " is expecting it to reside elsewhere."; diff --git a/src/models/embeddings.cpp b/src/models/embeddings.cpp index 01b494078..4b38db207 100644 --- a/src/models/embeddings.cpp +++ b/src/models/embeddings.cpp @@ -25,7 +25,7 @@ Embeddings::Embeddings(State& state, Embeddings::Mode mode, const std::string& n sb_embeddings_ = state_.GetCapturedGraphInfo()->sb_embeddings_.get(); } - embeddings_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + embeddings_ = OrtValue::CreateTensor(model_.p_device_->GetAllocator(), shape_, type_); } } @@ -54,7 +54,7 @@ void Embeddings::UpdateSequenceLength(size_t new_length) { if (mode_ == Embeddings::Mode::Input) { if (!sb_embeddings_) { - embeddings_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + embeddings_ = OrtValue::CreateTensor(model_.p_device_->GetAllocator(), shape_, type_); } else { embeddings_ = sb_embeddings_->CreateTensorOnStaticBuffer(shape_, type_); } diff --git a/src/models/extra_inputs.cpp b/src/models/extra_inputs.cpp index 6fabefcce..7c4f78b3f 100644 --- a/src/models/extra_inputs.cpp +++ b/src/models/extra_inputs.cpp @@ -1,7 +1,6 @@ #include "../generators.h" #include "model.h" #include "extra_inputs.h" -#include "kernels.h" namespace Generators { @@ -68,11 +67,6 @@ ExtraInputs::ExtraInputs(State& state) } } -#pragma warning(push) -#pragma warning(disable : 4065) // switch statement contains 'default' but no 'case' labels -#pragma warning(disable : 4189) // local variable is initialized but not referenced -#pragma warning(disable : 4702) // unreachable code - void ExtraInputs::Add() { // Add extra user inputs for (int i = 0; i < state_.params_->extra_inputs.size(); ++i) { @@ -82,46 +76,13 @@ void ExtraInputs::Add() { // Copy the data from the CPU-backed ORT value to the static buffers for (int i = 0; i < sb_extra_inputs_.size(); ++i) { - auto type_and_shape_info = extra_inputs_[i]->GetTensorTypeAndShapeInfo(); - auto shape = type_and_shape_info->GetShape(); - auto element_count = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); - auto copy_size_in_bytes = element_count * SizeOf(type_and_shape_info->GetElementType()); - - switch (model_.device_type_) { -#if USE_DML - case DeviceType::DML: { - ComPtr target_resource; - Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, extra_inputs_[i]->GetTensorMutableRawData(), &target_resource)); - - auto source = std::span(state_.params_->extra_inputs[i].tensor->ort_tensor_->GetTensorData(), copy_size_in_bytes); - - model_.GetDmlUploadHeap()->BeginUploadToGpu( - target_resource.Get(), - 0, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - source); - } break; -#endif - -#if USE_CUDA - case DeviceType::CUDA: { - cudaMemcpyAsync( - extra_inputs_[i]->GetTensorMutableRawData(), - state_.params_->extra_inputs[i].tensor->ort_tensor_->GetTensorMutableRawData(), - copy_size_in_bytes, - cudaMemcpyHostToDevice, - model_.cuda_stream_); - } break; -#endif - - default: - throw std::runtime_error("Unsupported device for graph capture"); - } + auto tensor = ByteWrapTensor(*model_.p_device_, *extra_inputs_[i]); + auto source = std::span{state_.params_->extra_inputs[i].tensor->ort_tensor_->GetTensorData(), tensor.size()}; + copy(source, tensor.CpuSpan()); + tensor.CopyCpuToDevice(); } registrar_.Add(); } -#pragma warning(pop) - } // namespace Generators diff --git a/src/models/image_features.cpp b/src/models/image_features.cpp index 71dee99fa..6f51abd0f 100644 --- a/src/models/image_features.cpp +++ b/src/models/image_features.cpp @@ -26,7 +26,7 @@ ImageFeatures::ImageFeatures(State& state, ImageFeatures::Mode mode, const std:: // 4) Created as an input for embedding model (num_image_tokens = 0) // The tensor does not need to be pre-allocated because it will be created during (2). if (mode == ImageFeatures::Mode::Output) { - image_features_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + image_features_ = OrtValue::CreateTensor(model_.p_device_->GetAllocator(), shape_, type_); } } @@ -50,7 +50,7 @@ void ImageFeatures::Update(bool is_prompt) { // num_image_tokens will be 0 when no image is provided if (!is_prompt && shape_[0] > 0) { // if num_image_tokens > 0 shape_[0] = 0; - image_features_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + image_features_ = OrtValue::CreateTensor(model_.p_device_->GetAllocator(), shape_, type_); state_.inputs_[index_] = image_features_.get(); } } diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index a309b4f85..6a32a6233 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -1,14 +1,13 @@ #include "../generators.h" #include "model.h" #include "input_ids.h" -#include "kernels.h" namespace Generators { DefaultInputIDs::DefaultInputIDs(State& state) : state_{state} { name_ = model_.config_->model.decoder.inputs.input_ids.c_str(); - shape_ = {state_.params_->search.batch_size, 0}; + shape_ = {state_.params_->BatchBeamSize(), 0}; type_ = model_.session_info_->GetInputDataType(name_); if (model_.session_info_->HasInput(model_.config_->model.decoder.inputs.current_sequence_length) && @@ -45,182 +44,55 @@ void DefaultInputIDs::Add() { } } -void DefaultInputIDs::Update(DeviceSpan& new_tokens) { - // There are three scopes involved when the Update function is called: - // 1. A new Generator state has been just created. This is a prompt stage, and value_ is a nullptr. - // i.e. this is the very first time ever that Update is being called for this Generator. - // 2. We move to the token generation stage. value_ has already been previously created in the prompt stage. - // Update is called on every new token generated. - // 3. We move from the token generation stage back to the prompt stage (e.g. in continous decoding). value_ is already created. - - // For instances where the value_ is not created, we need handle graph capture correctly. - // For subsequent prompt stages, the limiting factor is that the subsequent prompts can not - // be larger than the first prompt (when graph capture is enabled). - if (!value_) { - shape_[1] = static_cast(new_tokens.size()) / shape_[0]; - - // If 64-bit, convert from 32-bit to 64-bit - auto input_ids = new_tokens.CopyDeviceToCpu(); - if (type_ == Ort::TypeToTensorType) { - value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_); - auto* p_data = value_->GetTensorMutableData(); - for (auto v : input_ids) { - *p_data++ = v; - } - } else { - if (type_ != Ort::TypeToTensorType) - throw std::runtime_error("InputIDs must be int64 or int32"); - value_ = OrtValue::CreateTensor(model_.allocator_cpu_.GetInfo(), input_ids, shape_); - } - - value_ = model_.ExpandInputs(value_, state_.params_->search.num_beams); - shape_[0] *= state_.params_->search.num_beams; - - if (state_.GetCapturedGraphInfo()) { - sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get(); - -#if USE_DML - if (model_.device_type_ == DeviceType::DML) { - sb_input_ids_int32_ = state_.GetCapturedGraphInfo()->sb_input_ids_int32_.get(); - } -#endif - } +void DefaultInputIDs::Update(DeviceSpan new_tokens) { + auto new_tokens_cpu = new_tokens.CopyDeviceToCpu(); - is_prompt_ = false; - state_.inputs_[input_index_] = value_.get(); - return; - } - - const auto get_unpadded_sequence_length = [](std::span input_ids, - int32_t pad_token_id) { - int32_t seq_length = 0; + const auto get_unpadded_sequence_length = [](std::span input_ids, int32_t pad_token_id) { for (int32_t i = 0; i < input_ids.size(); i++) { - if (input_ids[i] == pad_token_id) { - break; - } - seq_length++; + if (input_ids[i] == pad_token_id) + return i; } - return seq_length; + return static_cast(input_ids.size()); }; if (current_sequence_length_ && past_sequence_length_) { if (state_.params_->BatchBeamSize() != 1) { throw std::runtime_error("Batch size must be 1 for current_sequence_length and past_sequence_length inputs"); } - auto new_sequence_length = get_unpadded_sequence_length(new_tokens.CpuSpan(), model_.config_->model.pad_token_id); + auto new_sequence_length = get_unpadded_sequence_length(new_tokens_cpu, model_.config_->model.pad_token_id); *current_sequence_length_->GetTensorMutableData() += new_sequence_length; *past_sequence_length_->GetTensorMutableData() += new_sequence_length; } - // Resize input_ids shape based on new_tokens - // For beam search + // For beam search, resize input_ids shape based on new_tokens size_t sequence_length = static_cast(new_tokens.size()) / state_.params_->BatchBeamSize(); if (is_prompt_ && state_.params_->search.num_beams > 1) sequence_length = static_cast(new_tokens.size()) / state_.params_->search.batch_size; if (static_cast(shape_[1]) != sequence_length) { shape_[1] = sequence_length; - if (!sb_input_ids_) { - value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); - -#if USE_DML - if (model_.device_type_ == DeviceType::DML) { - value_int32_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); - } -#endif - } else { - value_ = sb_input_ids_->CreateTensorOnStaticBuffer(shape_, type_); - -#if USE_DML - if (model_.device_type_ == DeviceType::DML) { - value_int32_ = sb_input_ids_int32_->CreateTensorOnStaticBuffer(shape_, Ort::TypeToTensorType); - } -#endif - } - + value_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_); state_.inputs_[input_index_] = value_.get(); } - // Update input_ids with next tokens, converting from 32-bit to 64-bit - if (type_ == Ort::TypeToTensorType) { - switch (model_.device_type_) { - case DeviceType::CUDA: { -#if USE_CUDA - auto* data = value_->GetTensorMutableData(); - auto next_tokens = new_tokens.Span(); - // For beam search - if (is_prompt_ && state_.params_->search.num_beams > 1) - cuda::LaunchExpandAndInt32ToInt64(next_tokens.data(), data, state_.params_->search.num_beams, state_.params_->search.batch_size, static_cast(sequence_length), model_.cuda_stream_); - else - cuda::LaunchInt32ToInt64(next_tokens.data(), data, static_cast(next_tokens.size()), model_.cuda_stream_); -#endif - } break; - - case DeviceType::DML: { -#if USE_DML - ComPtr source_resource; - Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, value_int32_->GetTensorMutableRawData(), &source_resource)); - - auto source = std::span( - reinterpret_cast(new_tokens.CpuSpan().data()), - new_tokens.CpuSpan().size_bytes()); - - model_.GetDmlUploadHeap()->BeginUploadToGpu( - source_resource.Get(), - 0, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - source); - - DmlHelpers::DmlCastInputToOutput( - model_.GetDmlExecutionContext(), - *model_.allocator_device_, - *value_int32_, - value_, - model_.GetDmlDevice(), - model_.GetOrtDmlApi(), - input_ids_cast_command_list_state_); -#endif - } break; - default: { - // CPU, WEBGPU - auto* data = value_->GetTensorMutableData(); - auto next_tokens = new_tokens.Span(); - for (int b = 0; b < shape_[0]; b++) { - for (int i = 0; i < shape_[1]; i++) { - // For beam search - int32_t next_token; - if (is_prompt_ && state_.params_->search.num_beams > 1) - next_token = next_tokens[(b / state_.params_->search.num_beams) * shape_[1] + i]; - else - next_token = next_tokens[b * shape_[1] + i]; - data[b * shape_[1] + i] = next_token; - } - } - } + // Update input_ids with next tokens + auto data_span = WrapTensor(*model_.p_device_inputs_, *value_); + + // For beam search + if (is_prompt_ && state_.params_->search.num_beams > 1) { + int row_size = static_cast(shape_[1]); + for (int b = 0; b < shape_[0]; b++) { + int in_offset = (b / state_.params_->search.num_beams) * row_size; + int out_offset = b * row_size; + data_span.subspan(out_offset, row_size).CopyFrom(new_tokens.subspan(in_offset, row_size)); } } else { - auto* data = value_->GetTensorMutableData(); -#if USE_CUDA - if (model_.device_type_ == DeviceType::CUDA) { - if (is_prompt_ && state_.params_->search.num_beams > 1) { - cuda::LaunchExpand(new_tokens.Span().data(), data, state_.params_->search.num_beams, state_.params_->search.batch_size, static_cast(sequence_length), model_.cuda_stream_); - } else { - cudaMemcpyAsync(data, new_tokens.Span().data(), shape_[0] * shape_[1] * sizeof(int32_t), cudaMemcpyDeviceToDevice, model_.cuda_stream_); - } - } else -#endif - { - // For beam search - if (is_prompt_ && state_.params_->search.num_beams > 1) { - for (int b = 0; b < shape_[0]; b++) { - int in_offset = (b / state_.params_->search.num_beams) * static_cast(shape_[1]); - int out_offset = b * static_cast(shape_[1]); - memcpy(data + out_offset, new_tokens.Span().data() + in_offset, shape_[1] * sizeof(int32_t)); - } - } else { - memcpy(data, new_tokens.Span().data(), shape_[0] * shape_[1] * sizeof(int32_t)); - } - } + data_span.CopyFrom(new_tokens); + } + + if (type_ == Ort::TypeToTensorType) { + Cast(*value_, cast_value_, *model_.p_device_inputs_, type_); + state_.inputs_[input_index_] = cast_value_.get(); } is_prompt_ = false; @@ -253,7 +125,7 @@ void WindowedInputIDs::Add() { state_.input_names_.push_back(name_); } -void WindowedInputIDs::Update(DeviceSpan& new_tokens) { +void WindowedInputIDs::Update(DeviceSpan new_tokens) { if (window_index_ == 0) { num_windows_ = (new_tokens.size() + window_size_ - 1) / window_size_; diff --git a/src/models/input_ids.h b/src/models/input_ids.h index d7a229911..fd4b159e6 100644 --- a/src/models/input_ids.h +++ b/src/models/input_ids.h @@ -8,7 +8,7 @@ struct InputIDs { virtual ~InputIDs() = default; virtual void Add() = 0; virtual std::array GetShape() const = 0; - virtual void Update(DeviceSpan& next_tokens) = 0; + virtual void Update(DeviceSpan next_tokens) = 0; }; struct DefaultInputIDs : InputIDs { @@ -21,7 +21,7 @@ struct DefaultInputIDs : InputIDs { void Add() override; // Resize input_ids based on size of next_tokens. // Update value with next_tokens. - void Update(DeviceSpan& next_tokens) override; + void Update(DeviceSpan next_tokens) override; std::array GetShape() const override { return shape_; } const char* name_; @@ -38,15 +38,7 @@ struct DefaultInputIDs : InputIDs { std::array shape_{}; ONNXTensorElementDataType type_; std::unique_ptr value_; - - // Used for decoding runs with cuda graphs. - StaticBuffer* sb_input_ids_{}; - -#if USE_DML - std::unique_ptr value_int32_; - StaticBuffer* sb_input_ids_int32_{}; - DmlReusedCommandListState input_ids_cast_command_list_state_{}; -#endif + std::unique_ptr cast_value_; std::unique_ptr current_sequence_length_; std::unique_ptr past_sequence_length_; @@ -65,7 +57,7 @@ struct WindowedInputIDs : public InputIDs { WindowedInputIDs& operator=(const WindowedInputIDs&) = delete; void Add() override; - void Update(DeviceSpan& next_tokens) override; + void Update(DeviceSpan next_tokens) override; std::array GetShape() const override { return shape_; } private: diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index b8dbc3b6b..994e883ce 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -46,11 +46,11 @@ CombinedKeyValueCache::CombinedKeyValueCache(State& state) // Derive the KV data type from the KV input 0 type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]); - empty_past_ = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_); + empty_past_ = OrtValue::CreateTensor(Allocator(), shape_, type_); shape_[3] = 0; for (int i = 0; i < layer_count_; ++i) { - presents_.push_back(OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_)); + presents_.push_back(OrtValue::CreateTensor(Allocator(), shape_, type_)); } } @@ -82,7 +82,7 @@ void CombinedKeyValueCache::Update(DeviceSpan beam_indices, int total_l shape_[3] = total_length; for (int i = 0; i < layer_count_; i++) { - presents_[i] = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_); + presents_[i] = OrtValue::CreateTensor(Allocator(), shape_, type_); state_.outputs_[output_index_ + i] = presents_[i].get(); } @@ -119,22 +119,14 @@ void CombinedKeyValueCache::RewindPastTensorsTo(size_t index) { for (int i = 0; i < layer_count_; i++) { OrtValue& present = *presents_[i]; - std::unique_ptr past = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_); + std::unique_ptr past = OrtValue::CreateTensor(Allocator(), shape_, type_); + auto present_span = WrapTensor(Device(), present); + auto past_span = WrapTensor(Device(), *past); + for (int j = 0; j < 2 * batch_x_num_heads; j++) { - auto present_data = present.GetTensorData() + j * old_length_x_head_size; - auto past_data = past->GetTensorMutableData() + j * new_length_x_head_size; -#if USE_CUDA - if (model_.device_type_ == DeviceType::CUDA) { - cudaMemcpyAsync(past_data, present_data, new_length_x_head_size * sizeof(T), cudaMemcpyDeviceToDevice, model_.cuda_stream_); - } else -#elif USE_DML - if (model_.device_type_ == DeviceType::DML) { - // TODO: Implement DML version - } else -#endif - { - copy(std::span(present_data, new_length_x_head_size), std::span(past_data, new_length_x_head_size)); - } + auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size); + auto past_data = past_span.subspan(j * new_length_x_head_size, new_length_x_head_size); + past_data.CopyFrom(present_data); } pasts_[i] = std::move(past); state_.inputs_[input_index_ + i] = pasts_[i].get(); @@ -147,38 +139,22 @@ void CombinedKeyValueCache::PickPastState(DeviceSpan beam_indices_devic std::span beam_indices = beam_indices_device.CopyDeviceToCpu(); auto block_size_per_beam = shape_[2] * shape_[3] * shape_[4]; auto past_key_size = shape_[1] * block_size_per_beam; - auto element_count = shape_[0] * past_key_size; - - const OrtValue& present = *presents_[index]; - std::unique_ptr past = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_); - auto past_span = std::span(past->GetTensorMutableData(), element_count); - auto present_span = std::span(present.GetTensorData(), element_count); - -#if USE_CUDA - if (model_.device_type_ == DeviceType::CUDA) { - for (size_t j = 0; j < beam_indices.size(); j++) { - int32_t beam_index = beam_indices[j]; - auto present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam); - auto present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam); - - auto past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam); - auto past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam); - cudaMemcpyAsync(past_key.data(), present_key.data(), present_key.size_bytes(), cudaMemcpyDeviceToDevice, model_.cuda_stream_); - cudaMemcpyAsync(past_value.data(), present_value.data(), present_value.size_bytes(), cudaMemcpyDeviceToDevice, model_.cuda_stream_); - } - } else -#endif - { - for (size_t j = 0; j < beam_indices.size(); j++) { - int32_t const beam_index = beam_indices[j]; - auto present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam); - auto present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam); - - auto past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam); - auto past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam); - copy(present_key, past_key); - copy(present_value, past_value); - } + + OrtValue& present = *presents_[index]; + std::unique_ptr past = OrtValue::CreateTensor(Allocator(), shape_); + + auto past_span = WrapTensor(Device(), *past); + auto present_span = WrapTensor(Device(), present); + + for (size_t j = 0; j < beam_indices.size(); j++) { + int32_t beam_index = beam_indices[j]; + auto present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam); + auto present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam); + + auto past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam); + auto past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam); + past_key.CopyFrom(present_key); + past_value.CopyFrom(present_value); } pasts_[index] = std::move(past); @@ -214,7 +190,7 @@ DefaultKeyValueCache::DefaultKeyValueCache(State& state) // Derive the KV data type from the KV input 0 type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]); - empty_past_ = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_); + empty_past_ = OrtValue::CreateTensor(Allocator(), shape_, type_); // Set the size after empty_past_ has been created with 0 for this field if (past_present_share_buffer_) { @@ -228,23 +204,13 @@ DefaultKeyValueCache::DefaultKeyValueCache(State& state) } } - auto kv_cache_size_bytes = SizeOf(type_) * shape_[0] * shape_[1] * shape_[2] * shape_[3]; try { for (int i = 0; i < layer_count_ * 2; ++i) { presents_.push_back( - sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_) + sb_kv_caches_.empty() ? OrtValue::CreateTensor(Allocator(), shape_, type_) : sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_)); -#if USE_CUDA - if (model_.device_type_ == DeviceType::CUDA) { - cudaMemsetAsync(presents_.back()->GetTensorMutableRawData(), 0, kv_cache_size_bytes, model_.cuda_stream_); - } else -#endif - { - if (model_.device_type_ == DeviceType::CPU) { - // FIXME: this is a device ternsor and we can only use memset for cpu. Revisit for other EPs. - memset(presents_.back()->GetTensorMutableRawData(), 0, kv_cache_size_bytes); - } - } + // Zero the memory so we don't leak any data from the previous run + ByteWrapTensor(Device(), *presents_.back()).Zero(); } } catch (const Ort::Exception&) { std::ostringstream oss; @@ -303,7 +269,7 @@ void DefaultKeyValueCache::Update(DeviceSpan beam_indices, int total_le shape_[2] = total_length; for (int i = 0; i < layer_count_ * 2; i++) { - presents_[i] = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_); + presents_[i] = OrtValue::CreateTensor(Allocator(), shape_, type_); state_.outputs_[output_index_ + i] = presents_[i].get(); } @@ -342,22 +308,15 @@ void DefaultKeyValueCache::RewindPastTensorsTo(size_t index) { for (int i = 0; i < layer_count_ * 2; i++) { OrtValue& present = *presents_[i]; - std::unique_ptr past = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_); + std::unique_ptr past = OrtValue::CreateTensor(Allocator(), shape_, type_); + + auto past_span = WrapTensor(Device(), *past); + auto present_span = WrapTensor(Device(), present); + for (int j = 0; j < batch_x_num_heads; j++) { - auto present_data = present.GetTensorData() + j * old_length_x_head_size; - auto past_data = past->GetTensorMutableData() + j * new_length_x_head_size; -#if USE_CUDA - if (model_.device_type_ == DeviceType::CUDA) { - cudaMemcpyAsync(past_data, present_data, new_length_x_head_size * sizeof(T), cudaMemcpyDeviceToDevice, model_.cuda_stream_); - } else -#elif USE_DML - if (model_.device_type_ == DeviceType::DML) { - // TODO: Implement DML copy - } else -#endif - { - copy(std::span(present_data, new_length_x_head_size), std::span(past_data, new_length_x_head_size)); - } + auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size); + auto past_data = past_span.subspan(j * new_length_x_head_size, new_length_x_head_size); + past_data.CopyFrom(present_data); } pasts_[i] = std::move(past); state_.inputs_[input_index_ + i] = pasts_[i].get(); @@ -369,30 +328,18 @@ template void DefaultKeyValueCache::PickPastState(DeviceSpan beam_indices_device, int index) { std::span beam_indices = beam_indices_device.CopyDeviceToCpu(); auto block_size_per_beam = shape_[1] * shape_[2] * shape_[3]; - auto element_count = shape_[0] * block_size_per_beam; - - const OrtValue& present_value = *presents_[index]; - std::unique_ptr past_value = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_); - auto past_span = std::span(past_value->GetTensorMutableData(), element_count); - auto present_span = std::span(present_value.GetTensorData(), element_count); - -#if USE_CUDA - if (model_.device_type_ == DeviceType::CUDA) { - for (size_t j = 0; j < beam_indices.size(); j++) { - int32_t beam_index = beam_indices[j]; - auto present = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam); - auto past = past_span.subspan(j * block_size_per_beam, block_size_per_beam); - cudaMemcpyAsync(past.data(), present.data(), present.size_bytes(), cudaMemcpyDeviceToDevice, model_.cuda_stream_); - } - } else -#endif - { - for (size_t j = 0; j < beam_indices.size(); j++) { - int32_t const beam_index = beam_indices[j]; - auto present = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam); - auto past = past_span.subspan(j * block_size_per_beam, block_size_per_beam); - copy(present, past); - } + + OrtValue& present_value = *presents_[index]; + std::unique_ptr past_value = OrtValue::CreateTensor(Allocator(), shape_); + + auto past_span = WrapTensor(Device(), *past_value); + auto present_span = WrapTensor(Device(), present_value); + + for (size_t j = 0; j < beam_indices.size(); j++) { + int32_t beam_index = beam_indices[j]; + auto present = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam); + auto past = past_span.subspan(j * block_size_per_beam, block_size_per_beam); + past.CopyFrom(present); } pasts_[index] = std::move(past_value); @@ -424,8 +371,8 @@ CrossCache::CrossCache(State& state) type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]); for (int i = 0; i < layer_count_; ++i) { - values_.push_back(OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_)); - values_.push_back(OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_)); + values_.push_back(OrtValue::CreateTensor(Allocator(), shape_, type_)); + values_.push_back(OrtValue::CreateTensor(Allocator(), shape_, type_)); } } @@ -475,21 +422,21 @@ WindowedKeyValueCache::WindowedKeyValueCache(State& state) for (int i = 0; i < layer_count_; ++i) { key_caches_in_.push_back( - OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_in_, type_)); + OrtValue::CreateTensor(Allocator(), key_cache_shape_in_, type_)); std::fill_n(key_caches_in_[i]->GetTensorMutableData(), ElementCountFromShape(key_cache_shape_in_), static_cast(model_.config_->model.decoder.sliding_window->pad_value)); value_caches_in_.push_back( - OrtValue::CreateTensor(*model_.allocator_device_, value_cache_shape_in_, type_)); + OrtValue::CreateTensor(Allocator(), value_cache_shape_in_, type_)); std::fill_n(value_caches_in_[i]->GetTensorMutableData(), ElementCountFromShape(value_cache_shape_in_), static_cast(model_.config_->model.decoder.sliding_window->pad_value)); key_caches_out_.push_back( - OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_out_, type_)); + OrtValue::CreateTensor(Allocator(), key_cache_shape_out_, type_)); value_caches_out_.push_back( - OrtValue::CreateTensor(*model_.allocator_device_, value_cache_shape_out_, type_)); + OrtValue::CreateTensor(Allocator(), value_cache_shape_out_, type_)); } } @@ -601,7 +548,7 @@ void WindowedKeyValueCache::Update(DeviceSpan beam_indices, int current ThreadPool thread_pool{static_cast(layer_count_)}; thread_pool.Compute([&](size_t layer_idx) { - std::unique_ptr key_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_in, type_); + std::unique_ptr key_cache = OrtValue::CreateTensor(Allocator(), updated_key_cache_shape_in, type_); uint8_t* key_cache_data = key_cache->GetTensorMutableData(); uint8_t* key_cache_in_data = key_caches_in_[layer_idx]->GetTensorMutableData(); @@ -627,9 +574,9 @@ void WindowedKeyValueCache::Update(DeviceSpan beam_indices, int current } key_caches_in_[layer_idx] = std::move(key_cache); - key_caches_out_[layer_idx] = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_out, type_); + key_caches_out_[layer_idx] = OrtValue::CreateTensor(Allocator(), updated_key_cache_shape_out, type_); - std::unique_ptr value_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_in, type_); + std::unique_ptr value_cache = OrtValue::CreateTensor(Allocator(), updated_value_cache_shape_in, type_); uint8_t* value_cache_data = value_cache->GetTensorMutableData(); uint8_t* value_cache_in_data = value_caches_in_[layer_idx]->GetTensorMutableData(); @@ -655,7 +602,7 @@ void WindowedKeyValueCache::Update(DeviceSpan beam_indices, int current } value_caches_in_[layer_idx] = std::move(value_cache); - value_caches_out_[layer_idx] = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_out, type_); + value_caches_out_[layer_idx] = OrtValue::CreateTensor(Allocator(), updated_value_cache_shape_out, type_); }); window_size_ = 1; diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index 0e871d938..9b8fe8e83 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -30,6 +30,9 @@ struct CombinedKeyValueCache : KeyValueCache { template void RewindPastTensorsTo(size_t index); + DeviceInterface& Device() { return *model_.p_device_kvcache_; } + Ort::Allocator& Allocator() { return model_.p_device_kvcache_->GetAllocator(); } + State& state_; const Model& model_{state_.model_}; int layer_count_; @@ -64,6 +67,9 @@ struct DefaultKeyValueCache : KeyValueCache { template void RewindPastTensorsTo(size_t index); + DeviceInterface& Device() { return *model_.p_device_kvcache_; } + Ort::Allocator& Allocator() { return model_.p_device_kvcache_->GetAllocator(); } + State& state_; const Model& model_{state_.model_}; int layer_count_; @@ -89,6 +95,9 @@ struct CrossCache { void AddInputs(); private: + DeviceInterface& Device() { return *model_.p_device_kvcache_; } + Ort::Allocator& Allocator() { return model_.p_device_kvcache_->GetAllocator(); } + State& state_; const Model& model_{state_.model_}; int layer_count_; @@ -113,6 +122,9 @@ struct WindowedKeyValueCache : KeyValueCache { } private: + DeviceInterface& Device() { return *model_.p_device_kvcache_; } + Ort::Allocator& Allocator() { return model_.p_device_kvcache_->GetAllocator(); } + void Slide(); State& state_; diff --git a/src/models/logits.cpp b/src/models/logits.cpp index edaf95d1a..369bd8e39 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -3,10 +3,6 @@ #include "../generators.h" #include "model.h" #include "logits.h" -#if USE_CUDA -#include "../cuda/cuda_common.h" -#include "kernels.h" -#endif namespace Generators { @@ -14,23 +10,18 @@ Logits::Logits(State& state) : state_{state}, shape_{static_cast(state_.params_->BatchBeamSize()), 0, model_.config_->model.vocab_size}, type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} { - output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + output_raw_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_); -#if USE_CUDA - if (model_.device_type_ == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) { + if (model_.p_device_inputs_->GetType() == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) { auto& cpu_ids = model_.config_->model.eos_token_ids; - cuda_eos_token_ids_ = state_.params_->p_device->Allocate(cpu_ids.size()); + cuda_eos_token_ids_ = model_.p_device_->Allocate(cpu_ids.size()); copy(std::span{cpu_ids}, cuda_eos_token_ids_.CpuSpan()); cuda_eos_token_ids_.CopyCpuToDevice(); } -#endif input_sequence_lengths.resize(state_.params_->search.batch_size); } -#pragma warning(push) -#pragma warning(disable : 4189) // local variable is initialized but not referenced - DeviceSpan Logits::Get() { size_t element_count = shape_[0] * shape_[1] * shape_[2]; @@ -41,64 +32,28 @@ DeviceSpan Logits::Get() { const size_t seq_length = shape_[1]; const size_t vocab_size = shape_[2]; const size_t num_beams = state_.params_->search.num_beams; - const size_t element_count_last_token = shape_[0] * shape_[2]; // create new OrtValue for logits_of_last_token and use output_last_tokens_ to hold it - output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_last, type_); + output_last_tokens_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_last, type_); if (type_ == Ort::TypeToTensorType) - logits_of_last_token_fp32_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_); + logits_of_last_token_fp32_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_); logits_of_last_token = output_last_tokens_.get(); - size_t element_size = type_ == Ort::TypeToTensorType ? 4 : 2; + size_t element_size = SizeOf(type_); size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process + auto logits_raw = ByteWrapTensor(*model_.p_device_inputs_, *output_raw_); + auto logits_last_tokens = ByteWrapTensor(*model_.p_device_inputs_, *logits_of_last_token); + for (int batch_index = 0; batch_index < state_.params_->search.batch_size; batch_index++) { // Find the first non pad token from the end size_t token_index = input_sequence_lengths[batch_index] - 1; for (int beam_index = 0; beam_index < num_beams; beam_index++) { - switch (model_.device_type_) { - case DeviceType::DML: { -#if USE_DML - ComPtr source_resource; - Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, output_raw_->GetTensorMutableRawData(), &source_resource)); - - ComPtr target_resource; - Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, logits_of_last_token->GetTensorMutableRawData(), &target_resource)); - - uint64_t source_offset = (vocab_index * seq_length + token_index * vocab_size) * element_size; - uint64_t target_offset = vocab_index * element_size; - uint64_t size_in_bytes = vocab_size * element_size; - - model_.GetDmlExecutionContext()->CopyBufferRegion( - target_resource.Get(), - target_offset, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - source_resource.Get(), - source_offset, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - size_in_bytes); -#endif - } break; - - default: { - // CPU, CUDA, WEBGPU - auto logits_raw = std::span{output_raw_->GetTensorMutableData(), element_count * element_size}; - auto logits_last_tokens = std::span{logits_of_last_token->GetTensorMutableData(), element_count_last_token * element_size}; - auto target = logits_last_tokens.subspan(vocab_index * element_size, vocab_size * element_size); - auto source = logits_raw.subspan((vocab_index * seq_length + token_index * vocab_size) * element_size, vocab_size * element_size); - if (model_.device_type_ == DeviceType::CUDA) -#if USE_CUDA - CudaCheck() == cudaMemcpyAsync(target.data(), source.data(), source.size_bytes(), cudaMemcpyDeviceToDevice, state_.params_->cuda_stream); -#else - throw std::runtime_error("Unexpected CUDA device usage"); -#endif - else - copy(source, target); - } break; - } - + auto target = logits_last_tokens.subspan(vocab_index * element_size, vocab_size * element_size); + auto source = logits_raw.subspan((vocab_index * seq_length + token_index * vocab_size) * element_size, vocab_size * element_size); + target.CopyFrom(source); vocab_index += vocab_size; } } @@ -108,79 +63,28 @@ DeviceSpan Logits::Get() { // Convert from float16 to float32 if necessary if (type_ == Ort::TypeToTensorType) { -#if USE_DML - if (model_.device_type_ == DeviceType::DML) { - DmlHelpers::DmlCastInputToOutput( - model_.GetDmlExecutionContext(), - *model_.allocator_device_, - *logits_of_last_token, - logits_of_last_token_fp32_, - model_.GetDmlDevice(), - model_.GetOrtDmlApi(), - logits_cast_command_list_state_); - - logits_of_last_token = logits_of_last_token_fp32_.get(); - } else -#endif - { - ConvertFp16ToFp32(*model_.allocator_device_, *logits_of_last_token, logits_of_last_token_fp32_, model_.device_type_, model_.cuda_stream_); - logits_of_last_token = logits_of_last_token_fp32_.get(); - } + Cast(*logits_of_last_token, logits_of_last_token_fp32_, *model_.p_device_inputs_, Ort::TypeToTensorType); + logits_of_last_token = logits_of_last_token_fp32_.get(); } -#if USE_DML - // DML doesn't support on-device scoring yet, so we need to download some data to the CPU - if (model_.device_type_ == DeviceType::DML) { - value32_cpu_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_last); - } -#endif - if (logits_.empty() || logits_of_last_token->GetTensorMutableRawData() != logits_.Span().data()) - logits_ = WrapTensor(*state_.params_->p_device, *logits_of_last_token); + logits_ = WrapTensor(*model_.p_device_inputs_, *logits_of_last_token); -#if USE_CUDA - if (model_.device_type_ == DeviceType::CUDA) { + if (model_.p_device_inputs_->GetType() == DeviceType::CUDA) { if (!cuda_eos_token_ids_.empty()) - cuda::LaunchHandleEOSArray( + model_.p_device_inputs_->LaunchHandleEOSArray( logits_.Span().data(), static_cast(shape_[0]) /* batch_beam_size*/, static_cast(shape_[2]) /* vocab_size */, cuda_eos_token_ids_.Span().data(), - static_cast(cuda_eos_token_ids_.size()), - model_.cuda_stream_); - return logits_; - } -#endif -#if USE_DML - if (model_.device_type_ == DeviceType::DML) { - // DML doesn't support on-device scoring yet, so we transfer the data to the CPU - ComPtr gpu_resource; - Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation( - model_.allocator_device_, - logits_of_last_token->GetTensorMutableData(), - &gpu_resource)); - auto cpu_tensor = value32_cpu_->GetTensorMutableData(); - - model_.GetDmlReadbackHeap()->ReadbackFromGpu( - std::span(reinterpret_cast(cpu_tensor), element_count * sizeof(float)), - gpu_resource.Get(), - 0, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS); - - auto batched_logits_cpu = cpu_span{cpu_tensor, element_count}; - HandleEOSArray(batched_logits_cpu); - - logits_ = WrapTensor(*state_.params_->p_device, *value32_cpu_); + static_cast(cuda_eos_token_ids_.size())); return logits_; } -#endif HandleEOSArray(logits_.Span()); return logits_; } -#pragma warning(pop) - void Logits::Update(const DeviceSpan& next_tokens, size_t new_kv_length) { if (static_cast(output_raw_.get()->GetTensorTypeAndShapeInfo()->GetShape()[1]) == new_kv_length && new_kv_length == 1) { return; @@ -203,21 +107,7 @@ void Logits::Update(const DeviceSpan& next_tokens, size_t new_kv_length } shape_[1] = new_kv_length; - StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType ? sb_logits16_ : sb_logits32_; - output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) - : sb_logits->CreateTensorOnStaticBuffer(shape_, type_); - - if (state_.GetCapturedGraphInfo()) { - if (!sb_logits16_ && !sb_logits32_) { - if (type_ == Ort::TypeToTensorType) { - sb_logits32_ = state_.GetCapturedGraphInfo()->sb_logits32_.get(); - } - if (type_ == Ort::TypeToTensorType) { - sb_logits16_ = state_.GetCapturedGraphInfo()->sb_logits16_.get(); - } - } - } - + output_raw_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_); state_.outputs_[output_index_] = output_raw_.get(); } diff --git a/src/models/logits.h b/src/models/logits.h index f723c48ee..a9c658835 100644 --- a/src/models/logits.h +++ b/src/models/logits.h @@ -39,18 +39,7 @@ struct Logits { // OrtValue wrapped in a DeviceMemory object to make it universal DeviceSpan logits_; - // Used for decoding runs with cuda graphs. - StaticBuffer* sb_logits32_{}; - StaticBuffer* sb_logits16_{}; - -#if USE_CUDA DeviceSpan cuda_eos_token_ids_; // eos_token_ids from params, but in cuda accessible memory -#endif - -#if USE_DML - DmlReusedCommandListState logits_cast_command_list_state_{}; - std::unique_ptr value32_cpu_; -#endif }; } // namespace Generators diff --git a/src/models/model.cpp b/src/models/model.cpp index 749cac52b..266569419 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -13,16 +13,9 @@ #include "gpt.h" #include "decoder_only.h" #include "whisper.h" -#include "kernels.h" #include "multi_modal_vision_model.h" #include "decoder_only_pipeline.h" -#if USE_DML -#include -#include "dml_provider_factory.h" -#include "../dml/dml_helpers.h" - -std::string CurrentModulePath(); -#endif +#include "../dml/interface.h" namespace Generators { @@ -224,20 +217,29 @@ int32_t Tokenizer::TokenToTokenId(const char* token) const { return token_id; } -#if USE_CUDA // Since Python/Others can and will hold onto a generator object past the model object's lifetime we need to ensure // the allocator used is not destroyed until last. This keeps the allocator around until exit, after all other memory // has been destroyed. Without this, we will crash in the Onnxruntime BFCArena code when deleting tensors due to the // arena already being destroyed. -Ort::Allocator* GetCudaAllocator(OrtSession& session) { - auto& globals = *GetOrtGlobals(); - if (!globals.allocator_cuda_) { - globals.memory_info_cuda_ = OrtMemoryInfo::Create("Cuda", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); - globals.allocator_cuda_ = Ort::Allocator::Create(session, *globals.memory_info_cuda_); - } - return globals.allocator_cuda_.get(); +void EnsureDeviceOrtInit(OrtSession& session, DeviceType type) { + // CPU Allocator is a special case, it's not in the owned 'allocator_device_' table below so we handle it separately + if (type == DeviceType::CPU) + return; + + auto& device = GetOrtGlobals()->allocator_device_[static_cast(type)]; + if (device) + return; + + static const char* device_type_names[] = {"CPU (Not used, see above)", "Cuda", "DML", "WebGPU_Buffer", "QnnHtpShared"}; + static_assert(std::size(device_type_names) == static_cast(DeviceType::MAX)); + + auto name = device_type_names[static_cast(type)]; + auto memory_info = OrtMemoryInfo::Create(name, OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); + device = Ort::Allocator::Create(session, *memory_info); + if (!device) + throw std::runtime_error("Unexpected failure to create device memory allocator for " + std::string(name)); + GetDeviceInterface(type)->InitOrt(*Ort::api, *device); // Necessary for any shared library providers so they can access Ort::api } -#endif SessionInfo::SessionInfo(OrtSession& session) { Add(session); @@ -299,41 +301,19 @@ Model::Model(std::unique_ptr config) : config_{std::move(config)} { Model::~Model() = default; void Model::InitDeviceAllocator(OrtSession& session) { - allocator_device_ = &allocator_cpu_; - allocator_kvcache_ = &allocator_cpu_; -#if USE_CUDA - if (device_type_ == DeviceType::CUDA) { - allocator_device_ = GetCudaAllocator(session); - allocator_kvcache_ = allocator_device_; - } -#endif + EnsureDeviceOrtInit(session, p_device_->GetType()); -#if USE_DML - if (device_type_ == DeviceType::DML) { - memory_info_device_ = OrtMemoryInfo::Create("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); - owned_allocator_device_ = Ort::Allocator::Create(session, *memory_info_device_); - allocator_device_ = owned_allocator_device_.get(); - allocator_kvcache_ = allocator_device_; - } -#endif + // Only CUDA does every input on the device + if (p_device_->GetType() == DeviceType::CUDA) + p_device_inputs_ = p_device_; + else + p_device_inputs_ = GetDeviceInterface(DeviceType::CPU); - if (device_type_ == DeviceType::WEBGPU) { - // for webgpu we only use device memory for kv_cache - memory_info_device_ = OrtMemoryInfo::Create("WebGPU_Buffer", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); - owned_allocator_device_ = Ort::Allocator::Create(session, *memory_info_device_); - allocator_kvcache_ = owned_allocator_device_.get(); - } - - if (device_type_ == DeviceType::QNN) { - memory_info_device_ = OrtMemoryInfo::Create("QnnHtpShared", OrtAllocatorType::OrtDeviceAllocator, 0, - OrtMemType::OrtMemTypeDefault); - owned_allocator_device_ = Ort::Allocator::Create(session, *memory_info_device_); - allocator_device_ = owned_allocator_device_.get(); - allocator_kvcache_ = allocator_device_; - } + // The kvcache is always allocated in device memory + p_device_kvcache_ = p_device_; session_info_ = std::make_unique(session); - captured_graph_pool_ = std::make_shared(config_.get(), session_info_.get(), allocator_device_); + captured_graph_pool_ = std::make_shared(config_.get(), session_info_.get(), &p_device_->GetAllocator()); } void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_session_options, @@ -433,12 +413,10 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ // Device type determines the scoring device. // Only use the primary session options to determine the device type if (is_primary_session_options) { - device_type_ = DeviceType::CUDA; // Scoring will use CUDA - p_device_ = GetDeviceInterface(device_type_); + p_device_ = GetDeviceInterface(DeviceType::CUDA); // Create and set our cudaStream_t - cuda_stream_ = p_device_->GetCudaStream(); - ort_provider_options->UpdateValue("user_compute_stream", cuda_stream_); + ort_provider_options->UpdateValue("user_compute_stream", p_device_->GetCudaStream()); } session_options.AppendExecutionProvider_CUDA_V2(*ort_provider_options); @@ -456,64 +434,26 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ session_options.AppendExecutionProvider_ROCM(ort_provider_options); #if USE_DML } else if (provider_options.name == "dml") { - if (!p_dml_api_) { - auto current_module_path = CurrentModulePath(); - - bool contains_device_luid = false; + if (!GetDmlInterface()) { LUID device_luid{}; + LUID* p_device_luid{}; for (const auto& [name, value] : provider_options.options) { if (name == "luid") { if (auto separator_position = value.find(":"); separator_position != std::string::npos) { device_luid.HighPart = std::stol(value.substr(0, separator_position)); device_luid.LowPart = std::stol(value.substr(separator_position + 1)); - contains_device_luid = true; + p_device_luid = &device_luid; } } } - if (contains_device_luid) { - dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path, &device_luid); - } else { - dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path); - } - - constexpr auto directml_dll = "DirectML.dll"; - wil::unique_hmodule smart_directml_dll(LoadLibraryEx(directml_dll, nullptr, 0)); - THROW_LAST_ERROR_IF(!smart_directml_dll); - - if (LoadLibraryEx(directml_dll, nullptr, 0) == NULL) { - throw std::runtime_error("DirectML.dll not found"); - } - - auto dml_create_device1_fn = reinterpret_cast(GetProcAddress(smart_directml_dll.get(), "DMLCreateDevice1")); - THROW_LAST_ERROR_IF(!dml_create_device1_fn); - THROW_IF_FAILED(dml_create_device1_fn(dml_objects_.d3d12_device.Get(), DML_CREATE_DEVICE_FLAG_NONE, DML_FEATURE_LEVEL_5_0, IID_PPV_ARGS(&dml_device_))); - - Ort::ThrowOnError(Ort::api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast(&p_dml_api_))); - if (!p_dml_api_) { - throw std::runtime_error("Unexpected nullptr getting OrtDmlApi"); - } - - dml_execution_context_ = std::make_unique( - dml_objects_.d3d12_device.Get(), - dml_device_.Get(), - dml_objects_.command_queue.Get(), - *allocator_device_, - p_dml_api_); - - dml_pooled_upload_heap_ = std::make_unique(dml_objects_.d3d12_device.Get(), dml_execution_context_.get()); - dml_readback_heap_ = std::make_unique(dml_objects_.d3d12_device.Get(), dml_execution_context_.get()); + InitDmlInterface(p_device_luid); } - if (!disable_graph_capture) { - session_options.AddConfigEntry("ep.dml.enable_graph_capture", "1"); - session_options.AddConfigEntry("ep.dml.disable_memory_arena", "1"); - } - - p_dml_api_->SessionOptionsAppendExecutionProvider_DML1(&session_options, dml_device_.Get(), dml_objects_.command_queue.Get()); + SetDmlProvider(session_options); if (is_primary_session_options) - device_type_ = DeviceType::DML; // We use a DML allocator for input/output caches, but other tensors will use CPU tensors + p_device_ = GetDeviceInterface(DeviceType::DML); // We use a DML allocator for input/output caches, but other tensors will use CPU tensors #endif } else if (provider_options.name == "qnn") { session_options.AddConfigEntry("ep.share_ep_contexts", "1"); @@ -527,12 +467,12 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ // on the other hand, not sure if is_primary_session_options is the right thing to check here. if (const auto opt_it = opts.find("enable_htp_shared_memory_allocator"); opt_it != opts.end() && opt_it->second == "1") { - device_type_ = DeviceType::QNN; + p_device_ = GetDeviceInterface(DeviceType::QNN); } session_options.AppendExecutionProvider("QNN", opts); } else if (provider_options.name == "webgpu") { - device_type_ = DeviceType::WEBGPU; + p_device_ = GetDeviceInterface(DeviceType::WEBGPU); std::unordered_map opts; for (auto& option : provider_options.options) { opts.emplace(option.first, option.second); @@ -542,14 +482,14 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ throw std::runtime_error("Unknown provider type: " + provider_options.name); } - // If no device is set, create it, default to CPU - if (!p_device_) { - p_device_ = GetDeviceInterface(device_type_); - } + // Fallback to CPU if no provider specific interface was set + if (!p_device_) + p_device_ = GetDeviceInterface(DeviceType::CPU); } void Model::CreateSessionOptions() { session_options_ = OrtSessionOptions::Create(); + CreateSessionOptionsFromConfig(config_->model.decoder.session_options, *session_options_, true, false); for (auto& pipeline_model : config_->model.decoder.pipeline) { @@ -612,133 +552,44 @@ std::shared_ptr CreateGeneratorParams(const Config& config) { return std::make_shared(config); } -void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr& p_out, DeviceType device_type, cudaStream_t stream) { - auto shape_info = in.GetTensorTypeAndShapeInfo(); - auto shape = shape_info->GetShape(); - assert(shape_info->GetElementType() == Ort::TypeToTensorType); - - bool allocate_p_out = p_out == nullptr; - if (p_out) { - auto out_shape_info = p_out->GetTensorTypeAndShapeInfo(); - auto out_shape = out_shape_info->GetShape(); - allocate_p_out = shape != out_shape; - } - - if (allocate_p_out) - p_out = OrtValue::CreateTensor(allocator, shape); - - int count = static_cast(shape_info->GetElementCount()); - auto* fp16 = in.GetTensorData(); - auto* fp32 = p_out->GetTensorMutableData(); - - switch (device_type) { - case DeviceType::WEBGPU: - case DeviceType::DML: - // DML, WebGpu doesn't currently support on-device scoring, so we fall back to the CPU - case DeviceType::CPU: - for (int i = 0; i < count; i++) - fp32[i] = FastFloat16ToFloat32(fp16[i]); - break; - -#if USE_CUDA - case DeviceType::CUDA: - cuda::LaunchFp16ToFp32(fp16, fp32, count, stream); - break; -#endif - - default: - throw std::runtime_error("ConvertFp16ToFp32 - Unsupported device type"); - } -} - -void ConvertFp32ToFp16(OrtAllocator& allocator, OrtValue& in, std::unique_ptr& p_out, - DeviceType device_type, cudaStream_t stream) { - auto shape_info = in.GetTensorTypeAndShapeInfo(); - auto shape = shape_info->GetShape(); - assert(shape_info->GetElementType() == Ort::TypeToTensorType); - - bool allocate_p_out = p_out == nullptr; - if (p_out) { - auto out_shape_info = p_out->GetTensorTypeAndShapeInfo(); - auto out_shape = out_shape_info->GetShape(); - allocate_p_out = shape != out_shape; - } - - if (allocate_p_out) - p_out = OrtValue::CreateTensor(allocator, shape); - - int count = static_cast(shape_info->GetElementCount()); - auto* fp32 = in.GetTensorData(); - auto* fp16 = p_out->GetTensorMutableData(); +void Cast(OrtValue& input, std::unique_ptr& output, DeviceInterface& device, ONNXTensorElementDataType output_type) { + auto input_info = input.GetTensorTypeAndShapeInfo(); + auto shape = input_info->GetShape(); - switch (device_type) { - case DeviceType::DML: - case DeviceType::CPU: - for (int i = 0; i < count; i++) - fp16[i] = FastFloat32ToFloat16(fp32[i]); - break; - -#if USE_CUDA - case DeviceType::CUDA: - cuda::LaunchFp32ToFp16(fp32, fp16, count, stream); -#endif + if (output && shape != output->GetTensorTypeAndShapeInfo()->GetShape()) + output = nullptr; + if (!output) + output = OrtValue::CreateTensor(device.GetAllocator(), shape, output_type); - default: - throw std::runtime_error("ConvertFp32ToFp16 - Unsupported device type"); - } + if (!device.Cast(input, *output)) + GetDeviceInterface(DeviceType::CPU)->Cast(input, *output); } std::unique_ptr Model::ExpandInputs(std::unique_ptr& input, int num_beams) const { // Input shape (batch_size, sequence_length). The input is required with data type T. // Output shape (batch_size * num_beams, sequence_length) - // If we're on CUDA, we still want to do the copy to move the data over to CUDA memory where we will read from it later. - // DML doesn't currently support on-device scoring, so we go the same route as the CPU - if (num_beams == 1 && (device_type_ == DeviceType::CPU || device_type_ == DeviceType::DML || device_type_ == DeviceType::WEBGPU)) { + // When num_beams == 1, we don't need to expand the input, but the expand has a side effect of copying from + // CPU memory to device memory, so we can skip if the p_device_inputs_ is the CPU device + if (num_beams == 1 && p_device_inputs_ == GetDeviceInterface(DeviceType::CPU)) return std::move(input); - } auto input_type_info = input->GetTensorTypeAndShapeInfo(); auto element_type = input_type_info->GetElementType(); - auto element_size = SizeOf(element_type); auto input_shape = input_type_info->GetShape(); const int64_t batch_size = input_shape[0]; - const int64_t data_size_bytes = input_type_info->GetElementCount() * element_size / batch_size; + const int64_t data_size_bytes = input_type_info->GetElementCount() * SizeOf(element_type) / batch_size; input_shape[0] *= num_beams; - auto& allocator = device_type_ == DeviceType::DML ? allocator_cpu_ : *allocator_device_; - auto expanded = OrtValue::CreateTensor(allocator, input_shape, element_type); - const auto* input_data = reinterpret_cast(input->GetTensorRawData()); - auto* expanded_data = reinterpret_cast(expanded->GetTensorMutableRawData()); - auto* target = expanded_data; - - switch (device_type_) { - case DeviceType::WEBGPU: - case DeviceType::DML: - case DeviceType::QNN: - // DML and WebGpu doesn't currently support on-device scoring, so we use the CPU for non-cache inputs/outputs - case DeviceType::CPU: - for (int i = 0; i < batch_size; i++) { - for (int j = 0; j < num_beams; j++) { - memcpy(target, input_data + i * data_size_bytes, data_size_bytes); - target += data_size_bytes; - } - } - break; - -#if USE_CUDA - case DeviceType::CUDA: - for (int i = 0; i < batch_size; i++) { - for (int j = 0; j < num_beams; j++) { - cudaMemcpyAsync(target, input_data + i * data_size_bytes, data_size_bytes, cudaMemcpyHostToDevice, cuda_stream_); - target += data_size_bytes; - } - } - break; -#endif - default: - throw std::runtime_error("ExpandInputs - Unsupported device type"); + auto input_span = ByteWrapTensor(*GetDeviceInterface(DeviceType::CPU), *input); + auto expanded = OrtValue::CreateTensor(p_device_inputs_->GetAllocator(), input_shape, element_type); + auto expanded_span = ByteWrapTensor(*p_device_inputs_, *expanded); + + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < num_beams; j++) { + expanded_span.subspan((i * num_beams + j) * data_size_bytes, data_size_bytes).CopyFrom(input_span.subspan(i * data_size_bytes, data_size_bytes)); + } } return expanded; } diff --git a/src/models/model.h b/src/models/model.h index 2af4b2315..c17736b73 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -8,22 +8,11 @@ #include "audio_processor.h" #include "adapters.h" -#if USE_DML -#include "dml_provider_factory.h" -#include "../dml/dml_helpers.h" -#include "../dml/dml_execution_context.h" -#include "../dml/dml_pooled_upload_heap.h" -#include "../dml/dml_readback_heap.h" -#endif - namespace Generators { struct Tokenizer; -void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr& p_out, DeviceType device_type, cudaStream_t stream); - -void ConvertFp32ToFp16(OrtAllocator& allocator, OrtValue& in, std::unique_ptr& p_out, DeviceType device_type, cudaStream_t stream); - +void Cast(OrtValue& input, std::unique_ptr& output, DeviceInterface& device, ONNXTensorElementDataType type); void CheckResult(extError_t error); struct State { @@ -147,26 +136,16 @@ struct Model : std::enable_shared_from_this, LeakChecked { std::unique_ptr config_; std::unique_ptr session_options_; - cudaStream_t cuda_stream_{}; - DeviceInterface* p_device_{}; - DeviceType device_type_{DeviceType::CPU}; - Ort::Allocator& allocator_cpu_{Ort::Allocator::GetWithDefaultOptions()}; - Ort::Allocator* allocator_device_{}; // Can be CUDA or CPU based on the DeviceType in the model - Ort::Allocator* allocator_kvcache_{}; // keep allocator for kv_cache seperate to allow that only kv_cache is on device + mutable DeviceInterface* p_device_{}; // The device we're running on (matches device_type_) used for things that work the same on all devices + mutable DeviceInterface* p_device_inputs_{}; // For some model inputs, the device might be the CPU device (all but KV cache currently) + mutable DeviceInterface* p_device_kvcache_{}; // The kvcache is always allocated in device memory (TODO: Remove in favor of just p_device_?) + + Ort::Allocator& allocator_cpu_{GetDeviceInterface(DeviceType::CPU)->GetAllocator()}; std::unique_ptr session_info_; std::shared_ptr external_owner_; // Set to 'this' when created by the C API to preserve lifetime -#if USE_DML - DmlExecutionContext* GetDmlExecutionContext() const { return dml_execution_context_.get(); } - DmlReadbackHeap* GetDmlReadbackHeap() const { return dml_readback_heap_.get(); } - DmlPooledUploadHeap* GetDmlUploadHeap() const { return dml_pooled_upload_heap_.get(); } - const OrtDmlApi* GetOrtDmlApi() const { return p_dml_api_; } - IDMLDevice* GetDmlDevice() const { return dml_device_.Get(); } - ID3D12Device* GetD3D12Device() const { return dml_objects_.d3d12_device.Get(); } -#endif - protected: void InitDeviceAllocator(OrtSession& session); void CreateSessionOptions(); @@ -176,18 +155,6 @@ struct Model : std::enable_shared_from_this, LeakChecked { bool is_primary_session_options, bool disable_graph_capture); -#if USE_DML - mutable DmlObjects dml_objects_; - const OrtDmlApi* p_dml_api_{}; - std::unique_ptr dml_pooled_upload_heap_; - std::unique_ptr dml_execution_context_; - std::unique_ptr dml_readback_heap_; - ComPtr dml_device_; -#endif - - std::unique_ptr owned_allocator_device_{}; // nullptr if n/a - std::unique_ptr memory_info_device_{}; // nullptr if n/a - std::shared_ptr captured_graph_pool_; std::map> pipeline_session_options_; }; diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index fde1ed7a9..ce81bae66 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -1,11 +1,6 @@ #include "../generators.h" #include "model.h" #include "position_inputs.h" -#include "kernels.h" - -#if USE_DML -#include "../dml/dml_update_mask_kernel.h" -#endif namespace Generators { @@ -49,12 +44,6 @@ DefaultPositionInputs::DefaultPositionInputs(const Model& model, State& state, D } if (has_mask_input_) { sb_attention_mask_ = state_.GetCapturedGraphInfo()->sb_attention_mask_.get(); - -#if USE_DML - if (model_.device_type_ == DeviceType::DML) { - sb_attention_mask_next_ = state_.GetCapturedGraphInfo()->sb_attention_mask_next_.get(); - } -#endif } } } @@ -104,9 +93,7 @@ void DefaultPositionInputs::RewindTo(size_t index) { // Rewind the mask input to a previous state } else if (has_mask_input_) { if (attention_mask_shape_[0] == 1) { -#if USE_CUDA RewindMask(index); -#endif } else throw std::runtime_error("DefaultPositionInputs::RewindTo - Unsupported batch size"); } @@ -126,44 +113,21 @@ void DefaultPositionInputs::AddPositionIDs() { state_.input_names_.push_back(model_.config_->model.decoder.inputs.position_ids.c_str()); } -#if USE_CUDA || USE_DML -void DefaultPositionInputs::CopyNextPositionIDsToCurrent() { -#if USE_CUDA - assert(model_.device_type_ == DeviceType::CUDA); - cudaMemcpyAsync(position_ids_->GetTensorMutableRawData(), - position_ids_next_->GetTensorMutableRawData(), - (type_ == Ort::TypeToTensorType ? sizeof(int32_t) : sizeof(int64_t)) * position_ids_shape_[0], - cudaMemcpyDeviceToDevice, - model_.cuda_stream_); -#elif USE_DML - assert(model_.device_type_ == DeviceType::DML); - ComPtr target_resource; - Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, position_ids_->GetTensorMutableRawData(), &target_resource)); - auto source = std::span(position_ids_next_->GetTensorData(), (type_ == Ort::TypeToTensorType ? sizeof(int32_t) : sizeof(int64_t)) * position_ids_shape_[0]); - model_.GetDmlUploadHeap()->BeginUploadToGpu( - target_resource.Get(), - 0, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - source); -#endif -} -#endif - void DefaultPositionInputs::CreateNextPositionIDsTensor() { if (!sb_position_ids_) { if (position_ids_shape_[1] == 1 && position_ids_next_) { position_ids_ = std::move(position_ids_next_); position_ids_next_ = nullptr; } else { - position_ids_ = OrtValue::CreateTensor(*model_.allocator_device_, position_ids_shape_, type_); + position_ids_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), position_ids_shape_, type_); } } else { -#if USE_CUDA || USE_DML position_ids_ = sb_position_ids_->CreateTensorOnStaticBuffer(position_ids_shape_, type_); if (position_ids_shape_[1] == 1) { - CopyNextPositionIDsToCurrent(); + auto position_ids_span = ByteWrapTensor(*model_.p_device_inputs_, *position_ids_); + auto position_ids_next_span = ByteWrapTensor(*model_.p_device_inputs_, *position_ids_next_); + position_ids_span.CopyFrom(position_ids_next_span); } -#endif } } @@ -178,116 +142,51 @@ void DefaultPositionInputs::UpdatePositionIDs(int total_length, int new_kv_lengt state_.inputs_[posid_input_index_] = position_ids_.get(); } - switch (model_.device_type_) { - case DeviceType::WEBGPU: - case DeviceType::CPU: { - type_ == Ort::TypeToTensorType ? UpdatePositionIDsImpl(total_length, new_kv_length) - : UpdatePositionIDsImpl(total_length, new_kv_length); - break; - } -#if USE_CUDA - case DeviceType::CUDA: { - if (type_ == Ort::TypeToTensorType) - cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData(), static_cast(position_ids_shape_[0]), total_length, new_kv_length, model_.cuda_stream_); - else - cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData(), static_cast(position_ids_shape_[0]), total_length, new_kv_length, model_.cuda_stream_); - break; - } -#elif USE_DML - case DeviceType::DML: { - UpdatePositionIDsImplDML(); - break; - } -#endif - default: - throw std::runtime_error("PositionIDs::Update - Unsupported device type"); + if (model_.p_device_inputs_->GetType() == DeviceType::CUDA) + model_.p_device_inputs_->UpdatePositionIds(position_ids_->GetTensorMutableRawData(), static_cast(position_ids_shape_[0]), total_length, new_kv_length, type_); + else { + type_ == Ort::TypeToTensorType ? UpdatePositionIDsImpl(total_length, new_kv_length) + : UpdatePositionIDsImpl(total_length, new_kv_length); } } void DefaultPositionInputs::CreateNextAttentionMaskTensor(int total_length) { if (!sb_attention_mask_) { attention_mask_shape_[1] = total_length; - attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); -#if USE_DML - if (model_.device_type_ == DeviceType::DML) - attention_mask_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); -#endif + attention_mask_next_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), attention_mask_shape_, type_); } else { -#if USE_CUDA attention_mask_shape_[1] = state_.params_->search.max_length; attention_mask_next_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); if (is_first_mask_update_) { - cudaMemsetAsync(attention_mask_next_->GetTensorMutableRawData(), - 0, - (type_ == Ort::TypeToTensorType ? sizeof(int32_t) : sizeof(int64_t)) * attention_mask_shape_[0] * attention_mask_shape_[1], - model_.cuda_stream_); + ByteWrapTensor(*model_.p_device_inputs_, *attention_mask_next_).Zero(); } -#elif USE_DML - attention_mask_shape_[1] = state_.params_->search.max_length; - attention_mask_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); - attention_mask_next_ = sb_attention_mask_next_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); -#endif } } void DefaultPositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) throw std::runtime_error("DefaultPositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); - if (DeviceType::DML == model_.device_type_ && !(total_length == 0 || new_kv_length == 1)) - throw std::runtime_error("DefaultPositionInputs::UpdatePositionIDs - DML does not support continuous decoding."); CreateNextAttentionMaskTensor(total_length); state_.inputs_[mask_input_index_] = attention_mask_.get(); - switch (model_.device_type_) { - case DeviceType::WEBGPU: - case DeviceType::CPU: - case DeviceType::QNN: { - type_ == Ort::TypeToTensorType ? UpdateAttentionMaskImpl(total_length) - : UpdateAttentionMaskImpl(total_length); - break; - } -#if USE_CUDA - case DeviceType::CUDA: { - int max_length = sb_attention_mask_ ? state_.params_->search.max_length : total_length; - bool update_only = sb_attention_mask_ && !is_first_mask_update_; - if (type_ == Ort::TypeToTensorType) { - cuda::Launch_UpdateAttentionMask(attention_mask_next_->GetTensorMutableData(), - attention_mask_->GetTensorData(), - static_cast(attention_mask_shape_[0]), - new_kv_length, - total_length, - max_length, - update_only, - model_.cuda_stream_); - } else { - cuda::Launch_UpdateAttentionMask(attention_mask_next_->GetTensorMutableData(), - attention_mask_->GetTensorData(), - static_cast(attention_mask_shape_[0]), - new_kv_length, - total_length, - max_length, - update_only, - model_.cuda_stream_); - } - break; - } -#elif USE_DML - case DeviceType::DML: { - UpdateAttentionMaskImplDML(total_length); - break; - } -#endif - default: - throw std::runtime_error("DefaultPositionInputs::Update - Unsupported device type"); - } -#if USE_DML - if (model_.device_type_ != DeviceType::DML) { - attention_mask_ = std::move(attention_mask_next_); + if (model_.p_device_inputs_->GetType() == DeviceType::CUDA) { + int max_length = sb_attention_mask_ ? state_.params_->search.max_length : total_length; + bool update_only = sb_attention_mask_ && !is_first_mask_update_; + model_.p_device_inputs_->UpdateAttentionMask(attention_mask_next_->GetTensorMutableRawData(), + attention_mask_->GetTensorRawData(), + static_cast(attention_mask_shape_[0]), + new_kv_length, + total_length, + max_length, + update_only, + type_); + } else { + type_ == Ort::TypeToTensorType ? UpdateAttentionMaskImpl(total_length) + : UpdateAttentionMaskImpl(total_length); } -#else + attention_mask_ = std::move(attention_mask_next_); -#endif state_.inputs_[mask_input_index_] = attention_mask_.get(); is_first_mask_update_ = false; } @@ -367,25 +266,6 @@ void DefaultPositionInputs::UpdatePositionIDsImpl(int total_length, int new_kv_l } } -#if USE_DML -void DefaultPositionInputs::UpdatePositionIDsImplDML() { - ComPtr target_resource; - Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, position_ids_->GetTensorMutableRawData(), &target_resource)); - - dml_update_position_ids_kernel_ = DmlIncrementValuesKernel( - model_.GetD3D12Device(), - model_.GetDmlExecutionContext(), - static_cast(position_ids_shape_[0]), - type_, - target_resource.Get()); - - // Execute the cached command list - ComPtr fence; - uint64_t completion_value; - model_.GetDmlExecutionContext()->ExecuteCommandList(dml_update_position_ids_kernel_->GetCommandList(), &fence, &completion_value); -} -#endif - template void DefaultPositionInputs::UpdateAttentionMaskImpl(int total_length) { auto* data = attention_mask_next_->GetTensorMutableData(); @@ -405,44 +285,10 @@ void DefaultPositionInputs::UpdateAttentionMaskImpl(int total_length) { } } -#if USE_DML -void DefaultPositionInputs::UpdateAttentionMaskImplDML(int total_length) { - ComPtr attention_mask_resource; - Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_->GetTensorMutableRawData(), &attention_mask_resource)); - ComPtr attention_mask_next_resource; - Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_next_->GetTensorMutableRawData(), &attention_mask_next_resource)); - if (is_first_mask_update_) { - dml_update_mask_kernel_ = DmlUpdateMaskKernel( - model_.GetD3D12Device(), - model_.GetDmlExecutionContext(), - static_cast(attention_mask_shape_[0]), - static_cast(attention_mask_shape_[1]), - type_, - total_length, - attention_mask_resource.Get(), - attention_mask_next_resource.Get()); - is_second_mask_update_ = true; - } else if (is_second_mask_update_) { - dml_update_mask_kernel_ = DmlUpdateMaskKernel( - model_.GetD3D12Device(), - model_.GetDmlExecutionContext(), - static_cast(attention_mask_shape_[0]), - static_cast(attention_mask_shape_[1]), - type_, - 1, - attention_mask_resource.Get(), - attention_mask_next_resource.Get()); - is_second_mask_update_ = false; - } - ComPtr fence; - uint64_t completion_value; - model_.GetDmlExecutionContext()->ExecuteCommandList(dml_update_mask_kernel_->GetCommandList(), &fence, &completion_value); -} -#endif - -#if USE_CUDA void DefaultPositionInputs::RewindMask(size_t index) { if (sb_attention_mask_ && !is_first_mask_update_) { + throw std::runtime_error("PositionInputs::RewindMask - Static buffer is not supported for continuous decoding."); +#if 0 // TODO: Fix implementation, cudaMemsetAsync of 1 is setting bytes of 1 vs int32's of 1 int past_length = static_cast(index); int max_length = static_cast(state_.params_->search.max_length); cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), @@ -453,9 +299,9 @@ void DefaultPositionInputs::RewindMask(size_t index) { 1, (type_ == Ort::TypeToTensorType ? sizeof(int32_t) : sizeof(int64_t)) * past_length, model_.cuda_stream_); +#endif } } -#endif WindowedPositionInputs::WindowedPositionInputs(State& state) : state_{state} { diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h index 4365e0ee4..5712be992 100644 --- a/src/models/position_inputs.h +++ b/src/models/position_inputs.h @@ -1,12 +1,6 @@ #pragma once - #include "static_buffer.h" -#if USE_DML -#include "../dml/dml_update_mask_kernel.h" -#include "../dml/dml_increment_values_kernel.h" -#endif - namespace Generators { struct PositionInputs { @@ -46,18 +40,7 @@ struct DefaultPositionInputs : PositionInputs { template void UpdateAttentionMaskImpl(int total_length); -#if USE_CUDA || USE_DML - void CopyNextPositionIDsToCurrent(); -#endif - -#if USE_DML - void UpdatePositionIDsImplDML(); - void UpdateAttentionMaskImplDML(int total_length); -#endif - -#if USE_CUDA void RewindMask(size_t index); -#endif const Model& model_; State& state_; @@ -84,13 +67,6 @@ struct DefaultPositionInputs : PositionInputs { bool is_first_mask_update_{true}; bool is_first_update_{true}; - -#if USE_DML - std::optional dml_update_mask_kernel_; - StaticBuffer* sb_attention_mask_next_{}; - std::optional dml_update_position_ids_kernel_; - bool is_second_mask_update_{}; -#endif }; // Certain models can only process a fixed number of tokens at a time. diff --git a/src/models/prompt_image_processor.cpp b/src/models/prompt_image_processor.cpp index 249711a0b..33ed2cfa3 100644 --- a/src/models/prompt_image_processor.cpp +++ b/src/models/prompt_image_processor.cpp @@ -89,7 +89,7 @@ std::unique_ptr ProcessPixelValues(ortc::Tensor* pixel_values, allocator.GetInfo(), std::span(const_cast(pixel_values->Data()), pixel_values->NumberOfElement()), pixel_values->Shape()); - ConvertFp32ToFp16(allocator, *pixel_values_fp32, pixel_values_value, DeviceType::CPU, nullptr); + Cast(*pixel_values_fp32, pixel_values_value, *GetDeviceInterface(DeviceType::CPU), Ort::TypeToTensorType); } return pixel_values_value; diff --git a/src/models/utils.cpp b/src/models/utils.cpp index 7f4d43629..dd4bef813 100644 --- a/src/models/utils.cpp +++ b/src/models/utils.cpp @@ -1,9 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "../generators.h" +#include "utils.h" namespace Generators { +DeviceSpan ByteWrapTensor(DeviceInterface& device, OrtValue& value) { + auto info = value.GetTensorTypeAndShapeInfo(); + return device.WrapMemory(std::span{value.GetTensorMutableData(), info->GetElementCount() * SizeOf(info->GetElementType())}); +} + size_t SizeOf(ONNXTensorElementDataType type) { switch (type) { case Ort::TypeToTensorType: diff --git a/src/models/whisper.cpp b/src/models/whisper.cpp index 2988b9108..05fc20171 100644 --- a/src/models/whisper.cpp +++ b/src/models/whisper.cpp @@ -3,10 +3,6 @@ #include "../generators.h" #include "whisper.h" #include -#include "kernels.h" -#if USE_CUDA -#include "../cuda/cuda_common.h" -#endif namespace Generators { @@ -42,7 +38,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, DeviceSpan seq } if (inputs.alignment_heads != nullptr) { -#if USE_CUDA +#if 0 // USE_CUDA auto alignment_heads_type_and_shape_info = inputs.alignment_heads->ort_tensor_->GetTensorTypeAndShapeInfo(); auto alignment_heads_type = alignment_heads_type_and_shape_info->GetElementType(); // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 auto alignment_heads_shape = alignment_heads_type_and_shape_info->GetShape(); @@ -67,7 +63,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, DeviceSpan seq auto hidden_states_type = model_.session_info_->GetOutputDataType("encoder_hidden_states"); auto encoder_hidden_states_shape = std::array{decoder_input_ids_.GetShape()[0], 1500, static_cast(model_.config_->model.decoder.num_attention_heads) * model_.config_->model.decoder.head_size}; - encoder_hidden_states_ = OrtValue::CreateTensor(*model_.allocator_device_, encoder_hidden_states_shape, hidden_states_type); + encoder_hidden_states_ = OrtValue::CreateTensor(model_.p_device_->GetAllocator(), encoder_hidden_states_shape, hidden_states_type); auto sequence_lengths = sequence_lengths_unk.CpuSpan(); for (int i = 0; i < decoder_input_ids_.GetShape()[0]; i++) { @@ -94,14 +90,14 @@ Whisper_State::Whisper_State(const Whisper_Model& model, DeviceSpan seq auto type = model_.session_info_->GetOutputDataType(output_names_[kv_cache_indices]); for (int i = 0; i < layer_count * 2; i++) { - init_presents_.emplace_back(OrtValue::CreateTensor(*model_.allocator_device_, shape, type)); + init_presents_.emplace_back(OrtValue::CreateTensor(model_.p_device_->GetAllocator(), shape, type)); presents_.emplace_back(outputs_[kv_cache_indices + i]); outputs_[kv_cache_indices + i] = init_presents_.back().get(); } } } -#if USE_CUDA +#if 0 // USE_CUDA template void TransposeKCacheForDMMHA(T* dest_data, T* temp_buffer, @@ -147,7 +143,7 @@ DeviceSpan Whisper_State::Run(int current_length, DeviceSpan& ne const auto copy_data_size_all = src_shape_info->GetElementCount() * SizeOf(src_shape_info->GetElementType()); -#if USE_CUDA +#if 0 // USE_CUDA const auto src_dims = src_shape_info->GetShape(); const auto src_element_type = src_shape_info->GetElementType(); const auto src_element_size = SizeOf(src_element_type); @@ -186,8 +182,8 @@ DeviceSpan Whisper_State::Run(int current_length, DeviceSpan& ne auto src_data = init_presents_[i]->GetTensorRawData(); auto dest_data = presents_[i]->GetTensorMutableRawData(); - switch (model_.device_type_) { -#if USE_CUDA + switch (model_.p_device_inputs_->GetType()) { +#if 0 // USE_CUDA case DeviceType::CUDA: if (self_attn_kv_cache_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { // CUDA EP + FP16 precision == `DecoderMaskedMultiHeadAttention` op is used @@ -228,7 +224,7 @@ DeviceSpan Whisper_State::Run(int current_length, DeviceSpan& ne } } -#if USE_CUDA +#if 0 // USE_CUDA if (self_attn_kv_cache_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 && model_.device_type_ == DeviceType::CUDA) { // Transpose cross attention K caches for `DecoderMaskedMultiHeadAttention` @@ -272,16 +268,12 @@ DeviceSpan Whisper_State::Run(int current_length, DeviceSpan& ne } if (model_.session_info_->HasInput("cache_indirection")) { -#if USE_CUDA - cache_indirection_ = OrtValue::CreateTensor(*model_.allocator_device_, std::array{params_->search.batch_size, params_->search.num_beams, params_->search.max_length}); + cache_indirection_ = OrtValue::CreateTensor(model_.p_device_->GetAllocator(), std::array{params_->search.batch_size, params_->search.num_beams, params_->search.max_length}); cache_indirection_index_ = inputs_.size(); input_names_.push_back("cache_indirection"); inputs_.push_back(cache_indirection_.get()); - auto data = gpu_span{cache_indirection_->GetTensorMutableData(), - static_cast(params_->BatchBeamSize()) * params_->search.max_length}; - CudaCheck() == cudaMemsetAsync(data.data(), 0, data.size_bytes(), params_->cuda_stream); -#endif + ByteWrapTensor(*model_.p_device_, *cache_indirection_).Zero(); } if (model_.session_info_->HasOutput("output_cross_qk_0")) { @@ -292,7 +284,7 @@ DeviceSpan Whisper_State::Run(int current_length, DeviceSpan& ne char string[64]; snprintf(string, std::size(string), "output_cross_qk_%d", i); output_cross_qk_names_.emplace_back(string); - output_cross_qk_.emplace_back(OrtValue::CreateTensor(*model_.allocator_device_, shape, type)); + output_cross_qk_.emplace_back(OrtValue::CreateTensor(model_.p_device_->GetAllocator(), shape, type)); output_names_.emplace_back(output_cross_qk_names_.back().c_str()); outputs_.emplace_back(output_cross_qk_.back().get()); @@ -335,7 +327,7 @@ void Whisper_State::UpdateInputsOutputs(DeviceSpan& next_tokens, Device } if (cache_indirection_) { -#if USE_CUDA +#if 0 // USE_CUDA auto beam_indices_gpu = gpu_span{beam_indices.Span()}; if (beam_indices_gpu.empty()) { auto beam_indices_cpu = beam_indices.CpuSpan(); @@ -363,7 +355,7 @@ void Whisper_State::UpdateInputsOutputs(DeviceSpan& next_tokens, Device } if (output_cross_qk_.size() && alignment_heads_) { -#if USE_CUDA +#if 0 // USE_CUDA // Collect a GPU array of float* pointers from the vector of OrtValues to pass to the kernel auto output_cross_qk_ptrs = cross_qk_ptrs_gpu_.CpuSpan(); assert(output_cross_qk_ptrs.size() == output_cross_qk_.size()); @@ -394,7 +386,7 @@ void Whisper_State::Initialize(DeviceSpan& next_tokens, int total_lengt void Whisper_State::Finalize() { if (output_cross_qk_.size() && alignment_heads_) { -#if USE_CUDA +#if 0 // USE_CUDA int decoded_length = *(past_sequence_length_->GetTensorMutableData()) + 1; auto output_cross_qk_dims = output_cross_qk_[0]->GetTensorTypeAndShapeInfo()->GetShape(); diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index f6cedcd22..82fac1c7d 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -331,42 +331,13 @@ OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator auto& generator = *reinterpret_cast(oga_generator); auto* ortvalue_output = generator.state_->GetOutput(name); auto type_info = ortvalue_output->GetTensorTypeAndShapeInfo(); - std::unique_ptr ortvalue_clone = OrtValue::CreateTensor(generator.model_->allocator_cpu_, - type_info->GetShape(), - type_info->GetElementType()); + auto ortvalue_clone = OrtValue::CreateTensor(generator.model_->allocator_cpu_, type_info->GetShape(), type_info->GetElementType()); + // Copy data to ortvalue_clone - auto element_size = Generators::SizeOf(type_info->GetElementType()); - auto data_size = type_info->GetElementCount() * element_size; - const auto device_type = ortvalue_output->GetTensorMemoryInfo().GetDeviceType(); - if (device_type == OrtMemoryInfoDeviceType_CPU) { - std::copy(static_cast(ortvalue_output->GetTensorMutableRawData()), - static_cast(ortvalue_output->GetTensorMutableRawData()) + data_size, - static_cast(ortvalue_clone->GetTensorMutableRawData())); - } else if (device_type == OrtMemoryInfoDeviceType_GPU) { -#if USE_CUDA - cudaMemcpy(ortvalue_clone->GetTensorMutableRawData(), ortvalue_output->GetTensorMutableRawData(), data_size, cudaMemcpyDeviceToHost); -#else - throw std::runtime_error("Unexpected error. Trying to access GPU memory but the project is not compiled with CUDA."); -#endif - } else if (static_cast(device_type) == 4) { -#if USE_DML - ComPtr gpu_resource; - Ort::ThrowOnError(generator.model_->GetOrtDmlApi()->GetD3D12ResourceFromAllocation( - generator.model_->allocator_device_, - ortvalue_output->GetTensorMutableRawData(), - &gpu_resource)); - auto cpu_tensor = ortvalue_clone->GetTensorMutableRawData(); - generator.model_->GetDmlReadbackHeap()->ReadbackFromGpu( - std::span(reinterpret_cast(cpu_tensor), data_size), - gpu_resource.Get(), - 0, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS); -#else - throw std::runtime_error("Unexpected error. Trying to access DML memory but the project is not compiled with DML."); -#endif - } else { - throw std::runtime_error("Unsupported device type: " + std::to_string(static_cast(device_type))); - } + bool is_cpu = ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_CPU; + auto output_span = Generators::ByteWrapTensor(is_cpu ? *Generators::GetDeviceInterface(Generators::DeviceType::CPU) : *generator.model_->p_device_, *ortvalue_output); + auto copy_span = Generators::ByteWrapTensor(*Generators::GetDeviceInterface(Generators::DeviceType::CPU), *ortvalue_clone); + copy_span.CopyFrom(output_span); auto tensor = std::make_shared(std::move(ortvalue_clone)); tensor->external_owner_ = tensor; diff --git a/src/python/python.cpp b/src/python/python.cpp index cdef19830..0ae6e6236 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -12,10 +12,6 @@ #include "../logging.h" #include "../smartptrs.h" -#if USE_CUDA -#include "../cuda/cuda_common.h" -#endif - using namespace pybind11::literals; // If a parameter to a C++ function is an array of float16, this type will let pybind11::array_t map to numpy's float16 format @@ -143,50 +139,21 @@ pybind11::array ToNumpy(OrtValue* v, const Generators::Model& model) { auto shape = type_info->GetShape(); auto type = type_info->GetElementType(); auto element_size = Generators::SizeOf(type); - auto data = v->GetTensorMutableRawData(); - - std::unique_ptr cpu_copy; - -#if USE_DML - // TODO: DML version of this - if (v->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && model.device_type_ == Generators::DeviceType::DML) { - auto data_size = type_info->GetElementCount() * element_size; - cpu_copy = std::make_unique(data_size); - - ComPtr gpu_resource; - Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation( - model.allocator_device_, - data, - &gpu_resource)); - - model.GetDmlReadbackHeap()->ReadbackFromGpu( - std::span(reinterpret_cast(cpu_copy.get()), data_size), - gpu_resource.Get(), - 0, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS); - data = cpu_copy.get(); - } -#endif -#if USE_CUDA - if (v->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && model.device_type_ == Generators::DeviceType::CUDA) { - auto data_size = type_info->GetElementCount() * element_size; - cpu_copy = std::make_unique(data_size); - Generators::CudaCheck() == cudaMemcpy(cpu_copy.get(), data, data_size, cudaMemcpyDeviceToHost); - data = cpu_copy.get(); - } -#endif std::vector strides(shape.size()); { - auto size = Generators::SizeOf(type); + auto size = element_size; for (size_t i = strides.size(); i-- > 0;) { strides[i] = size; size *= shape[i]; } } + bool is_cpu = v->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_CPU; + auto device_span = Generators::ByteWrapTensor(is_cpu ? *Generators::GetDeviceInterface(Generators::DeviceType::CPU) : *model.p_device_, *v); + pybind11::buffer_info bufinfo{ - data, // Pointer to memory buffer + device_span.CopyDeviceToCpu().data(), // Pointer to memory buffer static_cast(element_size), // Size of underlying scalar type ToFormatDescriptor(type), // Python struct-style format descriptor static_cast(shape.size()), // Number of dimensions @@ -445,7 +412,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { })) .def_property_readonly("type", [](const Model& model) { return model.config_->model.type; }) .def_property_readonly( - "device_type", [](const Model& model) { return to_string(model.device_type_); }, "The device type the model is running on") + "device_type", [](const Model& model) { return to_string(model.p_device_->GetType()); }, "The device type the model is running on") .def("create_multimodal_processor", [](const Model& model) { return model.CreateMultiModalProcessor(); }); pybind11::class_(m, "Generator") diff --git a/src/qnn/interface.cpp b/src/qnn/interface.cpp new file mode 100644 index 000000000..0fc746437 --- /dev/null +++ b/src/qnn/interface.cpp @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "../generators.h" +#include "../search.h" +#include "interface.h" + +namespace Generators { +namespace QNN { + +static Ort::Allocator* ort_allocator_{}; +const char* device_label = "qnn"; + +struct QnnMemory final : DeviceBuffer { + QnnMemory(size_t size) : owned_{true} { + size_in_bytes_ = size; + p_cpu_ = p_device_ = static_cast(ort_allocator_->Alloc(size_in_bytes_)); + } + + QnnMemory(void* p, size_t size) : owned_{false} { + size_in_bytes_ = size; + p_cpu_ = p_device_ = static_cast(p); + } + + ~QnnMemory() override { + if (owned_) + ort_allocator_->Free(p_device_); + } + + const char* GetType() const override { return device_label; } + void AllocateCpu() override {} // Nothing to do, device memory is CPU accessible + void CopyDeviceToCpu() override {} // Nothing to do, device memory is CPU accessible + void CopyCpuToDevice() override {} // Nothing to do, device memory is CPU accessible + void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override { + CopyThroughCpu(*this, begin_dest, source, begin_source, size_in_bytes); + } + + void Zero() override { + memset(p_device_, 0, size_in_bytes_); + } + + bool owned_; +}; + +struct InterfaceImpl : DeviceInterface { + InterfaceImpl() { + } + + DeviceType GetType() const override { return DeviceType::QNN; } + + void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override { + assert(!ort_allocator_); + ort_allocator_ = &allocator; + } + + Ort::Allocator& GetAllocator() override { + return *ort_allocator_; + } + + std::shared_ptr AllocateBase(size_t size) override { + return std::make_shared(size); + } + + std::shared_ptr WrapMemoryBase(void* p, size_t size) override { + return std::make_shared(p, size); + } + + std::unique_ptr CreateGreedy(const GeneratorParams& params) override { return std::make_unique(params); } + std::unique_ptr CreateBeam(const GeneratorParams& params) override { return std::make_unique(params); } + + void Synchronize() override {} // Nothing to do +}; + +} // namespace QNN + +DeviceInterface* GetQNNInterface() { + static std::unique_ptr g_device = std::make_unique(); + return g_device.get(); +} + +} // namespace Generators diff --git a/src/qnn/interface.h b/src/qnn/interface.h new file mode 100644 index 000000000..fcbfe1f64 --- /dev/null +++ b/src/qnn/interface.h @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Generators { + +DeviceInterface* GetQNNInterface(); + +} // namespace Generators \ No newline at end of file diff --git a/src/search.cpp b/src/search.cpp index a1a9f6890..30cc14c15 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -2,16 +2,18 @@ #include "softmax.h" #include "search.h" #include "beam_search_scorer.h" +#include "cpu/interface.h" #include #include namespace Generators { Search_Cpu::Search_Cpu(const GeneratorParams& params) - : Search{params} { + : Search{params}, + cpu_device_{*GetCpuInterface()} { auto batch_beam_size = params.BatchBeamSize(); - sequence_lengths_ = params.p_device->Allocate(batch_beam_size); + sequence_lengths_ = cpu_device_.Allocate(batch_beam_size); } GreedySearch_Cpu::GreedySearch_Cpu(const GeneratorParams& params) @@ -26,9 +28,9 @@ GreedySearch_Cpu::GreedySearch_Cpu(const GeneratorParams& params) gen_.seed(seq); } - next_tokens_ptr_ = params.p_device->Allocate(params.search.batch_size); + next_tokens_ptr_ = cpu_device_.Allocate(params.search.batch_size); + next_tokens_ptr_.Zero(); next_tokens_ = cpu_span(next_tokens_ptr_.Span()); - memset(next_tokens_.data(), 0, next_tokens_.size_bytes()); eos_seen_buffer_ = AllocateArray(params.search.batch_size, &eos_seen_); memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); @@ -371,7 +373,6 @@ DeviceSpan BeamSearch_Cpu::GetSequence(size_t index) { return beam_scorer_->GetBeamHypotheses(batch_id, beam_id); } -// TODO(aciddelgado): my question is, should this return copy or reference? A: A copy, as with DeviceSpan it's like a span DeviceSpan BeamSearch_Cpu::GetSequence(size_t batch_id, size_t beam_id) { Finalize(params_->search.num_return_sequences); return beam_scorer_->GetBeamHypotheses(batch_id, beam_id); diff --git a/src/search.h b/src/search.h index 35fc521d8..9d456368c 100644 --- a/src/search.h +++ b/src/search.h @@ -52,6 +52,8 @@ struct Search_Cpu : Search { std::span GetScores(int batch_beam_index); + DeviceInterface& cpu_device_; + DeviceSpan sequence_lengths_; // shape (beam_size*batch_size) cpu_span next_tokens_; // shape (beam_size*batch_size) @@ -82,7 +84,6 @@ struct GreedySearch_Cpu : Search_Cpu { bool PadIfAlreadyEOS(size_t batch_id); - std::unique_ptr next_tokens_buffer_; DeviceSpan next_tokens_ptr_; std::unique_ptr temp_topk_buffer_; diff --git a/src/smartptrs.h b/src/smartptrs.h index ac7037957..642196ca3 100644 --- a/src/smartptrs.h +++ b/src/smartptrs.h @@ -5,6 +5,10 @@ #include #include "span.h" +namespace Ort { +struct Allocator; +} + namespace Generators { struct Search; struct Sequences; @@ -21,6 +25,7 @@ struct DeviceBuffer : std::enable_shared_from_this { virtual void CopyDeviceToCpu() = 0; // Allocates p_cpu_ if necessary and copies p_device_ memory into it virtual void CopyCpuToDevice() = 0; virtual void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) = 0; + virtual void Zero() = 0; // Zero out the device memory uint8_t* p_device_{}; uint8_t* p_cpu_{}; @@ -38,6 +43,8 @@ struct DeviceSpan { bool empty() const { return length_ == 0; } size_t size() const { return length_; } + operator DeviceSpan() const { return DeviceSpan(*p_device_memory_, begin_, length_); } + DeviceSpan subspan(size_t begin, size_t length) { return DeviceSpan(*p_device_memory_, begin_ + begin, length); } // Return the device accessible memory. Should only be done in device specific code, as it's not CPU accessible @@ -58,24 +65,47 @@ struct DeviceSpan { // Copy CPU memory to device memory, typically used after calling CpuSpan or CopyDeviceToCpu to update the device memory with the modifications made void CopyCpuToDevice() { p_device_memory_->CopyCpuToDevice(); } + // Zero out the device memory + void Zero() { p_device_memory_->Zero(); } + + void CopyFrom(const DeviceSpan& source) { + assert(source.size() == size()); // Spans must be the same size to copy + p_device_memory_->CopyFrom(begin_ * sizeof(T), *source.p_device_memory_, source.begin_ * sizeof(T), length_ * sizeof(T)); + } + private: DeviceSpan(DeviceBuffer& memory, size_t begin, size_t length) : p_device_memory_{memory.shared_from_this()}, begin_{begin}, length_{length} {} std::shared_ptr p_device_memory_; size_t begin_{}, length_{}; // Subspan of p_device_memory_, relative to original memory block + template + friend struct DeviceSpan; // All DeviceSpans are friends +}; + +enum struct DeviceType { + CPU, + CUDA, + DML, + WEBGPU, + QNN, + MAX }; struct DeviceInterface { virtual ~DeviceInterface() {} + virtual DeviceType GetType() const = 0; + virtual void InitOrt(const OrtApi& api, Ort::Allocator& allocator) = 0; + virtual Ort::Allocator& GetAllocator() = 0; + template - DeviceSpan Allocate(size_t count, bool cpu_accessible = false) { return DeviceSpan(AllocateBase(sizeof(T) * count, cpu_accessible)); } - virtual std::shared_ptr AllocateBase(size_t size, bool cpu_accessible) = 0; + DeviceSpan Allocate(size_t count) { return DeviceSpan(AllocateBase(sizeof(T) * count)); } + virtual std::shared_ptr AllocateBase(size_t size) = 0; // Wraps an existing memory block, useful for tensors. Use WrapTensor for OrtValue vs calling this directly template - DeviceSpan WrapMemory(std::span memory) { return DeviceSpan(WrapMemoryBase(memory.data(), memory.size_bytes())); } + DeviceSpan WrapMemory(std::span memory) { return DeviceSpan(WrapMemoryBase(const_cast*>(memory.data()), memory.size_bytes())); } virtual std::shared_ptr WrapMemoryBase(void* memory, size_t size) = 0; virtual std::unique_ptr CreateGreedy(const GeneratorParams& params) = 0; @@ -83,7 +113,18 @@ struct DeviceInterface { virtual void Synchronize() = 0; // Synchronize the device, typically used for timing or debugging - virtual cudaStream_t GetCudaStream() { + virtual bool Cast(OrtValue& /*input*/, OrtValue& /*output*/) { return false; } + + virtual void UpdatePositionIds(void* /*position_ids*/, int /*batch_beam_size*/, int /*total_length*/, int /*new_kv_length*/, ONNXTensorElementDataType /*type*/) { assert(false); } + virtual void UpdateAttentionMask(void* /*mask_data*/, const void* /*old_data*/, int /*batch_beam_size*/, int /*new_kv_length*/, int /*total_length*/, int /*max_length*/, bool /*update_only*/, ONNXTensorElementDataType /*type*/) { assert(false); } + + virtual void LaunchHandleEOSArray(float* /*batch_logits*/, int /*batch_beam_size*/, int /*vocab_size*/, const int32_t* /*eos_token_ids*/, int /*eos_token_ids_count*/) { assert(false); } + virtual void UpdateCacheIndirectionKernelLauncher(int32_t* /*tgt_indir_cache*/, const int32_t* /*src_indir_cache*/, const int32_t* /*beam_ids*/, int /*batch_size*/, int /*beam_width*/, int /*input_seq_length*/, int /*max_seq_length*/, int /*current_length*/) { assert(false); } + virtual void ReorderPastStatesKernelLauncher(void* /*out_buffer*/, const void* /*in_buffer*/, int /*batch_size*/, int /*num_heads*/, int /*max_length*/, int /*head_size*/, int /*chunk_size*/) { assert(false); } + virtual void LaunchCopyCrossQKSingleDecodeStep(float* /*cross_qk_buffer_data*/, float** /*qk_layer_pointers*/, int /*token_index*/, int /*batch_beam_size*/, int /*num_layers*/, int /*num_heads*/, int /*num_alignment_heads*/, const int* /*alignment_heads*/, int /*frames*/, int /*max_length*/) { assert(false); } + virtual void LaunchFinalizeCrossQK(int /*iteration_number*/, int /*context_decoding_len*/, int /*batch_size*/, int /*num_beams*/, int /*max_length*/, int /*num_alignment_heads*/, int /*frames_of_k*/, const float* /*cross_qk_buffer_data*/, float* /*cross_qk_output*/, int /*num_return_sequences*/, const int* /*cache_indir_data*/) { assert(false); } + + virtual void* GetCudaStream() { assert(false); return nullptr; } // Temporary until we fully factor out providers diff --git a/src/webgpu/interface.cpp b/src/webgpu/interface.cpp new file mode 100644 index 000000000..3b1fd9c25 --- /dev/null +++ b/src/webgpu/interface.cpp @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "../generators.h" +#include "../search.h" +#include "interface.h" + +namespace Generators { +namespace WebGPU { + +static Ort::Allocator* ort_allocator_{}; +const char* device_label = "WebGPU"; + +struct WebGPUMemory final : DeviceBuffer { + WebGPUMemory(size_t size) : owned_{true} { + size_in_bytes_ = size; + p_cpu_ = p_device_ = static_cast(ort_allocator_->Alloc(size_in_bytes_)); + } + + WebGPUMemory(void* p, size_t size) : owned_{false} { + size_in_bytes_ = size; + p_cpu_ = p_device_ = static_cast(p); + } + + ~WebGPUMemory() override { + if (owned_) + ort_allocator_->Free(p_device_); + } + + const char* GetType() const override { return device_label; } + void AllocateCpu() override { throw std::runtime_error("CPU can't access WebGPU memory"); } + void CopyDeviceToCpu() override { throw std::runtime_error("CPU can't access WebGPU memory"); } + void CopyCpuToDevice() override { throw std::runtime_error("CPU can't access WebGPU memory"); } + void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override { + throw std::runtime_error("CPU can't access WebGPU memory"); + } + + void Zero() override { + throw std::runtime_error("Zeroing not implemented for WebGPU memory"); + } + + bool owned_; +}; + +struct InterfaceImpl : DeviceInterface { + InterfaceImpl() { + } + + DeviceType GetType() const override { return DeviceType::WEBGPU; } + + void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override { + assert(!ort_allocator_); + ort_allocator_ = &allocator; + } + + Ort::Allocator& GetAllocator() override { + return *ort_allocator_; + } + + std::shared_ptr AllocateBase(size_t size) override { + return std::make_shared(size); + } + + std::shared_ptr WrapMemoryBase(void* p, size_t size) override { + return std::make_shared(p, size); + } + + std::unique_ptr CreateGreedy(const GeneratorParams& params) override { return std::make_unique(params); } + std::unique_ptr CreateBeam(const GeneratorParams& params) override { return std::make_unique(params); } + + void Synchronize() override {} // Nothing to do? +}; + +} // namespace WebGPU + +DeviceInterface* GetWebGPUInterface() { + static std::unique_ptr g_device = std::make_unique(); + return g_device.get(); +} + +} // namespace Generators diff --git a/src/webgpu/interface.h b/src/webgpu/interface.h new file mode 100644 index 000000000..204b4dfed --- /dev/null +++ b/src/webgpu/interface.h @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Generators { + +DeviceInterface* GetWebGPUInterface(); + +} // namespace Generators \ No newline at end of file diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index 40835a42a..56e40a2c1 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include #include diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs index 1350981ec..747675a1a 100644 --- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs +++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs @@ -90,11 +90,12 @@ private static string GetDirectoryInTreeThatContains(string currentDirectory, st public OnnxRuntimeGenAITests(ITestOutputHelper o) { + Console.WriteLine("**** Running OnnxRuntimeGenAITests constructor"); // Initialize GenAI and register a handler to dispose it on process exit ogaHandle = new OgaHandle(); AppDomain.CurrentDomain.ProcessExit += (sender, e) => ogaHandle.Dispose(); - this.output = o; + Console.WriteLine("**** OnnxRuntimeGenAI constructor completed"); } private class IgnoreOnModelAbsenceFact : FactAttribute @@ -578,6 +579,8 @@ public IgnoreOnAdaptersAbsentFact() [IgnoreOnAdaptersAbsentFact(DisplayName = "TestAdapters")] public void TestAdapters() { + Console.WriteLine("**** Running TestAdapters"); + string modelPath = _adaptersPath; string adapterPath = Path.Combine(modelPath, "adapters.onnx_adapter"); diff --git a/test/model_tests.cpp b/test/model_tests.cpp index 321d1ac46..9482a0998 100644 --- a/test/model_tests.cpp +++ b/test/model_tests.cpp @@ -18,12 +18,6 @@ #define PHI2_PATH MODEL_PATH "phi-2/int4/cpu" #endif #endif -#if USE_DML -#include -#include -#include -#include -#endif // To generate this file: // python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 --output tiny_gpt2_greedysearch_fp16.onnx --use_gpu --max_length 20 @@ -35,7 +29,7 @@ static const std::pair c_tiny_gpt2_model_paths[] = { #if USE_DML TEST(ModelTests, DMLAdapterSelection) { -#if TEST_PHI2 +#if 0 // TEST_PHI2 TODO: Remove this? Can't access the device directly anymore. auto model = Generators::CreateModel(Generators::GetOrtEnv(), PHI2_PATH); auto d3d12Device = model->GetD3D12Device(); diff --git a/test/sampling_benchmark.cpp b/test/sampling_benchmark.cpp index eb7f04cd9..28137fdc9 100644 --- a/test/sampling_benchmark.cpp +++ b/test/sampling_benchmark.cpp @@ -32,7 +32,6 @@ struct SamplingBenchmark { params->search.max_length = 10; params->search.batch_size = batch_size_; params->p_device = Generators::GetDeviceInterface(device_type_); - params->device_type = device_type_; std::random_device rd; std::mt19937 engine(rd()); diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 6ca0080eb..6fa206873 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -39,7 +39,6 @@ TEST(SamplingTests, BatchedSamplingTopPCpu) { params->search.top_p = 0.25f; params->search.batch_size = 4; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU); - params->device_type = Generators::DeviceType::CPU; auto generator = Generators::CreateGenerator(*model, *params); auto logits = params->p_device->WrapMemory(logits_cpu); generator->SetLogits(logits); @@ -66,7 +65,6 @@ TEST(SamplingTests, BatchedSamplingTopKCpu) { params->search.top_k = 2; params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU); - params->device_type = Generators::DeviceType::CPU; auto generator = Generators::CreateGenerator(*model, *params); auto logits_copy = logits_cpu; auto logits = params->p_device->WrapMemory(logits_copy); @@ -101,7 +99,6 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCpu) { params->search.top_p = 0.25f; params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU); - params->device_type = Generators::DeviceType::CPU; auto generator = Generators::CreateGenerator(*model, *params); auto logits_copy = logits_cpu; auto logits = params->p_device->WrapMemory(logits_copy); @@ -152,7 +149,6 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) { params->search.top_p = 0.95f; params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU); - params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); @@ -205,7 +201,6 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { params->search.top_k = k; params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU); - params->device_type = Generators::DeviceType::CPU; // Create data structures for testing std::random_device rd; @@ -270,7 +265,6 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) { params->search.top_p = p; params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU); - params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); @@ -317,7 +311,6 @@ TEST(SamplingTests, BatchedSamplingTopPCuda) { params->search.top_p = 0.25f; params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA); - params->device_type = Generators::DeviceType::CUDA; auto logits = AllocateFromCpuMem(*params->p_device, logits_cpu); auto generator = Generators::CreateGenerator(*model, *params); generator->SetLogits(logits); @@ -345,7 +338,6 @@ TEST(SamplingTests, BatchedSamplingTopKCuda) { params->search.top_k = 2; params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA); - params->device_type = Generators::DeviceType::CUDA; auto logits = AllocateFromCpuMem(*params->p_device, logits_cpu); auto generator = Generators::CreateGenerator(*model, *params); generator->SetLogits(logits); @@ -378,7 +370,6 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCuda) { params->search.top_p = 0.25f; params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA); - params->device_type = Generators::DeviceType::CUDA; auto logits = AllocateFromCpuMem(*params->p_device, logits_cpu); auto generator = Generators::CreateGenerator(*model, *params); generator->SetLogits(logits); @@ -406,7 +397,6 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) { params->search.top_p = 0.95f; params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA); - params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = params->p_device->Allocate(config.model.vocab_size * batch_size); auto indices_buffer = params->p_device->Allocate(config.model.vocab_size * batch_size); @@ -417,8 +407,8 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) { for (int i = 0; i < num_iter; i++) { int num_large = dist(engine); auto generator = Generators::CreateGenerator(*model, *params); - LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); - LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->cuda_stream); + LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->p_device->GetCudaStream()); + LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->p_device->GetCudaStream()); generator->SetLogits(logits_gpu); generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().CopyDeviceToCpu(); @@ -448,7 +438,6 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { params->search.top_k = k; params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA); - params->device_type = Generators::DeviceType::CUDA; // Create data structures for testing std::random_device rd; @@ -512,7 +501,6 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) { params->search.top_p = p; params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA); - params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = params->p_device->Allocate(config.model.vocab_size * batch_size); auto indices_buffer = params->p_device->Allocate(config.model.vocab_size * batch_size); std::random_device rd; @@ -522,8 +510,8 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) { for (int i = 0; i < num_iter; i++) { int num_large = dist(engine); auto generator = Generators::CreateGenerator(*model, *params); - LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); - LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->cuda_stream); + LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->p_device->GetCudaStream()); + LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->p_device->GetCudaStream()); generator->SetLogits(logits_gpu); generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().CopyDeviceToCpu(); @@ -549,7 +537,6 @@ TEST(SamplingTests, RandomizedSamplingSelectTopCuda) { params->search.max_length = 10; params->search.batch_size = batch_size; params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA); - params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = params->p_device->Allocate(config.model.vocab_size * batch_size); auto indices_buffer = params->p_device->Allocate(config.model.vocab_size * batch_size); std::random_device rd; @@ -558,8 +545,8 @@ TEST(SamplingTests, RandomizedSamplingSelectTopCuda) { int num_iter = 100; for (int i = 0; i < num_iter; i++) { int num_large = dist(engine); - LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); - LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->cuda_stream); + LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->p_device->GetCudaStream()); + LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->p_device->GetCudaStream()); auto generator = Generators::CreateGenerator(*model, *params); generator->SetLogits(logits_gpu); generator->GenerateNextToken(); diff --git a/test/tests_helper.cu b/test/tests_helper.cu index c15539d02..28f97b0e3 100644 --- a/test/tests_helper.cu +++ b/test/tests_helper.cu @@ -21,10 +21,10 @@ __global__ void GeometricDecayKernel(float* logits, int vocab_size, int num_larg } } -void LaunchGeometricDecayKernel(float* logits, int vocab_size, int batch_size, int num_large, float large_val, cudaStream_t stream) { +void LaunchGeometricDecayKernel(float* logits, int vocab_size, int batch_size, int num_large, float large_val, void* stream) { int num_threads = 256; int num_blocks = batch_size; - GeometricDecayKernel<<>>(logits, vocab_size, num_large, large_val); + GeometricDecayKernel<<(stream)>>>(logits, vocab_size, num_large, large_val); } __global__ void FisherYatesKernel(float* logits, int* indices, int vocab_size, curandState* states) { @@ -74,13 +74,13 @@ void LaunchPopulateIndices(int* indices, int size, int batch_size, cudaStream_t PopulateIndices<<>>(indices, size, batch_size); } -void LaunchFisherYatesKernel(float* logits, int* indices_buffer, int vocab_size, int batch_size, cudaStream_t stream) { +void LaunchFisherYatesKernel(float* logits, int* indices_buffer, int vocab_size, int batch_size, void* stream) { int num_threads = 256; int num_blocks = batch_size; curandState *random_states; cudaMalloc((void **)&random_states, num_threads * sizeof(curandState)); std::span logits_span{logits, static_cast(vocab_size * batch_size)}; std::span indices{indices_buffer, static_cast(vocab_size * batch_size)}; - LaunchPopulateIndices(indices.data(), vocab_size, batch_size, stream); - FisherYatesKernel<<>>(logits_span.data(), indices.data(), vocab_size, random_states); + LaunchPopulateIndices(indices.data(), vocab_size, batch_size, static_cast(stream)); + FisherYatesKernel<<(stream)>>>(logits_span.data(), indices.data(), vocab_size, random_states); } diff --git a/test/tests_helper.cuh b/test/tests_helper.cuh index ff0f3f319..d4dd48a92 100644 --- a/test/tests_helper.cuh +++ b/test/tests_helper.cuh @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -void LaunchGeometricDecayKernel(float* logits, int vocab_size, int batch_size, int num_large, float large_val, cudaStream_t stream); -void LaunchFisherYatesKernel(float* logits, int* indices, int vocab_size, int batch_size, cudaStream_t stream); \ No newline at end of file +void LaunchGeometricDecayKernel(float* logits, int vocab_size, int batch_size, int num_large, float large_val, void* stream); +void LaunchFisherYatesKernel(float* logits, int* indices, int vocab_size, int batch_size, void* stream); \ No newline at end of file