From 7e4668b27883f67e5bcfd4a38c9a55ac792e77d2 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Fri, 22 Nov 2024 17:24:35 -0800
Subject: [PATCH 01/31] Use DeviceInterface for debugging

---
 src/models/debugging.cpp | 15 +++++++--------
 src/models/model.h       |  2 +-
 2 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/src/models/debugging.cpp b/src/models/debugging.cpp
index 597ded8de..67e0df3b1 100644
--- a/src/models/debugging.cpp
+++ b/src/models/debugging.cpp
@@ -3,15 +3,7 @@
 #include "../generators.h"
 #include "utils.h"
 #include <cinttypes>
-
-#if USE_CUDA
-#include "../cuda/cuda_common.h"
-#endif
-
-#if USE_DML
-#include "../dml/dml_helpers.h"
 #include "model.h"
-#endif
 
 namespace Generators {
 static constexpr size_t c_value_count = 10;  // Dump this many values from the start of a tensor
@@ -92,6 +84,12 @@ void DumpTensor(const Model& model, std::ostream& stream, OrtValue* value, bool
       break;
     case OrtMemoryInfoDeviceType_GPU: {
       stream << "GPU\r\n";
+      auto type = type_info->GetElementType();
+      auto tensor_span = std::span<uint8_t>{const_cast<OrtValue*>(value)->GetTensorMutableData<uint8_t>(), SizeOf(type) * element_count};
+      auto device_span = model.p_device_->WrapMemory<uint8_t>(tensor_span);
+      DumpValues(stream, type, device_span.CopyDeviceToCpu().data(), element_count);
+      break;
+#if 0
 #if USE_CUDA
       auto type = type_info->GetElementType();
       size_t element_size = SizeOf(type);
@@ -120,6 +118,7 @@ void DumpTensor(const Model& model, std::ostream& stream, OrtValue* value, bool
       DumpValues(stream, type, cpu_copy.get(), element_count);
 #else
       stream << "Unexpected, using GPU memory but not compiled with CUDA or DML?";
+#endif
 #endif
       break;
     }
diff --git a/src/models/model.h b/src/models/model.h
index 03d973afa..7f29c6b28 100644
--- a/src/models/model.h
+++ b/src/models/model.h
@@ -144,7 +144,7 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
   std::unique_ptr<OrtSessionOptions> session_options_;
 
   cudaStream_t cuda_stream_{};
-  DeviceInterface* p_device_{};
+  mutable DeviceInterface* p_device_{};
   DeviceType device_type_{DeviceType::CPU};
   Ort::Allocator& allocator_cpu_{Ort::Allocator::GetWithDefaultOptions()};
   Ort::Allocator* allocator_device_{};   // Can be CUDA or CPU based on the DeviceType in the model

From 3823664cfe1a390b1bd8089fc23bfe3c2a51ac22 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Sun, 15 Dec 2024 23:52:16 -0800
Subject: [PATCH 02/31] Summary: Remove #ifdefs for providers and go through
 device interface. Details:

Add a DML DeviceInterface and DML DeviceBuffer handler.
Remove #if blocks that are doing memory copies between device/cpu memory and use the DeviceSpan interface.
---
 src/beam_search_scorer.cpp         |   2 +-
 src/cpu/interface.cpp              |  18 ++-
 src/cuda/interface.cpp             |  65 ++++-------
 src/cuda/interface.h               |   4 +-
 src/cuda/search_cuda.cpp           |   2 +-
 src/dml/interface.cpp              | 172 +++++++++++++++++++++++++++++
 src/dml/interface.h                |  11 ++
 src/generators.cpp                 |  26 ++++-
 src/generators.h                   |  25 ++++-
 src/models/captured_graph_pool.cpp |  15 +--
 src/models/extra_inputs.cpp        |  46 +-------
 src/models/input_ids.cpp           |  64 ++---------
 src/models/input_ids.h             |   6 -
 src/models/kv_cache.cpp            |  14 +--
 src/models/logits.cpp              | 102 ++---------------
 src/models/logits.h                |   5 -
 src/models/model.cpp               |  94 +++++-----------
 src/models/model.h                 |   5 -
 src/models/position_inputs.cpp     |   5 +-
 src/models/utils.cpp               |   6 +
 src/models/whisper.cpp             |   6 +-
 src/ort_genai_c.cpp                |  36 +-----
 src/python/python.cpp              |  43 +-------
 src/search.cpp                     |  11 +-
 src/search.h                       |   3 +-
 src/smartptrs.h                    |  24 +++-
 test/c_api_tests.cpp               |   3 +
 test/sampling_benchmark.cpp        |   3 +
 test/sampling_tests.cpp            |   3 +
 29 files changed, 365 insertions(+), 454 deletions(-)
 create mode 100644 src/dml/interface.cpp
 create mode 100644 src/dml/interface.h

diff --git a/src/beam_search_scorer.cpp b/src/beam_search_scorer.cpp
index d2f056038..b9cbceffe 100644
--- a/src/beam_search_scorer.cpp
+++ b/src/beam_search_scorer.cpp
@@ -67,7 +67,7 @@ BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters)
 
   // Space to store intermediate sequence
   size_t const per_beam = (max_length_ * (max_length_ + 1)) / 2;
-  hypothesis_buffer_ = device.Allocate<int32_t>(batch_beam_size * per_beam, true);
+  hypothesis_buffer_ = device.Allocate<int32_t>(batch_beam_size * per_beam);
 
   memset(next_beam_scores_.Span().data(), 0, next_beam_scores_.Span().size_bytes());
 
diff --git a/src/cpu/interface.cpp b/src/cpu/interface.cpp
index 420ba8f74..b27882668 100644
--- a/src/cpu/interface.cpp
+++ b/src/cpu/interface.cpp
@@ -7,6 +7,7 @@
 
 namespace Generators {
 
+static Ort::Allocator* ort_allocator_{};
 const char* label_cpu = "cpu";
 
 struct CpuMemory final : DeviceBuffer {
@@ -30,18 +31,23 @@ struct CpuMemory final : DeviceBuffer {
   void CopyDeviceToCpu() override {}  // Nothing to do, device is also CPU
   void CopyCpuToDevice() override {}  // Nothing to do, device is also CPU
   void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override {
-    if (GetType() == label_cpu)
-      memcpy(p_device_ + begin_dest, source.p_device_ + begin_source, size_in_bytes);
-    else
-      throw std::runtime_error("CpuMemory::CopyFromDevice not implemented for " + std::string(source.GetType()));
+    CopyThroughCpu(*this, begin_dest, source, begin_source, size_in_bytes);
+  }
+
+  void Zero() override {
+    memset(p_device_, 0, size_in_bytes_);
   }
 
   bool owned_;
 };
 
 struct CpuInterface : DeviceInterface {
-  std::shared_ptr<DeviceBuffer> AllocateBase(size_t size, bool cpu_accessible) override {
-    // cpu_accessible is ignored, as with the cpu, the device is also the cpu
+  void InitAllocator(Ort::Allocator& allocator) override {
+    assert(!ort_allocator_);
+    ort_allocator_ = &allocator;
+  }
+
+  std::shared_ptr<DeviceBuffer> AllocateBase(size_t size) override {
     return std::make_shared<CpuMemory>(size);
   }
 
diff --git a/src/cuda/interface.cpp b/src/cuda/interface.cpp
index b225604f1..0b708c94d 100644
--- a/src/cuda/interface.cpp
+++ b/src/cuda/interface.cpp
@@ -11,33 +11,9 @@
 
 namespace Generators {
 
+GenaiInterface* gp_genai{};
+Ort::Allocator* ort_allocator_{};
 const char* label_cuda = "cuda";
-const char* label_cuda_cpu = "cuda_cpu";
-
-struct HostMemory final : DeviceBuffer {
-  HostMemory(size_t size) {
-    size_in_bytes_ = size;
-    ::cudaHostAlloc(&p_device_, size, 0);
-    p_cpu_ = p_device_;  // CPU & GPU both access the same memory here
-  }
-
-  ~HostMemory() override {
-    ::cudaFreeHost(p_device_);
-  }
-
-  const char* GetType() const override { return label_cuda_cpu; }
-  void AllocateCpu() override {}      // Nothing to do, device is also CPU
-  void CopyDeviceToCpu() override {}  // Nothing to do, device is also CPU
-  void CopyCpuToDevice() override {}  // Nothing to do, device is also CPU
-  void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override {
-    if (source.GetType() == label_cuda_cpu)
-      ::memcpy(p_cpu_ + begin_dest, source.p_cpu_ + begin_source, size_in_bytes);
-    else if (source.GetType() == label_cuda)
-      ::cudaMemcpyAsync(p_device_ + begin_dest, source.p_device_ + begin_source, size_in_bytes, ::cudaMemcpyDeviceToHost, GetStream());
-    else
-      throw std::runtime_error("Cuda HostMemory::CopyFromDevice not implemented for " + std::string(source.GetType()));
-  }
-};
 
 struct GpuMemory final : DeviceBuffer {
   GpuMemory(size_t size) : owned_{true} {
@@ -66,21 +42,24 @@ struct GpuMemory final : DeviceBuffer {
 
   void CopyDeviceToCpu() override {
     AllocateCpu();
-    ::cudaMemcpy(p_cpu_, p_device_, size_in_bytes_, ::cudaMemcpyDeviceToHost);
+    ::cudaMemcpyAsync(p_cpu_, p_device_, size_in_bytes_, ::cudaMemcpyDeviceToHost, GetStream());
+    ::cudaStreamSynchronize(GetStream());
   }
 
   void CopyCpuToDevice() override {
     assert(p_cpu_);
-    ::cudaMemcpy(p_device_, p_cpu_, size_in_bytes_, ::cudaMemcpyHostToDevice);
+    ::cudaMemcpyAsync(p_device_, p_cpu_, size_in_bytes_, ::cudaMemcpyHostToDevice, GetStream());
   }
 
-  void CopyFrom(size_t begin_source, DeviceBuffer& source, size_t begin_dest, size_t size_in_bytes) override {
-    if (source.GetType() == label_cuda_cpu)
-      ::cudaMemcpyAsync(p_device_ + begin_source, source.p_device_ + begin_dest, size_in_bytes, ::cudaMemcpyHostToDevice, GetStream());
-    else if (source.GetType() == label_cuda)
-      ::cudaMemcpyAsync(p_device_ + begin_source, source.p_device_ + begin_dest, size_in_bytes, ::cudaMemcpyDeviceToDevice, GetStream());
+  void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override {
+    if (source.GetType() == label_cuda)
+      ::cudaMemcpyAsync(p_device_ + begin_dest, source.p_device_ + begin_source, size_in_bytes, ::cudaMemcpyDeviceToDevice, GetStream());
     else
-      throw std::runtime_error("Cuda GpuMemory::CopyFromDevice not implemented for " + std::string(source.GetType()));
+      gp_genai->CopyThroughCpu(*this, begin_dest, source, begin_source, size_in_bytes);
+  }
+
+  void Zero() override {
+    ::cudaMemsetAsync(p_device_, 0, size_in_bytes_, GetStream());
   }
 
   bool owned_;  // If we own the memory, we delete it on destruction
@@ -94,9 +73,12 @@ struct CudaInterfaceImpl : CudaInterface {
   ~CudaInterfaceImpl() {
   }
 
-  std::shared_ptr<DeviceBuffer> AllocateBase(size_t size, bool cpu_accessible) override {
-    if (cpu_accessible)
-      return std::make_shared<HostMemory>(size);
+  void InitAllocator(Ort::Allocator& allocator) override {
+    assert(!ort_allocator_);
+    ort_allocator_ = &allocator;
+  }
+
+  std::shared_ptr<DeviceBuffer> AllocateBase(size_t size) override {
     return std::make_shared<GpuMemory>(size);
   }
 
@@ -180,18 +162,10 @@ struct CudaInterfaceImpl : CudaInterface {
     return ::cudaMemcpyAsync(dst, src, count, kind, stream);
   }
 
-  cudaError_t cudaMemcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind) override {
-    return ::cudaMemcpy(dst, src, count, kind);
-  }
-
   cudaError_t cudaMemsetAsync(void* ptr, int value, size_t count, cudaStream_t stream) override {
     return ::cudaMemsetAsync(ptr, value, count, stream);
   }
 
-  cudaError_t cudaMemset(void* ptr, int value, size_t count) override {
-    return ::cudaMemset(ptr, value, count);
-  }
-
  private:
   cuda_stream_holder cuda_stream_;
 };
@@ -201,7 +175,6 @@ std::unique_ptr<CudaInterface> g_cuda_device;
 DeviceInterface& GetCudaDeviceInterface() { return *g_cuda_device; }
 cudaStream_t GetStream() { return g_cuda_device->GetCudaStream(); }
 
-GenaiInterface* gp_genai{};
 LogItems& GetLogItems() { return gp_genai->GetLogItems(); }
 std::ostream& operator<<(std::ostream& stream, SGR sgr_code) { return gp_genai->operator_leftshift(stream, sgr_code); }
 std::ostream& Log(std::string_view label, std::string_view text) { return gp_genai->Log(label, text); }
diff --git a/src/cuda/interface.h b/src/cuda/interface.h
index d664277cc..24e6fb94f 100644
--- a/src/cuda/interface.h
+++ b/src/cuda/interface.h
@@ -7,6 +7,8 @@ struct GenaiInterface {
   virtual void HeapFree(void*) = 0;
 #endif
 
+  virtual void CopyThroughCpu(Generators::DeviceBuffer& dest, size_t begin_dest, Generators::DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) = 0;
+
   virtual Generators::LogItems& GetLogItems() = 0;
   virtual std::ostream& operator_leftshift(std::ostream& stream, Generators::SGR sgr_code) = 0;
   virtual std::ostream& Log(std::string_view label, std::string_view text = {}) = 0;
@@ -42,9 +44,7 @@ struct CudaInterface : DeviceInterface {
   virtual void LaunchFinalizeCrossQK(cudaStream_t stream, int iteration_number, int context_decoding_len, int batch_size, int num_beams, int max_length, int num_alignment_heads, int frames_of_k, const float* cross_qk_buffer_data, float* cross_qk_output, int num_return_sequences, const int* cache_indir_data) = 0;
 
   virtual cudaError_t cudaMemcpyAsync(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream) = 0;
-  virtual cudaError_t cudaMemcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind) = 0;
   virtual cudaError_t cudaMemsetAsync(void* ptr, int value, size_t count, cudaStream_t stream) = 0;
-  virtual cudaError_t cudaMemset(void* ptr, int value, size_t count) = 0;
 };
 #endif
 }  // namespace Generators
diff --git a/src/cuda/search_cuda.cpp b/src/cuda/search_cuda.cpp
index fa80f67dc..8d7bf6fe0 100644
--- a/src/cuda/search_cuda.cpp
+++ b/src/cuda/search_cuda.cpp
@@ -31,8 +31,8 @@ Search_Cuda::Search_Cuda(const GeneratorParams& params)
 GreedySearch_Cuda::GreedySearch_Cuda(const GeneratorParams& params)
     : Search_Cuda{params} {
   next_tokens_buffer_ = params.p_device->Allocate<int32_t>(params.search.batch_size);
+  next_tokens_buffer_.Zero();
   next_tokens_ = gpu_span<int32_t>(next_tokens_buffer_.Span());
-  cudaMemsetAsync(next_tokens_.data(), 0, next_tokens_.size_bytes(), params_->cuda_stream);
 
   unsigned long long random_seed;
   if (params_->search.random_seed != -1)
diff --git a/src/dml/interface.cpp b/src/dml/interface.cpp
new file mode 100644
index 000000000..61a57d663
--- /dev/null
+++ b/src/dml/interface.cpp
@@ -0,0 +1,172 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "../generators.h"
+#include "../search.h"
+#include "../cpu/interface.h"
+#include "interface.h"
+#include <cstdarg>
+
+#include <wil/wrl.h>
+#include "dml_provider_factory.h"
+#include "../dml/dml_helpers.h"
+#include "../dml/dml_execution_context.h"
+#include "../dml/dml_pooled_upload_heap.h"
+#include "../dml/dml_readback_heap.h"
+
+std::string CurrentModulePath();
+
+namespace Generators {
+namespace Dml { // If this was in a shared library it wouldn't need to be in its own namespace
+
+Ort::Allocator* ort_allocator_{};
+const char* label_dml = "dml";
+
+wil::unique_hmodule smart_directml_dll_;
+DmlObjects dml_objects_;
+const OrtDmlApi* dml_api_{};
+std::unique_ptr<DmlPooledUploadHeap> dml_pooled_upload_heap_;
+std::unique_ptr<DmlExecutionContext> dml_execution_context_;
+std::unique_ptr<DmlReadbackHeap> dml_readback_heap_;
+ComPtr<IDMLDevice> dml_device_;
+
+struct GpuMemory final : DeviceBuffer {
+  GpuMemory(size_t size) : owned_{true} {
+    size_in_bytes_ = size;
+    p_device_ = static_cast<uint8_t*>(ort_allocator_->Alloc(size_in_bytes_));
+    Ort::ThrowOnError(dml_api_->GetD3D12ResourceFromAllocation(ort_allocator_, p_device_, &gpu_resource_));
+  }
+
+  GpuMemory(void* p, size_t size) : owned_{false} {
+    size_in_bytes_ = size;
+    p_device_ = static_cast<uint8_t*>(p);
+    Ort::ThrowOnError(dml_api_->GetD3D12ResourceFromAllocation(ort_allocator_, p_device_, &gpu_resource_));
+  }
+
+  ~GpuMemory() override {
+    if (owned_)
+      ort_allocator_->Free(p_device_);
+    if (p_cpu_)
+      free(p_cpu_);
+  }
+
+  const char* GetType() const override { return label_dml; }
+
+  void AllocateCpu() override {
+    if (!p_cpu_)
+      p_cpu_ = static_cast<uint8_t*>(malloc(size_in_bytes_));
+  }
+
+  void CopyDeviceToCpu() override {
+    AllocateCpu();
+    dml_readback_heap_->ReadbackFromGpu(std::span(p_cpu_, size_in_bytes_), gpu_resource_.Get(), 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
+  }
+
+  void CopyCpuToDevice() override {
+    assert(p_cpu_);
+    auto source = std::span(p_cpu_, size_in_bytes_);
+    dml_pooled_upload_heap_->BeginUploadToGpu(gpu_resource_.Get(), 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, source);
+  }
+
+  void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override {
+    if (source.GetType() == label_dml) {
+      auto& source_gpu = dynamic_cast<GpuMemory&>(source);
+      dml_execution_context_->CopyBufferRegion(
+          gpu_resource_.Get(),
+          begin_dest,
+          D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
+          source_gpu.gpu_resource_.Get(),
+          begin_source,
+          D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
+          size_in_bytes);
+    } else
+      CopyThroughCpu(*this, begin_dest, source, begin_source, size_in_bytes);
+  }
+
+  void Zero() override {
+    // TODO: Implement a zeroing that runs directly on DML vs going through CPU
+    AllocateCpu();
+    memset(p_cpu_, 0, size_in_bytes_);
+    CopyCpuToDevice();
+  }
+
+  ComPtr<ID3D12Resource> gpu_resource_;
+  bool owned_;  // If we own the memory, we delete it on destruction
+};
+
+struct DmlInterfaceImpl : DeviceInterface {
+
+  DmlInterfaceImpl(LUID* p_device_luid) {
+    Ort::ThrowOnError(Ort::api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&dml_api_)));
+    if (!dml_api_) {
+      throw std::runtime_error("Unexpected nullptr getting OrtDmlApi");
+    }
+
+    dml_objects_ = DmlHelpers::CreateDmlObjects(CurrentModulePath(), p_device_luid);
+
+    constexpr auto directml_dll = "DirectML.dll";
+    smart_directml_dll_ = wil::unique_hmodule{LoadLibraryEx(directml_dll, nullptr, 0)};
+    if (!smart_directml_dll_)
+      throw std::runtime_error("DirectML.dll not found");
+
+    auto dml_create_device1_fn = reinterpret_cast<decltype(&DMLCreateDevice1)>(GetProcAddress(smart_directml_dll_.get(), "DMLCreateDevice1"));
+    THROW_LAST_ERROR_IF(!dml_create_device1_fn);
+    THROW_IF_FAILED(dml_create_device1_fn(dml_objects_.d3d12_device.Get(), DML_CREATE_DEVICE_FLAG_NONE, DML_FEATURE_LEVEL_5_0, IID_PPV_ARGS(&dml_device_)));
+
+    Ort::ThrowOnError(Ort::api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&dml_api_)));
+  }
+
+  void InitAllocator(Ort::Allocator& allocator) override {
+    assert(!ort_allocator_);
+    ort_allocator_ = &allocator;
+
+    dml_execution_context_ = std::make_unique<DmlExecutionContext>(
+        dml_objects_.d3d12_device.Get(),
+        dml_device_.Get(),
+        dml_objects_.command_queue.Get(),
+        *ort_allocator_,
+        dml_api_);
+
+    dml_pooled_upload_heap_ = std::make_unique<DmlPooledUploadHeap>(dml_objects_.d3d12_device.Get(), dml_execution_context_.get());
+    dml_readback_heap_ = std::make_unique<DmlReadbackHeap>(dml_objects_.d3d12_device.Get(), dml_execution_context_.get());
+  }
+
+  std::shared_ptr<DeviceBuffer> AllocateBase(size_t size) override {
+    return std::make_shared<GpuMemory>(size);
+  }
+
+  std::shared_ptr<DeviceBuffer> WrapMemoryBase(void* p, size_t size) override {
+    return std::make_shared<GpuMemory>(p, size);
+  }
+
+  std::unique_ptr<Search> CreateGreedy(const GeneratorParams& params) override {
+    return GetCpuInterface()->CreateGreedy(params);
+  }
+
+  std::unique_ptr<Search> CreateBeam(const GeneratorParams& params) override {
+    return GetCpuInterface()->CreateBeam(params);
+  }
+
+  void Synchronize() override {
+  }
+};
+
+}  // namespace Dml
+
+std::unique_ptr<Dml::DmlInterfaceImpl> g_dml_device;
+
+void InitDmlInterface(LUID* p_device_luid) {
+  if (!g_dml_device)
+    g_dml_device = std::make_unique<Dml::DmlInterfaceImpl>(p_device_luid);
+}
+
+void SetDmlProvider(OrtSessionOptions& session_options) {
+  Ort::ThrowOnError(Dml::dml_api_->SessionOptionsAppendExecutionProvider_DML1(&session_options, Dml::dml_device_.Get(), Dml::dml_objects_.command_queue.Get()));
+}
+
+DeviceInterface* GetDmlInterface() {
+  assert(g_dml_device);
+  return g_dml_device.get();
+}
+
+}  // namespace Generators
diff --git a/src/dml/interface.h b/src/dml/interface.h
new file mode 100644
index 000000000..c70faf721
--- /dev/null
+++ b/src/dml/interface.h
@@ -0,0 +1,11 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+namespace Generators {
+
+void InitDmlInterface(LUID* p_device_luid);
+void SetDmlProvider(OrtSessionOptions& options);
+
+DeviceInterface* GetDmlInterface();
+
+}
\ No newline at end of file
diff --git a/src/generators.cpp b/src/generators.cpp
index 8c9ce39b3..ad34d1979 100644
--- a/src/generators.cpp
+++ b/src/generators.cpp
@@ -8,6 +8,7 @@
 #include "search.h"
 #include "cpu/interface.h"
 #include "cuda/interface.h"
+#include "dml/interface.h"
 #if USE_CUDA
 #include "models/kernels.h"
 #endif
@@ -89,18 +90,31 @@ OrtEnv& GetOrtEnv() {
   return *GetOrtGlobals()->env_;
 }
 
+// Fallback to copy between two separate device buffers by going through CPU memory (slow unless we're the CPU device)
+void CopyThroughCpu(DeviceBuffer& dest, size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) {
+  source.CopyDeviceToCpu();
+  auto source_span = std::span<const uint8_t>(source.p_cpu_+begin_source, size_in_bytes);
+  dest.AllocateCpu();
+  std::copy(source_span.begin(), source_span.end(), dest.p_cpu_ + begin_dest);
+  dest.CopyCpuToDevice();
+}
+
 struct GenaiInterfaceImpl : GenaiInterface {
 #if _WIN32
   void* HeapAllocate(size_t size) override { return std::malloc(size); }
   void HeapFree(void* p) override { std::free(p); }
 #endif
 
+  void CopyThroughCpu(DeviceBuffer& dest, size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override {
+    return Generators::CopyThroughCpu(dest, begin_dest, source, begin_source, size_in_bytes);
+  }
+
   Generators::LogItems& GetLogItems() override { return g_log; }
   std::ostream& operator_leftshift(std::ostream& stream, Generators::SGR sgr_code) override { return stream << sgr_code; }
   std::ostream& Log(std::string_view label, std::string_view text = {}) override { return Log(label, text); }
 
-  void DumpSpan(std::ostream& stream, std::span<const float> values) override { return DumpSpan(stream, values); }
-  void DumpSpan(std::ostream& stream, std::span<const int> values) override { return DumpSpan(stream, values); }
+  void DumpSpan(std::ostream& stream, std::span<const float> values) override { return Generators::DumpSpan(stream, values); }
+  void DumpSpan(std::ostream& stream, std::span<const int> values) override { return Generators::DumpSpan(stream, values); }
 
   void Sequences_AfterAppendNextTokens(Sequences* p_this, DeviceSpan<int32_t> next_tokens, size_t batch_beam_size) override { return p_this->AfterAppendNextTokens(next_tokens, batch_beam_size); }
   void Sequences_RewindTo(Sequences* p_this, size_t new_length) override { return p_this->RewindTo(new_length); }
@@ -136,6 +150,7 @@ CudaInterface* GetCudaInterface() {
   return cuda_interface;
 }
 
+
 namespace cuda {
 void LaunchInt32ToInt64(const int32_t* input, int64_t* output, int count, cudaStream_t stream) { GetCudaInterface()->Int32ToInt64(input, output, count, stream); }
 void LaunchFp16ToFp32(const uint16_t* input, float* output, int count, cudaStream_t stream) { GetCudaInterface()->Fp16ToFp32(input, output, count, stream); }
@@ -160,6 +175,7 @@ void LaunchFinalizeCrossQK<float>(cudaStream_t stream, int iteration_number, int
 }  // namespace cuda
 #endif
 
+
 std::string to_string(DeviceType device_type) {
   switch (device_type) {
     case DeviceType::CPU:
@@ -182,6 +198,10 @@ DeviceInterface* GetDeviceInterface(DeviceType type) {
 #if USE_CUDA
     case DeviceType::CUDA:
       return GetCudaInterface();
+#endif
+#if USE_DML
+    case DeviceType::DML:
+      return GetDmlInterface();
 #endif
   }
 }
@@ -427,7 +447,5 @@ DeviceSpan<int32_t> Generator::GetSequence(size_t index) const {
 
 #if USE_CUDA
 cudaError_t cudaMemcpyAsync(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream) { return Generators::GetCudaInterface()->cudaMemcpyAsync(dst, src, count, kind, stream); }
-cudaError_t cudaMemcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind) { return Generators::GetCudaInterface()->cudaMemcpy(dst, src, count, kind); }
 cudaError_t cudaMemsetAsync(void* ptr, int value, size_t count, cudaStream_t stream) { return Generators::GetCudaInterface()->cudaMemsetAsync(ptr, value, count, stream); }
-cudaError_t cudaMemset(void* ptr, int value, size_t count) { return Generators::GetCudaInterface()->cudaMemset(ptr, value, count); }
 #endif
diff --git a/src/generators.h b/src/generators.h
index 3e4561b14..7275aa976 100644
--- a/src/generators.h
+++ b/src/generators.h
@@ -49,9 +49,24 @@ struct Tokenizer;
 
 template <typename T>
 DeviceSpan<T> WrapTensor(DeviceInterface& device, OrtValue& value) {
-  return device.WrapMemory(std::span<T>{value.GetTensorMutableData<T>(), value.GetTensorTypeAndShapeInfo()->GetElementCount()});
+  auto info = value.GetTensorTypeAndShapeInfo();
+  assert(info->GetElementType() == Ort::TypeToTensorType<std::remove_const_t<T>>);
+  return device.WrapMemory(std::span<T>{value.GetTensorMutableData<T>(), info->GetElementCount()});
 }
 
+DeviceSpan<uint8_t> ByteWrapTensor(DeviceInterface& device, OrtValue& value);
+
+template<typename T>
+struct OrtTensor {
+  OrtTensor(std::unique_ptr<OrtValue> ort_value, DeviceInterface& device)
+    : ort_value_{std::move(ort_value)}, device_span_{WrapTensor<T>(device, *ort_value_)} {}
+
+  operator OrtValue*() { return ort_value_.get(); }
+
+  std::unique_ptr<OrtValue> ort_value_;
+  DeviceSpan<T> device_span_;
+};
+
 // OgaSequences are a vector of int32 vectors
 using TokenSequences = std::vector<std::vector<int32_t>>;
 
@@ -60,6 +75,7 @@ enum struct DeviceType {
   CUDA,
   DML,
   WEBGPU,
+  MAX
 };
 
 std::string to_string(DeviceType device_type);
@@ -136,10 +152,7 @@ struct OrtGlobals {
   OrtGlobals();
 
   std::unique_ptr<OrtEnv> env_;
-#if USE_CUDA
-  std::unique_ptr<OrtMemoryInfo> memory_info_cuda_;
-  std::unique_ptr<Ort::Allocator> allocator_cuda_;
-#endif
+  std::unique_ptr<Ort::Allocator> allocator_device_[static_cast<int>(DeviceType::MAX)];
  private:
   OrtGlobals(const OrtGlobals&) = delete;
   void operator=(const OrtGlobals&) = delete;
@@ -155,6 +168,8 @@ std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Model& model);
 std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Config& config);  // For benchmarking purposes only
 std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorParams& params);
 
+void CopyThroughCpu(DeviceBuffer& dest, size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes);
+
 float Float16ToFloat32(uint16_t v);  // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction
 void top_k_indices(std::span<int32_t> top_k, std::span<const float> inputs);
 
diff --git a/src/models/captured_graph_pool.cpp b/src/models/captured_graph_pool.cpp
index 84c8bec11..baaa9243e 100644
--- a/src/models/captured_graph_pool.cpp
+++ b/src/models/captured_graph_pool.cpp
@@ -19,7 +19,7 @@ void CapturedGraphInfoRecycler::operator()(CapturedGraphInfo* captured_graph_inf
 }
 
 CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, const GeneratorParams& params) const {
-  if (!params.use_cuda_graph || (model.device_type_ != DeviceType::CUDA && model.device_type_ != DeviceType::DML)) {
+  if (!params.use_cuda_graph || (model.device_type_ != DeviceType::CUDA)) {
     return nullptr;
   }
 
@@ -48,12 +48,6 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model,
     size_t max_beam_batch_size = static_cast<size_t>(params.search.num_beams) * params.max_batch_size;
     new_captured_graph->sb_input_ids_ = std::make_unique<StaticBuffer>(allocator_device_, max_beam_batch_size);
 
-#if USE_DML
-    if (model.device_type_ == DeviceType::DML) {
-      new_captured_graph->sb_input_ids_int32_ = std::make_unique<StaticBuffer>(allocator_device_, max_beam_batch_size);
-    }
-#endif
-
     // Create the static buffers for the cache
     int layer_count = config_->model.decoder.num_hidden_layers;
     new_captured_graph->sb_kv_caches_.reserve(layer_count * 2);
@@ -70,13 +64,6 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model,
     // Create the static buffer for the attention mask, if needed
     if (session_info_->HasInput(config_->model.decoder.inputs.attention_mask)) {
       new_captured_graph->sb_attention_mask_ = std::make_unique<StaticBuffer>(allocator_device_, max_beam_batch_size);
-
-#if USE_DML
-      // DML currently needs an additional static buffer for the mask
-      if (model.device_type_ == DeviceType::DML) {
-        new_captured_graph->sb_attention_mask_next_ = std::make_unique<StaticBuffer>(allocator_device_, max_beam_batch_size);
-      }
-#endif
     }
 
     auto output_type = session_info_->GetOutputDataType(config_->model.decoder.outputs.logits);
diff --git a/src/models/extra_inputs.cpp b/src/models/extra_inputs.cpp
index 6827042d3..a4bab8ce7 100644
--- a/src/models/extra_inputs.cpp
+++ b/src/models/extra_inputs.cpp
@@ -28,11 +28,6 @@ ExtraInputs::ExtraInputs(State& state)
   }
 }
 
-#pragma warning(push)
-#pragma warning(disable : 4065)  // switch statement contains 'default' but no 'case' labels
-#pragma warning(disable : 4189)  // local variable is initialized but not referenced
-#pragma warning(disable : 4702)  // unreachable code
-
 void ExtraInputs::Add() {
   // Add extra user inputs
   for (int i = 0; i < state_.params_->extra_inputs.size(); ++i) {
@@ -42,44 +37,11 @@ void ExtraInputs::Add() {
 
   // Copy the data from the CPU-backed ORT value to the static buffers
   for (int i = 0; i < sb_extra_inputs_.size(); ++i) {
-    auto type_and_shape_info = extra_inputs_[i]->GetTensorTypeAndShapeInfo();
-    auto shape = type_and_shape_info->GetShape();
-    auto element_count = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
-    auto copy_size_in_bytes = element_count * SizeOf(type_and_shape_info->GetElementType());
-
-    switch (model_.device_type_) {
-#if USE_DML
-      case DeviceType::DML: {
-        ComPtr<ID3D12Resource> target_resource;
-        Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, extra_inputs_[i]->GetTensorMutableRawData(), &target_resource));
-
-        auto source = std::span(state_.params_->extra_inputs[i].tensor->ort_tensor_->GetTensorData<const uint8_t>(), copy_size_in_bytes);
-
-        model_.GetDmlUploadHeap()->BeginUploadToGpu(
-            target_resource.Get(),
-            0,
-            D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
-            source);
-      } break;
-#endif
-
-#if USE_CUDA
-      case DeviceType::CUDA: {
-        cudaMemcpyAsync(
-            extra_inputs_[i]->GetTensorMutableRawData(),
-            state_.params_->extra_inputs[i].tensor->ort_tensor_->GetTensorMutableRawData(),
-            copy_size_in_bytes,
-            cudaMemcpyHostToDevice,
-            model_.cuda_stream_);
-      } break;
-#endif
-
-      default:
-        throw std::runtime_error("Unsupported device for graph capture");
-    }
+    auto tensor = ByteWrapTensor(*model_.p_device_, *extra_inputs_[i]);
+    auto source = std::span{state_.params_->extra_inputs[i].tensor->ort_tensor_->GetTensorData<uint8_t>(), tensor.size()};
+    copy(source, tensor.CpuSpan());
+    tensor.CopyCpuToDevice();
   }
 }
 
-#pragma warning(pop)
-
 }  // namespace Generators
diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp
index 71f051bfc..47233f3d8 100644
--- a/src/models/input_ids.cpp
+++ b/src/models/input_ids.cpp
@@ -13,12 +13,6 @@ InputIDs::InputIDs(State& state)
 
   if (state_.GetCapturedGraphInfo()) {
     sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get();
-
-#if USE_DML
-    if (model_.device_type_ == DeviceType::DML) {
-      sb_input_ids_int32_ = state_.GetCapturedGraphInfo()->sb_input_ids_int32_.get();
-    }
-#endif
   }
 
   if (model_.session_info_->HasInput(model_.config_->model.decoder.inputs.current_sequence_length) &&
@@ -56,29 +50,26 @@ void InputIDs::Add() {
 }
 
 void InputIDs::Update(DeviceSpan<int32_t>& new_tokens) {
-  const auto get_unpadded_sequence_length = [](std::span<const int32_t> input_ids,
-                                               int32_t pad_token_id) {
-    int32_t seq_length = 0;
+  auto new_tokens_cpu = new_tokens.CopyDeviceToCpu();
+
+  const auto get_unpadded_sequence_length = [](std::span<const int32_t> input_ids, int32_t pad_token_id) {
     for (int32_t i = 0; i < input_ids.size(); i++) {
-      if (input_ids[i] == pad_token_id) {
-        break;
-      }
-      seq_length++;
+      if (input_ids[i] == pad_token_id)
+        return i;
     }
-    return seq_length;
+    return static_cast<int32_t>(input_ids.size());
   };
 
   if (current_sequence_length_ && past_sequence_length_) {
     if (state_.params_->BatchBeamSize() != 1) {
       throw std::runtime_error("Batch size must be 1 for current_sequence_length and past_sequence_length inputs");
     }
-    auto new_sequence_length = get_unpadded_sequence_length(new_tokens.CpuSpan(), model_.config_->model.pad_token_id);
+    auto new_sequence_length = get_unpadded_sequence_length(new_tokens_cpu, model_.config_->model.pad_token_id);
     *current_sequence_length_->GetTensorMutableData<int32_t>() += new_sequence_length;
     *past_sequence_length_->GetTensorMutableData<int32_t>() += new_sequence_length;
   }
 
-  // Resize input_ids shape based on new_tokens
-  // For beam search
+  // For beam search, resize input_ids shape based on new_tokens
   size_t sequence_length = static_cast<size_t>(new_tokens.size()) / state_.params_->BatchBeamSize();
   if (is_prompt_ && state_.params_->search.num_beams > 1)
     sequence_length = static_cast<size_t>(new_tokens.size()) / state_.params_->search.batch_size;
@@ -87,20 +78,8 @@ void InputIDs::Update(DeviceSpan<int32_t>& new_tokens) {
     shape_[1] = sequence_length;
     if (!sb_input_ids_) {
       value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
-
-#if USE_DML
-      if (model_.device_type_ == DeviceType::DML) {
-        value_int32_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
-      }
-#endif
     } else {
       value_ = sb_input_ids_->CreateTensorOnStaticBuffer(shape_, type_);
-
-#if USE_DML
-      if (model_.device_type_ == DeviceType::DML) {
-        value_int32_ = sb_input_ids_int32_->CreateTensorOnStaticBuffer(shape_, Ort::TypeToTensorType<int32_t>);
-      }
-#endif
     }
 
     state_.inputs_[input_index_] = value_.get();
@@ -121,33 +100,8 @@ void InputIDs::Update(DeviceSpan<int32_t>& new_tokens) {
 #endif
       } break;
 
-      case DeviceType::DML: {
-#if USE_DML
-        ComPtr<ID3D12Resource> source_resource;
-        Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, value_int32_->GetTensorMutableRawData(), &source_resource));
-
-        auto source = std::span<const uint8_t>(
-            reinterpret_cast<const uint8_t*>(new_tokens.CpuSpan().data()),
-            new_tokens.CpuSpan().size_bytes());
-
-        model_.GetDmlUploadHeap()->BeginUploadToGpu(
-            source_resource.Get(),
-            0,
-            D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
-            source);
-
-        DmlHelpers::DmlCastInputToOutput(
-            model_.GetDmlExecutionContext(),
-            *model_.allocator_device_,
-            *value_int32_,
-            value_,
-            model_.GetDmlDevice(),
-            model_.GetOrtDmlApi(),
-            input_ids_cast_command_list_state_);
-#endif
-      } break;
       default: {
-        // CPU, WEBGPU
+        // CPU, DML, WEBGPU
         auto* data = value_->GetTensorMutableData<int64_t>();
         auto next_tokens = new_tokens.Span();
         for (int b = 0; b < shape_[0]; b++) {
diff --git a/src/models/input_ids.h b/src/models/input_ids.h
index dd364212f..02af3a98a 100644
--- a/src/models/input_ids.h
+++ b/src/models/input_ids.h
@@ -35,12 +35,6 @@ struct InputIDs {
   // Used for decoding runs with cuda graphs.
   StaticBuffer* sb_input_ids_{};
 
-#if USE_DML
-  std::unique_ptr<OrtValue> value_int32_;
-  StaticBuffer* sb_input_ids_int32_{};
-  DmlReusedCommandListState input_ids_cast_command_list_state_{};
-#endif
-
   std::unique_ptr<OrtValue> current_sequence_length_;
   std::unique_ptr<OrtValue> past_sequence_length_;
 };
diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp
index 5cb66ade3..4fa3243e2 100644
--- a/src/models/kv_cache.cpp
+++ b/src/models/kv_cache.cpp
@@ -220,23 +220,13 @@ KV_Cache::KV_Cache(State& state)
     }
   }
 
-  auto kv_cache_size_bytes = SizeOf(type_) * shape_[0] * shape_[1] * shape_[2] * shape_[3];
   try {
     for (int i = 0; i < layer_count_ * 2; ++i) {
       presents_.push_back(
           sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_)
                                 : sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_));
-#if USE_CUDA
-      if (model_.device_type_ == DeviceType::CUDA) {
-        cudaMemsetAsync(presents_.back()->GetTensorMutableRawData(), 0, kv_cache_size_bytes, model_.cuda_stream_);
-      } else
-#endif
-      {
-        if (model_.device_type_ == DeviceType::CPU) {
-          // FIXME: this is a device ternsor and we can only use memset for cpu. Revisit for other EPs.
-          memset(presents_.back()->GetTensorMutableRawData(), 0, kv_cache_size_bytes);
-        }
-      }
+      // Zero the memory so we don't leak any data from the previous run
+      ByteWrapTensor(*model_.p_device_, *presents_.back()).Zero();
     }
   } catch (const Ort::Exception&) {
     std::ostringstream oss;
diff --git a/src/models/logits.cpp b/src/models/logits.cpp
index d4821c4b1..ae7d9bdcd 100644
--- a/src/models/logits.cpp
+++ b/src/models/logits.cpp
@@ -37,9 +37,6 @@ Logits::Logits(State& state)
   input_sequence_lengths.resize(state_.params_->search.batch_size);
 }
 
-#pragma warning(push)
-#pragma warning(disable : 4189)  // local variable is initialized but not referenced
-
 DeviceSpan<float> Logits::Get() {
   size_t element_count = shape_[0] * shape_[1] * shape_[2];
 
@@ -50,7 +47,6 @@ DeviceSpan<float> Logits::Get() {
     const size_t seq_length = shape_[1];
     const size_t vocab_size = shape_[2];
     const size_t num_beams = state_.params_->search.num_beams;
-    const size_t element_count_last_token = shape_[0] * shape_[2];
 
     // create new OrtValue for logits_of_last_token and use output_last_tokens_ to hold it
     output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_last, type_);
@@ -60,54 +56,20 @@ DeviceSpan<float> Logits::Get() {
 
     logits_of_last_token = output_last_tokens_.get();
 
-    size_t element_size = type_ == Ort::TypeToTensorType<float> ? 4 : 2;
+    size_t element_size = SizeOf(type_);
     size_t vocab_index = 0;  // Simpler math to have this index go up by vocab_size for every logit chunk we process
 
+    auto logits_raw = ByteWrapTensor(*state_.params_->p_device, *output_raw_);
+    auto logits_last_tokens = ByteWrapTensor(*state_.params_->p_device, *logits_of_last_token);
+
     for (int batch_index = 0; batch_index < state_.params_->search.batch_size; batch_index++) {
       // Find the first non pad token from the end
       size_t token_index = input_sequence_lengths[batch_index] - 1;
       for (int beam_index = 0; beam_index < num_beams; beam_index++) {
-        switch (model_.device_type_) {
-          case DeviceType::DML: {
-#if USE_DML
-            ComPtr<ID3D12Resource> source_resource;
-            Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, output_raw_->GetTensorMutableRawData(), &source_resource));
-
-            ComPtr<ID3D12Resource> target_resource;
-            Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, logits_of_last_token->GetTensorMutableRawData(), &target_resource));
-
-            uint64_t source_offset = (vocab_index * seq_length + token_index * vocab_size) * element_size;
-            uint64_t target_offset = vocab_index * element_size;
-            uint64_t size_in_bytes = vocab_size * element_size;
-
-            model_.GetDmlExecutionContext()->CopyBufferRegion(
-                target_resource.Get(),
-                target_offset,
-                D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
-                source_resource.Get(),
-                source_offset,
-                D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
-                size_in_bytes);
-#endif
-          } break;
-
-          default: {
-            // CPU, CUDA, WEBGPU
-            auto logits_raw = std::span<const uint8_t>{output_raw_->GetTensorMutableData<uint8_t>(), element_count * element_size};
-            auto logits_last_tokens = std::span<uint8_t>{logits_of_last_token->GetTensorMutableData<uint8_t>(), element_count_last_token * element_size};
-            auto target = logits_last_tokens.subspan(vocab_index * element_size, vocab_size * element_size);
-            auto source = logits_raw.subspan((vocab_index * seq_length + token_index * vocab_size) * element_size, vocab_size * element_size);
-            if (model_.device_type_ == DeviceType::CUDA)
-#if USE_CUDA
-              CudaCheck() == cudaMemcpyAsync(target.data(), source.data(), source.size_bytes(), cudaMemcpyDeviceToDevice, state_.params_->cuda_stream);
-#else
-              throw std::runtime_error("Unexpected CUDA device usage");
-#endif
-            else
-              copy(source, target);
-          } break;
-        }
 
+        auto target = logits_last_tokens.subspan(vocab_index * element_size, vocab_size * element_size);
+        auto source = logits_raw.subspan((vocab_index * seq_length + token_index * vocab_size) * element_size, vocab_size * element_size);
+        target.CopyFrom(source);
         vocab_index += vocab_size;
       }
     }
@@ -117,35 +79,10 @@ DeviceSpan<float> Logits::Get() {
 
   // Convert from float16 to float32 if necessary
   if (type_ == Ort::TypeToTensorType<Ort::Float16_t>) {
-#if USE_DML
-    if (model_.device_type_ == DeviceType::DML) {
-      DmlHelpers::DmlCastInputToOutput(
-          model_.GetDmlExecutionContext(),
-          *model_.allocator_device_,
-          *logits_of_last_token,
-          logits_of_last_token_fp32_,
-          model_.GetDmlDevice(),
-          model_.GetOrtDmlApi(),
-          logits_cast_command_list_state_);
-
-      logits_of_last_token = logits_of_last_token_fp32_.get();
-    } else
-#endif
-    {
       ConvertFp16ToFp32(*model_.allocator_device_, *logits_of_last_token, logits_of_last_token_fp32_, model_.device_type_, model_.cuda_stream_);
       logits_of_last_token = logits_of_last_token_fp32_.get();
-    }
   }
 
-  assert(shape_[1] == 1);
-
-#if USE_DML
-  // DML doesn't support on-device scoring yet, so we need to download some data to the CPU
-  if (model_.device_type_ == DeviceType::DML) {
-    value32_cpu_ = OrtValue::CreateTensor<float>(model_.allocator_cpu_, shape_last);
-  }
-#endif
-
   if (logits_.empty() || logits_of_last_token->GetTensorMutableRawData() != logits_.Span().data())
     logits_ = WrapTensor<float>(*state_.params_->p_device, *logits_of_last_token);
 
@@ -162,36 +99,11 @@ DeviceSpan<float> Logits::Get() {
     return logits_;
   }
 #endif
-#if USE_DML
-  if (model_.device_type_ == DeviceType::DML) {
-    // DML doesn't support on-device scoring yet, so we transfer the data to the CPU
-    ComPtr<ID3D12Resource> gpu_resource;
-    Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(
-        model_.allocator_device_,
-        logits_of_last_token->GetTensorMutableData<float>(),
-        &gpu_resource));
-    auto cpu_tensor = value32_cpu_->GetTensorMutableData<float>();
-
-    model_.GetDmlReadbackHeap()->ReadbackFromGpu(
-        std::span(reinterpret_cast<uint8_t*>(cpu_tensor), element_count * sizeof(float)),
-        gpu_resource.Get(),
-        0,
-        D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
-
-    auto batched_logits_cpu = cpu_span<float>{cpu_tensor, element_count};
-    HandleEOSArray(batched_logits_cpu);
-
-    logits_ = WrapTensor<float>(*state_.params_->p_device, *value32_cpu_);
-    return logits_;
-  }
-#endif
 
   HandleEOSArray(logits_.Span());
   return logits_;
 }
 
-#pragma warning(pop)
-
 void Logits::Update(const DeviceSpan<int32_t>& next_tokens, size_t new_kv_length) {
   if (static_cast<size_t>(output_raw_.get()->GetTensorTypeAndShapeInfo()->GetShape()[1]) == new_kv_length && new_kv_length == 1) {
     return;
diff --git a/src/models/logits.h b/src/models/logits.h
index f723c48ee..187ecd8cb 100644
--- a/src/models/logits.h
+++ b/src/models/logits.h
@@ -46,11 +46,6 @@ struct Logits {
 #if USE_CUDA
   DeviceSpan<int32_t> cuda_eos_token_ids_;  // eos_token_ids from params, but in cuda accessible memory
 #endif
-
-#if USE_DML
-  DmlReusedCommandListState logits_cast_command_list_state_{};
-  std::unique_ptr<OrtValue> value32_cpu_;
-#endif
 };
 
 }  // namespace Generators
diff --git a/src/models/model.cpp b/src/models/model.cpp
index 4430fdb41..9270a2cc7 100644
--- a/src/models/model.cpp
+++ b/src/models/model.cpp
@@ -15,11 +15,7 @@
 #include "multi_modal_vision_model.h"
 #include "decoder_only_pipeline.h"
 #if USE_DML
-#include <wil/wrl.h>
-#include "dml_provider_factory.h"
-#include "../dml/dml_helpers.h"
-
-std::string CurrentModulePath();
+#include "../dml/interface.h"
 #endif
 
 namespace Generators {
@@ -222,20 +218,20 @@ int32_t Tokenizer::TokenToTokenId(const char* token) const {
   return token_id;
 }
 
-#if USE_CUDA
 // Since Python/Others can and will hold onto a generator object past the model object's lifetime we need to ensure
 // the allocator used is not destroyed until last. This keeps the allocator around until exit, after all other memory
 // has been destroyed. Without this, we will crash in the Onnxruntime BFCArena code when deleting tensors due to the
 // arena already being destroyed.
-Ort::Allocator* GetCudaAllocator(OrtSession& session) {
-  auto& globals = *GetOrtGlobals();
-  if (!globals.allocator_cuda_) {
-    globals.memory_info_cuda_ = OrtMemoryInfo::Create("Cuda", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
-    globals.allocator_cuda_ = Ort::Allocator::Create(session, *globals.memory_info_cuda_);
+Ort::Allocator* GetDeviceAllocator(OrtSession& session, DeviceType type) {
+  auto& device = GetOrtGlobals()->allocator_device_[static_cast<int>(type)];
+  if (!device) {
+    static const char* device_type_names[static_cast<int>(DeviceType::MAX)] = {"CPU", "Cuda", "DML", "WebGPU Buffer"};
+    auto memory_info = OrtMemoryInfo::Create(device_type_names[static_cast<int>(type)], OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
+    device = Ort::Allocator::Create(session, *memory_info);
+    GetDeviceInterface(type)->InitAllocator(*device);
   }
-  return globals.allocator_cuda_.get();
+  return device.get();
 }
-#endif
 
 SessionInfo::SessionInfo(OrtSession& session) {
   Add(session);
@@ -290,27 +286,15 @@ Model::~Model() = default;
 
 void Model::InitDeviceAllocator([[maybe_unused]] OrtSession& session) {
   allocator_device_ = &allocator_cpu_;
-#if USE_CUDA
-  if (device_type_ == DeviceType::CUDA) {
-    allocator_device_ = GetCudaAllocator(session);
-  }
-#endif
-#if USE_DML
-  if (device_type_ == DeviceType::DML) {
-    memory_info_device_ = OrtMemoryInfo::Create("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
-    dml_owned_allocator_ = Ort::Allocator::Create(session, *memory_info_device_);
-    allocator_device_ = dml_owned_allocator_.get();
-  }
-#endif
+  if (device_type_== DeviceType::CUDA)
+    allocator_device_ = GetDeviceAllocator(session, device_type_);
+
   allocator_kvcache_ = allocator_device_;
-#if USE_WEBGPU
-  if (device_type_ == DeviceType::WEBGPU) {
-    // for webgpu we only use device memory for kv_cache
-    memory_info_device_ = OrtMemoryInfo::Create("WebGPU_Buffer", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
-    webgpu_owned_allocator_ = Ort::Allocator::Create(session, *memory_info_device_);
-    allocator_kvcache_ = webgpu_owned_allocator_.get();
+  if (device_type_ == DeviceType::WEBGPU || device_type_ == DeviceType::DML) {
+    // for dml and webgpu we only use device memory for kv_cache
+    allocator_kvcache_ = GetDeviceAllocator(session, device_type_);
   }
-#endif
+
   session_info_ = std::make_unique<SessionInfo>(session);
   captured_graph_pool_ = std::make_shared<CapturedGraphPool>(config_.get(), session_info_.get(), allocator_device_);
 }
@@ -436,52 +420,19 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
 #if USE_DML
     } else if (provider_options.name == "dml") {
       if (!p_dml_api_) {
-        auto current_module_path = CurrentModulePath();
-
-        bool contains_device_luid = false;
         LUID device_luid{};
+        LUID* p_device_luid{};
         for (const auto& [name, value] : provider_options.options) {
           if (name == "luid") {
             if (auto separator_position = value.find(":"); separator_position != std::string::npos) {
               device_luid.HighPart = std::stol(value.substr(0, separator_position));
               device_luid.LowPart = std::stol(value.substr(separator_position + 1));
-              contains_device_luid = true;
+              p_device_luid = &device_luid;
             }
           }
         }
 
-        if (contains_device_luid) {
-          dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path, &device_luid);
-        } else {
-          dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path);
-        }
-
-        constexpr auto directml_dll = "DirectML.dll";
-        wil::unique_hmodule smart_directml_dll(LoadLibraryEx(directml_dll, nullptr, 0));
-        THROW_LAST_ERROR_IF(!smart_directml_dll);
-
-        if (LoadLibraryEx(directml_dll, nullptr, 0) == NULL) {
-          throw std::runtime_error("DirectML.dll not found");
-        }
-
-        auto dml_create_device1_fn = reinterpret_cast<decltype(&DMLCreateDevice1)>(GetProcAddress(smart_directml_dll.get(), "DMLCreateDevice1"));
-        THROW_LAST_ERROR_IF(!dml_create_device1_fn);
-        THROW_IF_FAILED(dml_create_device1_fn(dml_objects_.d3d12_device.Get(), DML_CREATE_DEVICE_FLAG_NONE, DML_FEATURE_LEVEL_5_0, IID_PPV_ARGS(&dml_device_)));
-
-        Ort::ThrowOnError(Ort::api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&p_dml_api_)));
-        if (!p_dml_api_) {
-          throw std::runtime_error("Unexpected nullptr getting OrtDmlApi");
-        }
-
-        dml_execution_context_ = std::make_unique<DmlExecutionContext>(
-            dml_objects_.d3d12_device.Get(),
-            dml_device_.Get(),
-            dml_objects_.command_queue.Get(),
-            *allocator_device_,
-            p_dml_api_);
-
-        dml_pooled_upload_heap_ = std::make_unique<DmlPooledUploadHeap>(dml_objects_.d3d12_device.Get(), dml_execution_context_.get());
-        dml_readback_heap_ = std::make_unique<DmlReadbackHeap>(dml_objects_.d3d12_device.Get(), dml_execution_context_.get());
+        InitDmlInterface(p_device_luid);
       }
 
       if (!disable_graph_capture) {
@@ -489,7 +440,7 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
         session_options.AddConfigEntry("ep.dml.disable_memory_arena", "1");
       }
 
-      p_dml_api_->SessionOptionsAppendExecutionProvider_DML1(&session_options, dml_device_.Get(), dml_objects_.command_queue.Get());
+      SetDmlProvider(session_options);
 
       if (is_primary_session_options)
         device_type_ = DeviceType::DML;  // We use a DML allocator for input/output caches, but other tensors will use CPU tensors
@@ -520,6 +471,11 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
 
 void Model::CreateSessionOptions() {
   session_options_ = OrtSessionOptions::Create();
+#if 0
+  ClearProviders(*config_);
+  SetProviderOption(*config_, "dml", {}, {});
+#endif
+
   CreateSessionOptionsFromConfig(config_->model.decoder.session_options, *session_options_, true, false);
 
   for (auto& pipeline_model : config_->model.decoder.pipeline) {
diff --git a/src/models/model.h b/src/models/model.h
index 573dd5b6c..55ad960b3 100644
--- a/src/models/model.h
+++ b/src/models/model.h
@@ -181,14 +181,9 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
   std::unique_ptr<DmlExecutionContext> dml_execution_context_;
   std::unique_ptr<DmlReadbackHeap> dml_readback_heap_;
   ComPtr<IDMLDevice> dml_device_;
-  std::unique_ptr<Ort::Allocator> dml_owned_allocator_;
 #endif
 #if USE_WEBGPU
-  std::unique_ptr<Ort::Allocator> webgpu_owned_allocator_;
   std::unique_ptr<OrtIoBinding> webgpu_io_binding_;
-#endif
-#if USE_DML || USE_WEBGPU
-  std::unique_ptr<OrtMemoryInfo> memory_info_device_;
 #endif
   std::shared_ptr<CapturedGraphPool> captured_graph_pool_;
   std::map<std::string, std::unique_ptr<OrtSessionOptions>> pipeline_session_options_;
diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp
index a6180e4ee..5d4c82684 100644
--- a/src/models/position_inputs.cpp
+++ b/src/models/position_inputs.cpp
@@ -219,10 +219,7 @@ void PositionInputs::CreateNextAttentionMaskTensor(int total_length) {
     attention_mask_shape_[1] = state_.params_->search.max_length;
     attention_mask_next_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_);
     if (is_first_mask_update_) {
-      cudaMemsetAsync(attention_mask_next_->GetTensorMutableRawData(),
-                      0,
-                      (type_ == Ort::TypeToTensorType<int32_t> ? sizeof(int32_t) : sizeof(int64_t)) * attention_mask_shape_[0] * attention_mask_shape_[1],
-                      model_.cuda_stream_);
+      ByteWrapTensor(*model_.p_device_, *attention_mask_next_).Zero();
     }
 #elif USE_DML
     attention_mask_shape_[1] = state_.params_->search.max_length;
diff --git a/src/models/utils.cpp b/src/models/utils.cpp
index 7f4d43629..dd4bef813 100644
--- a/src/models/utils.cpp
+++ b/src/models/utils.cpp
@@ -1,9 +1,15 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 #include "../generators.h"
+#include "utils.h"
 
 namespace Generators {
 
+DeviceSpan<uint8_t> ByteWrapTensor(DeviceInterface& device, OrtValue& value) {
+  auto info = value.GetTensorTypeAndShapeInfo();
+  return device.WrapMemory(std::span<uint8_t>{value.GetTensorMutableData<uint8_t>(), info->GetElementCount() * SizeOf(info->GetElementType())});
+}
+
 size_t SizeOf(ONNXTensorElementDataType type) {
   switch (type) {
     case Ort::TypeToTensorType<uint8_t>:
diff --git a/src/models/whisper.cpp b/src/models/whisper.cpp
index 2988b9108..723b7aeda 100644
--- a/src/models/whisper.cpp
+++ b/src/models/whisper.cpp
@@ -272,16 +272,12 @@ DeviceSpan<float> Whisper_State::Run(int current_length, DeviceSpan<int32_t>& ne
       }
 
       if (model_.session_info_->HasInput("cache_indirection")) {
-#if USE_CUDA
         cache_indirection_ = OrtValue::CreateTensor<int32_t>(*model_.allocator_device_, std::array<int64_t, 3>{params_->search.batch_size, params_->search.num_beams, params_->search.max_length});
         cache_indirection_index_ = inputs_.size();
         input_names_.push_back("cache_indirection");
         inputs_.push_back(cache_indirection_.get());
 
-        auto data = gpu_span<int32_t>{cache_indirection_->GetTensorMutableData<int32_t>(),
-                                      static_cast<size_t>(params_->BatchBeamSize()) * params_->search.max_length};
-        CudaCheck() == cudaMemsetAsync(data.data(), 0, data.size_bytes(), params_->cuda_stream);
-#endif
+        ByteWrapTensor(*model_.p_device_, *cache_indirection_).Zero();
       }
 
       if (model_.session_info_->HasOutput("output_cross_qk_0")) {
diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp
index 87cf6f4c5..8b9a87852 100644
--- a/src/ort_genai_c.cpp
+++ b/src/ort_genai_c.cpp
@@ -331,37 +331,13 @@ OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator
   auto& generator = *reinterpret_cast<const Generators::Generator*>(oga_generator);
   auto* ortvalue_output = generator.state_->GetOutput(name);
   auto type_info = ortvalue_output->GetTensorTypeAndShapeInfo();
-  std::unique_ptr<OrtValue> ortvalue_clone = OrtValue::CreateTensor(generator.model_->allocator_cpu_,
-                                                                    type_info->GetShape(),
-                                                                    type_info->GetElementType());
+  auto ortvalue_clone = OrtValue::CreateTensor(generator.model_->allocator_cpu_, type_info->GetShape(), type_info->GetElementType());
+
   // Copy data to ortvalue_clone
-  auto element_size = Generators::SizeOf(type_info->GetElementType());
-  auto data_size = type_info->GetElementCount() * element_size;
-  if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::CUDA) {
-#if USE_CUDA
-    cudaMemcpy(ortvalue_clone->GetTensorMutableRawData(), ortvalue_output->GetTensorMutableRawData(), data_size, cudaMemcpyDeviceToHost);
-#endif
-  } else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::DML) {
-#if USE_DML
-    ComPtr<ID3D12Resource> gpu_resource;
-    Ort::ThrowOnError(generator.model_->GetOrtDmlApi()->GetD3D12ResourceFromAllocation(
-        generator.model_->allocator_device_,
-        ortvalue_output->GetTensorMutableRawData(),
-        &gpu_resource));
-    auto cpu_tensor = ortvalue_clone->GetTensorMutableRawData();
-    generator.model_->GetDmlReadbackHeap()->ReadbackFromGpu(
-        std::span(reinterpret_cast<uint8_t*>(cpu_tensor), data_size),
-        gpu_resource.Get(),
-        0,
-        D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
-#endif
-  } else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_CPU) {
-    std::copy(static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()),
-              static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()) + data_size,
-              static_cast<uint8_t*>(ortvalue_clone->GetTensorMutableRawData()));
-  } else {
-    throw std::runtime_error("Unsupported Device type: " + std::to_string(ortvalue_output->GetTensorMemoryInfo().GetDeviceType()));
-  }
+  bool is_cpu = ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_CPU;
+  auto output_span = Generators::ByteWrapTensor(is_cpu ? *Generators::GetDeviceInterface(Generators::DeviceType::CPU) : *generator.model_->p_device_, *ortvalue_output);
+  auto copy_span = Generators::ByteWrapTensor(*Generators::GetDeviceInterface(Generators::DeviceType::CPU), *ortvalue_clone);
+  copy_span.CopyFrom(output_span);
 
   auto tensor = std::make_shared<Generators::Tensor>(std::move(ortvalue_clone));
   tensor->external_owner_ = tensor;
diff --git a/src/python/python.cpp b/src/python/python.cpp
index 4ee8f87ed..0d825665d 100644
--- a/src/python/python.cpp
+++ b/src/python/python.cpp
@@ -12,10 +12,6 @@
 #include "../logging.h"
 #include "../smartptrs.h"
 
-#if USE_CUDA
-#include "../cuda/cuda_common.h"
-#endif
-
 using namespace pybind11::literals;
 
 // If a parameter to a C++ function is an array of float16, this type will let pybind11::array_t<Ort::Float16_t> map to numpy's float16 format
@@ -143,50 +139,21 @@ pybind11::array ToNumpy(OrtValue* v, const Generators::Model& model) {
   auto shape = type_info->GetShape();
   auto type = type_info->GetElementType();
   auto element_size = Generators::SizeOf(type);
-  auto data = v->GetTensorMutableRawData();
-
-  std::unique_ptr<uint8_t[]> cpu_copy;
-
-#if USE_DML
-  // TODO: DML version of this
-  if (v->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && model.device_type_ == Generators::DeviceType::DML) {
-    auto data_size = type_info->GetElementCount() * element_size;
-    cpu_copy = std::make_unique<uint8_t[]>(data_size);
-
-    ComPtr<ID3D12Resource> gpu_resource;
-    Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(
-        model.allocator_device_,
-        data,
-        &gpu_resource));
-
-    model.GetDmlReadbackHeap()->ReadbackFromGpu(
-        std::span(reinterpret_cast<uint8_t*>(cpu_copy.get()), data_size),
-        gpu_resource.Get(),
-        0,
-        D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
-    data = cpu_copy.get();
-  }
-#endif
-#if USE_CUDA
-  if (v->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && model.device_type_ == Generators::DeviceType::CUDA) {
-    auto data_size = type_info->GetElementCount() * element_size;
-    cpu_copy = std::make_unique<uint8_t[]>(data_size);
-    Generators::CudaCheck() == cudaMemcpy(cpu_copy.get(), data, data_size, cudaMemcpyDeviceToHost);
-    data = cpu_copy.get();
-  }
-#endif
 
   std::vector<int64_t> strides(shape.size());
   {
-    auto size = Generators::SizeOf(type);
+    auto size = element_size;
     for (size_t i = strides.size(); i-- > 0;) {
       strides[i] = size;
       size *= shape[i];
     }
   }
 
+  bool is_cpu = v->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_CPU;
+  auto device_span = Generators::ByteWrapTensor(is_cpu ? *Generators::GetDeviceInterface(Generators::DeviceType::CPU) : *model.p_device_, *v);
+
   pybind11::buffer_info bufinfo{
-      data,                                          // Pointer to memory buffer
+      device_span.CopyDeviceToCpu().data(),          // Pointer to memory buffer
       static_cast<pybind11::ssize_t>(element_size),  // Size of underlying scalar type
       ToFormatDescriptor(type),                      // Python struct-style format descriptor
       static_cast<pybind11::ssize_t>(shape.size()),  // Number of dimensions
diff --git a/src/search.cpp b/src/search.cpp
index 274a5b2dd..c4aacf333 100644
--- a/src/search.cpp
+++ b/src/search.cpp
@@ -2,16 +2,18 @@
 #include "softmax.h"
 #include "search.h"
 #include "beam_search_scorer.h"
+#include "cpu/interface.h"
 #include <queue>
 #include <algorithm>
 
 namespace Generators {
 
 Search_Cpu::Search_Cpu(const GeneratorParams& params)
-    : Search{params} {
+    : Search{params},
+      cpu_device_{*GetCpuInterface()} {
   auto batch_beam_size = params.BatchBeamSize();
 
-  sequence_lengths_ = params.p_device->Allocate<int32_t>(batch_beam_size);
+  sequence_lengths_ = cpu_device_.Allocate<int32_t>(batch_beam_size);
 }
 
 GreedySearch_Cpu::GreedySearch_Cpu(const GeneratorParams& params)
@@ -26,9 +28,9 @@ GreedySearch_Cpu::GreedySearch_Cpu(const GeneratorParams& params)
     gen_.seed(seq);
   }
 
-  next_tokens_ptr_ = params.p_device->Allocate<int32_t>(params.search.batch_size);
+  next_tokens_ptr_ = cpu_device_.Allocate<int32_t>(params.search.batch_size);
+  next_tokens_ptr_.Zero();
   next_tokens_ = cpu_span<int32_t>(next_tokens_ptr_.Span());
-  memset(next_tokens_.data(), 0, next_tokens_.size_bytes());
 
   eos_seen_buffer_ = AllocateArray<bool>(params.search.batch_size, &eos_seen_);
   memset(eos_seen_.data(), 0, eos_seen_.size_bytes());
@@ -368,7 +370,6 @@ DeviceSpan<int32_t> BeamSearch_Cpu::GetSequence(size_t index) {
   return beam_scorer_->GetBeamHypotheses(batch_id, beam_id);
 }
 
-// TODO(aciddelgado): my question is, should this return copy or reference? A: A copy, as with DeviceSpan it's like a span
 DeviceSpan<int32_t> BeamSearch_Cpu::GetSequence(size_t batch_id, size_t beam_id) {
   Finalize(params_->search.num_return_sequences);
   return beam_scorer_->GetBeamHypotheses(batch_id, beam_id);
diff --git a/src/search.h b/src/search.h
index 35fc521d8..9d456368c 100644
--- a/src/search.h
+++ b/src/search.h
@@ -52,6 +52,8 @@ struct Search_Cpu : Search {
 
   std::span<float> GetScores(int batch_beam_index);
 
+  DeviceInterface& cpu_device_;
+
   DeviceSpan<int32_t> sequence_lengths_;  // shape (beam_size*batch_size)
 
   cpu_span<int32_t> next_tokens_;  // shape (beam_size*batch_size)
@@ -82,7 +84,6 @@ struct GreedySearch_Cpu : Search_Cpu {
 
   bool PadIfAlreadyEOS(size_t batch_id);
 
-  std::unique_ptr<int32_t[]> next_tokens_buffer_;
   DeviceSpan<int32_t> next_tokens_ptr_;
   std::unique_ptr<int32_t[]> temp_topk_buffer_;
 
diff --git a/src/smartptrs.h b/src/smartptrs.h
index ac7037957..fbaf12ed2 100644
--- a/src/smartptrs.h
+++ b/src/smartptrs.h
@@ -5,6 +5,10 @@
 #include <memory>
 #include "span.h"
 
+namespace Ort {
+struct Allocator;
+}
+
 namespace Generators {
 struct Search;
 struct Sequences;
@@ -21,6 +25,7 @@ struct DeviceBuffer : std::enable_shared_from_this<DeviceBuffer> {
   virtual void CopyDeviceToCpu() = 0;  // Allocates p_cpu_ if necessary and copies p_device_ memory into it
   virtual void CopyCpuToDevice() = 0;
   virtual void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) = 0;
+  virtual void Zero() = 0;  // Zero out the device memory
 
   uint8_t* p_device_{};
   uint8_t* p_cpu_{};
@@ -38,6 +43,8 @@ struct DeviceSpan {
   bool empty() const { return length_ == 0; }
   size_t size() const { return length_; }
 
+  operator DeviceSpan<const T>() const { return DeviceSpan<const T>(*p_device_memory_, begin_, length_); }
+
   DeviceSpan<T> subspan(size_t begin, size_t length) { return DeviceSpan<T>(*p_device_memory_, begin_ + begin, length); }
 
   // Return the device accessible memory. Should only be done in device specific code, as it's not CPU accessible
@@ -58,24 +65,35 @@ struct DeviceSpan {
   // Copy CPU memory to device memory, typically used after calling CpuSpan or CopyDeviceToCpu to update the device memory with the modifications made
   void CopyCpuToDevice() { p_device_memory_->CopyCpuToDevice(); }
 
+  // Zero out the device memory
+  void Zero() { p_device_memory_->Zero(); }
+
+  void CopyFrom(const DeviceSpan<const T>& source) {
+    assert(source.size() == size());  // Spans must be the same size to copy
+    p_device_memory_->CopyFrom(begin_ * sizeof(T), *source.p_device_memory_, source.begin_ * sizeof(T), length_ * sizeof(T));
+  }
+
  private:
   DeviceSpan(DeviceBuffer& memory, size_t begin, size_t length)
       : p_device_memory_{memory.shared_from_this()}, begin_{begin}, length_{length} {}
 
   std::shared_ptr<DeviceBuffer> p_device_memory_;
   size_t begin_{}, length_{};  // Subspan of p_device_memory_, relative to original memory block
+  template <typename U>
+  friend struct DeviceSpan;  // All DeviceSpans are friends
 };
 
 struct DeviceInterface {
   virtual ~DeviceInterface() {}
+  virtual void InitAllocator(Ort::Allocator& allocator) = 0;
 
   template <typename T>
-  DeviceSpan<T> Allocate(size_t count, bool cpu_accessible = false) { return DeviceSpan<T>(AllocateBase(sizeof(T) * count, cpu_accessible)); }
-  virtual std::shared_ptr<DeviceBuffer> AllocateBase(size_t size, bool cpu_accessible) = 0;
+  DeviceSpan<T> Allocate(size_t count) { return DeviceSpan<T>(AllocateBase(sizeof(T) * count)); }
+  virtual std::shared_ptr<DeviceBuffer> AllocateBase(size_t size) = 0;
 
   // Wraps an existing memory block, useful for tensors. Use WrapTensor for OrtValue vs calling this directly
   template <typename T>
-  DeviceSpan<T> WrapMemory(std::span<T> memory) { return DeviceSpan<T>(WrapMemoryBase(memory.data(), memory.size_bytes())); }
+  DeviceSpan<T> WrapMemory(std::span<T> memory) { return DeviceSpan<T>(WrapMemoryBase(const_cast<std::remove_const_t<T>*>(memory.data()), memory.size_bytes())); }
   virtual std::shared_ptr<DeviceBuffer> WrapMemoryBase(void* memory, size_t size) = 0;
 
   virtual std::unique_ptr<Search> CreateGreedy(const GeneratorParams& params) = 0;
diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp
index 56cd1644f..18caa0a10 100644
--- a/test/c_api_tests.cpp
+++ b/test/c_api_tests.cpp
@@ -32,6 +32,9 @@ TEST(CAPITests, Config) {
 #endif
 }
 
+#undef USE_CUDA
+#define USE_CUDA 0
+
 TEST(CAPITests, TokenizerCAPI) {
 #if TEST_PHI2
   auto config = OgaConfig::Create(PHI2_PATH);
diff --git a/test/sampling_benchmark.cpp b/test/sampling_benchmark.cpp
index eb7f04cd9..c15478009 100644
--- a/test/sampling_benchmark.cpp
+++ b/test/sampling_benchmark.cpp
@@ -14,6 +14,9 @@
 #define MODEL_PATH "../../test/test_models/"
 #endif
 
+#undef USE_CUDA
+#define USE_CUDA 0
+
 // Defined in sampling_tests.cpp
 void CreateRandomLogits(float* logits, int num_large, int vocab_size, int batch_size, std::mt19937& engine);
 
diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp
index a71910b15..0482f704d 100644
--- a/test/sampling_tests.cpp
+++ b/test/sampling_tests.cpp
@@ -13,6 +13,9 @@
 #define MODEL_PATH "../../test/test_models/"
 #endif
 
+#undef USE_CUDA
+#define USE_CUDA 0
+
 template<typename T>
 auto AllocateFromCpuMem(Generators::DeviceInterface& device, std::span<const T> cpu_memory) {
   auto memory = device.Allocate<float>(cpu_memory.size());

From 41b462a5babaa0b0ef7b9e9107435cf7e7197b52 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Wed, 15 Jan 2025 15:34:15 -0800
Subject: [PATCH 03/31] Finish refactoring model processing Remove as many #if
 USE_CUDA/USE_DML as possible

---
 src/cpu/interface.cpp                 |  48 +++++++-
 src/cuda/beam_search_scorer_cuda.cpp  |   2 +-
 src/cuda/beam_search_scorer_cuda.cuh  |   1 +
 src/cuda/cuda_sampling.cu             |  13 +-
 src/cuda/cuda_sampling.cuh            |   2 +-
 src/cuda/interface.cpp                | 118 +++++++++---------
 src/cuda/interface.h                  |  24 ----
 src/{models => cuda}/kernels.h        |   2 -
 src/cuda/model_kernels.cu             |  30 -----
 src/cuda/search_cuda.cpp              |  34 +++---
 src/cuda/search_cuda.h                |   3 +-
 src/dml/interface.cpp                 |  60 ++++++++-
 src/generators.cpp                    |  57 ++-------
 src/generators.h                      |   9 +-
 src/models/audio_processor.cpp        |   2 +-
 src/models/captured_graph_pool.h      |   5 -
 src/models/debugging.cpp              |  34 +-----
 src/models/extra_inputs.cpp           |   1 -
 src/models/input_ids.cpp              |  76 +++---------
 src/models/input_ids.h                |   3 +-
 src/models/kv_cache.cpp               | 119 ++++++------------
 src/models/logits.cpp                 |  17 +--
 src/models/logits.h                   |   2 -
 src/models/model.cpp                  | 132 ++++----------------
 src/models/model.h                    |  31 +----
 src/models/position_inputs.cpp        | 169 +++-----------------------
 src/models/position_inputs.h          |  24 ----
 src/models/prompt_image_processor.cpp |   2 +-
 src/models/whisper.cpp                |  20 ++-
 src/smartptrs.h                       |  16 ++-
 test/model_tests.cpp                  |   8 +-
 31 files changed, 333 insertions(+), 731 deletions(-)
 rename src/{models => cuda}/kernels.h (92%)

diff --git a/src/cpu/interface.cpp b/src/cpu/interface.cpp
index b27882668..c6705fa3b 100644
--- a/src/cpu/interface.cpp
+++ b/src/cpu/interface.cpp
@@ -3,6 +3,7 @@
 
 #include "../generators.h"
 #include "../search.h"
+#include "../models/utils.h"
 #include "interface.h"
 
 namespace Generators {
@@ -13,7 +14,7 @@ const char* label_cpu = "cpu";
 struct CpuMemory final : DeviceBuffer {
   CpuMemory(size_t size) : owned_{true} {
     size_in_bytes_ = size;
-    p_cpu_ = p_device_ = new uint8_t[size_in_bytes_];
+    p_cpu_ = p_device_ = static_cast<uint8_t*>(ort_allocator_->Alloc(size_in_bytes_));
   }
 
   CpuMemory(void* p, size_t size) : owned_{false} {
@@ -23,7 +24,7 @@ struct CpuMemory final : DeviceBuffer {
 
   ~CpuMemory() override {
     if (owned_)
-      delete[] p_device_;
+      ort_allocator_->Free(p_device_);
   }
 
   const char* GetType() const override { return label_cpu; }
@@ -42,11 +43,19 @@ struct CpuMemory final : DeviceBuffer {
 };
 
 struct CpuInterface : DeviceInterface {
-  void InitAllocator(Ort::Allocator& allocator) override {
+  CpuInterface() {
+    InitOrt(*Ort::api, Ort::Allocator::GetWithDefaultOptions());
+  }
+
+  void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override {
     assert(!ort_allocator_);
     ort_allocator_ = &allocator;
   }
 
+  Ort::Allocator& GetAllocator() override {
+    return *ort_allocator_;
+  }
+
   std::shared_ptr<DeviceBuffer> AllocateBase(size_t size) override {
     return std::make_shared<CpuMemory>(size);
   }
@@ -55,6 +64,39 @@ struct CpuInterface : DeviceInterface {
     return std::make_shared<CpuMemory>(p, size);
   }
 
+  bool Cast(OrtValue& input, OrtValue& output) override {
+    auto input_info = input.GetTensorTypeAndShapeInfo();
+    auto output_info = output.GetTensorTypeAndShapeInfo();
+
+    auto input_type = input_info->GetElementType();
+    auto output_type = output_info->GetElementType();
+
+    auto element_count = input_info->GetElementCount();
+    if (element_count != output_info->GetElementCount())
+      throw std::runtime_error("Cast - input and output element counts do not match");
+    if (input_type == output_type)
+      throw std::runtime_error("Cast - input and output types are the same");
+
+    if (input_type == Ort::TypeToTensorType<float> && output_type == Ort::TypeToTensorType<Ort::Float16_t>) {
+      auto* fp32 = input.GetTensorData<float>();
+      auto* fp16 = output.GetTensorMutableData<uint16_t>();
+      for (size_t i = 0; i < element_count; i++)
+        fp16[i] = FastFloat32ToFloat16(fp32[i]);
+    } else if (input_type == Ort::TypeToTensorType<Ort::Float16_t> && output_type == Ort::TypeToTensorType<float>) {
+      auto* fp16 = input.GetTensorData<uint16_t>();
+      auto* fp32 = output.GetTensorMutableData<float>();
+      for (size_t i = 0; i < element_count; i++)
+        fp32[i] = FastFloat16ToFloat32(fp16[i]);
+    } else if (input_type == Ort::TypeToTensorType<int32_t> && output_type == Ort::TypeToTensorType<int64_t>) {
+      auto* input_data = input.GetTensorData<int32_t>();
+      auto* output_data = output.GetTensorMutableData<int64_t>();
+      for (size_t i = 0; i < element_count; i++)
+        output_data[i] = input_data[i];
+    } else
+      throw std::runtime_error("Cast - Unimplemented cast");
+    return true;
+  }
+
   std::unique_ptr<Search> CreateGreedy(const GeneratorParams& params) override { return std::make_unique<GreedySearch_Cpu>(params); }
   std::unique_ptr<Search> CreateBeam(const GeneratorParams& params) override { return std::make_unique<BeamSearch_Cpu>(params); }
 
diff --git a/src/cuda/beam_search_scorer_cuda.cpp b/src/cuda/beam_search_scorer_cuda.cpp
index a321af4d7..e2295e00d 100644
--- a/src/cuda/beam_search_scorer_cuda.cpp
+++ b/src/cuda/beam_search_scorer_cuda.cpp
@@ -8,7 +8,7 @@
 namespace Generators {
 
 BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters)
-    : stream_{parameters.cuda_stream} {
+    : stream_{GetStream()} {
   state_cpu_ = CudaMallocHostArray<cuda::BeamScorerState>(1);
   state_cpu_->batch_size_ = static_cast<size_t>(parameters.search.batch_size);
   state_cpu_->num_beams_ = static_cast<size_t>(parameters.search.num_beams);
diff --git a/src/cuda/beam_search_scorer_cuda.cuh b/src/cuda/beam_search_scorer_cuda.cuh
index 68be19fee..9e441d4bd 100644
--- a/src/cuda/beam_search_scorer_cuda.cuh
+++ b/src/cuda/beam_search_scorer_cuda.cuh
@@ -1,3 +1,4 @@
+#include "models/onnxruntime_api.h"
 #include "smartptrs.h"
 
 namespace Generators {
diff --git a/src/cuda/cuda_sampling.cu b/src/cuda/cuda_sampling.cu
index 8b5720479..f6be0b623 100644
--- a/src/cuda/cuda_sampling.cu
+++ b/src/cuda/cuda_sampling.cu
@@ -8,6 +8,7 @@
 #include "span.h"
 #include "beam_search_topk.h"
 #include "cuda_sampling.cuh"
+#include "models/onnxruntime_api.h"
 #include "smartptrs.h"
 #include <cuda_runtime.h>
 #include <cub/cub.cuh>
@@ -297,22 +298,22 @@ __global__ void SoftmaxBlockForward(outscalar_t* output, scalar_t* input, int cl
 }
 
 template <bool is_log_softmax>
-void DispatchBlockwiseSoftmaxForward(cudaStream_t* stream, float* output, const float* input, int softmax_elements,
+void DispatchBlockwiseSoftmaxForward(cudaStream_t stream, float* output, const float* input, int softmax_elements,
                                      int input_stride, int output_stride, int batch_count, float temperature) {
   dim3 grid(batch_count);
   constexpr int ILP = sizeof(float4) / sizeof(float);
   dim3 block = SoftmaxGetBlockSize(ILP, softmax_elements);
   if (is_log_softmax) {
     SoftmaxBlockForward<ILP, float, float, float, LogSoftmaxForwardEpilogue>
-        <<<grid, block, block.x * sizeof(float), *stream>>>(output, const_cast<float*>(input),
+        <<<grid, block, block.x * sizeof(float), stream>>>(output, const_cast<float*>(input),
                                                             softmax_elements, input_stride, output_stride, temperature);
   } else {
     SoftmaxBlockForward<ILP, float, float, float, SoftmaxForwardEpilogue>
-        <<<grid, block, block.x * sizeof(float), *stream>>>(output, const_cast<float*>(input),
+        <<<grid, block, block.x * sizeof(float), stream>>>(output, const_cast<float*>(input),
                                                             softmax_elements, input_stride, output_stride, temperature);
   }
 }
-template void DispatchBlockwiseSoftmaxForward<true>(cudaStream_t*, float*, const float*, int, int, int, int, float);
+template void DispatchBlockwiseSoftmaxForward<true>(cudaStream_t, float*, const float*, int, int, int, int, float);
 
 // Populate Kernels and Launchers
 
@@ -521,7 +522,7 @@ void LaunchSampleKernel(SamplingData* data, cudaStream_t stream, float* scores,
 void SoftmaxAndSort(SamplingData* data, cudaStream_t stream, float* scores_in, float* scores_out, int* indices_out, int vocab_size, int batch_size, float temperature) {
   // Softmax scores
   std::span<float> scores{data->scores_softmaxed.get(), static_cast<size_t>(vocab_size * batch_size)};
-  DispatchBlockwiseSoftmaxForward<false>(&stream, scores.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature);
+  DispatchBlockwiseSoftmaxForward<false>(stream, scores.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature);
   // Sort indices by scores
   std::span<int> offsets_gpu{data->offsets.get(), static_cast<size_t>(batch_size + 1)};
   LaunchPopulateOffsets(offsets_gpu.data(), vocab_size, batch_size, stream);
@@ -550,7 +551,7 @@ void LaunchGetTopKSubsetFullSort(SamplingData* data, cudaStream_t stream, float*
 void GetTopKSubset(SamplingData* data, cudaStream_t stream, float* scores_in, float* scores_out, int* indices_out, int vocab_size, int batch_size, int k, float temperature) {
   // Softmax scores
   std::span<float> scores_softmaxed{data->scores_softmaxed.get(), static_cast<size_t>(vocab_size * batch_size)};
-  DispatchBlockwiseSoftmaxForward<false>(&stream, scores_softmaxed.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature);
+  DispatchBlockwiseSoftmaxForward<false>(stream, scores_softmaxed.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature);
 // Get top k subset
 #define GetTopK(max_k)                                \
   LaunchGetTopKSubset<max_k>(stream,                  \
diff --git a/src/cuda/cuda_sampling.cuh b/src/cuda/cuda_sampling.cuh
index e6e0f184f..390ff92bc 100644
--- a/src/cuda/cuda_sampling.cuh
+++ b/src/cuda/cuda_sampling.cuh
@@ -25,7 +25,7 @@ void LaunchPopulateIndices(int* indices, int size, int batch_size, cudaStream_t
 void GetSample(SamplingData* data, cudaStream_t stream, int32_t* d_next_token, float* d_scores, int vocab_size, int batch_size, int k, float p, float temperature);
 
 template <bool is_log_softmax>
-void DispatchBlockwiseSoftmaxForward(cudaStream_t* stream, float* output, const float* input, int softmax_elements, int input_stride, int output_stride, int batch_count, float temperature = 1.0);
+void DispatchBlockwiseSoftmaxForward(cudaStream_t stream, float* output, const float* input, int softmax_elements, int input_stride, int output_stride, int batch_count, float temperature = 1.0);
 
 }  // namespace cuda
 }  // namespace Generators
\ No newline at end of file
diff --git a/src/cuda/interface.cpp b/src/cuda/interface.cpp
index 0b708c94d..4c85379b0 100644
--- a/src/cuda/interface.cpp
+++ b/src/cuda/interface.cpp
@@ -6,7 +6,7 @@
 #include "interface.h"
 #include "../search.h"
 #include "search_cuda.h"
-#include "../models/kernels.h"
+#include "kernels.h"
 #include <cstdarg>
 
 namespace Generators {
@@ -15,10 +15,13 @@ GenaiInterface* gp_genai{};
 Ort::Allocator* ort_allocator_{};
 const char* label_cuda = "cuda";
 
+cuda_stream_holder g_stream;
+cudaStream_t GetStream() { return g_stream.get(); }
+
 struct GpuMemory final : DeviceBuffer {
   GpuMemory(size_t size) : owned_{true} {
     size_in_bytes_ = size;
-    ::cudaMalloc(&p_device_, size);
+    p_device_ = static_cast<uint8_t*>(ort_allocator_->Alloc(size));
   }
 
   GpuMemory(void* p, size_t size) : owned_{false} {
@@ -28,7 +31,7 @@ struct GpuMemory final : DeviceBuffer {
 
   ~GpuMemory() override {
     if (owned_)
-      ::cudaFree(p_device_);
+      ort_allocator_->Free(p_device_);
     if (p_cpu_)
       ::cudaFreeHost(p_cpu_);
   }
@@ -65,19 +68,23 @@ struct GpuMemory final : DeviceBuffer {
   bool owned_;  // If we own the memory, we delete it on destruction
 };
 
-struct CudaInterfaceImpl : CudaInterface {
+struct CudaInterfaceImpl final : DeviceInterface {
   CudaInterfaceImpl() {
-    cuda_stream_.Create();
   }
 
   ~CudaInterfaceImpl() {
   }
 
-  void InitAllocator(Ort::Allocator& allocator) override {
+  void InitOrt(const OrtApi& api, Ort::Allocator& allocator) override {
+    Ort::api = &api;
     assert(!ort_allocator_);
     ort_allocator_ = &allocator;
   }
 
+  Ort::Allocator& GetAllocator() override {
+    return *ort_allocator_;
+  }
+
   std::shared_ptr<DeviceBuffer> AllocateBase(size_t size) override {
     return std::make_shared<GpuMemory>(size);
   }
@@ -95,85 +102,78 @@ struct CudaInterfaceImpl : CudaInterface {
   }
 
   void Synchronize() override {
-    ::cudaStreamSynchronize(cuda_stream_.get());
-  }
-
-  cudaStream_t GetCudaStream() override {
-    return cuda_stream_.get();
-  }
-
-  void Int32ToInt64(const int32_t* input, int64_t* output, int count, cudaStream_t stream) override {
-    cuda::LaunchInt32ToInt64(input, output, count, stream);
-  }
-
-  void Fp16ToFp32(const uint16_t* input, float* output, int count, cudaStream_t stream) override {
-    cuda::LaunchFp16ToFp32(input, output, count, stream);
-  }
-
-  void Fp32ToFp16(const float* input, uint16_t* output, int count, cudaStream_t stream) override {
-    cuda::LaunchFp32ToFp16(input, output, count, stream);
+    ::cudaStreamSynchronize(GetStream());
   }
 
-  void LaunchExpandAndInt32ToInt64(const int32_t* src, int64_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) override {
-    cuda::LaunchExpandAndInt32ToInt64(src, dst, num_beams, batch_size, sequence_length, stream);
+  void* GetCudaStream() override {
+    return GetStream();
   }
 
-  void LaunchExpand(const int32_t* src, int32_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) override {
-    cuda::LaunchExpand(src, dst, num_beams, batch_size, sequence_length, stream);
-  }
+  bool Cast(OrtValue& input, OrtValue& output) override {
+    auto input_info = input.GetTensorTypeAndShapeInfo();
+    auto output_info = output.GetTensorTypeAndShapeInfo();
 
-  void Launch_UpdatePositionIds(int32_t* position_ids, int batch_beam_size, int total_length, int new_kv_length, cudaStream_t stream) override {
-    cuda::Launch_UpdatePositionIds(position_ids, batch_beam_size, total_length, new_kv_length, stream);
-  }
+    auto input_type = input_info->GetElementType();
+    auto output_type = output_info->GetElementType();
 
-  void Launch_UpdatePositionIds(int64_t* position_ids, int batch_beam_size, int total_length, int new_kv_length, cudaStream_t stream) override {
-    cuda::Launch_UpdatePositionIds(position_ids, batch_beam_size, total_length, new_kv_length, stream);
-  }
+    auto input_data = input.GetTensorRawData();
+    auto output_data = output.GetTensorMutableRawData();
 
-  void Launch_UpdateAttentionMask(int32_t* mask_data, const int32_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) override {
-    cuda::Launch_UpdateAttentionMask(mask_data, old_data, batch_beam_size, new_kv_length, total_length, max_length, update_only, stream);
-  }
+    auto element_count = input_info->GetElementCount();
+    if (element_count != output_info->GetElementCount())
+      throw std::runtime_error("Cast - input and output element counts do not match");
+    if (input_type == output_type)
+      throw std::runtime_error("Cast - input and output types are the same");
 
-  void Launch_UpdateAttentionMask(int64_t* mask_data, const int64_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) override {
-    cuda::Launch_UpdateAttentionMask(mask_data, old_data, batch_beam_size, new_kv_length, total_length, max_length, update_only, stream);
+    if (input_type == Ort::TypeToTensorType<float> && output_type == Ort::TypeToTensorType<Ort::Float16_t>) {
+      cuda::LaunchFp32ToFp16(reinterpret_cast<const float*>(input_data), reinterpret_cast<uint16_t*>(output_data), static_cast<int>(element_count), GetStream());
+    } else if (input_type == Ort::TypeToTensorType<Ort::Float16_t> && output_type == Ort::TypeToTensorType<float>) {
+      cuda::LaunchFp16ToFp32(reinterpret_cast<const uint16_t*>(input_data), reinterpret_cast<float*>(output_data), static_cast<int>(element_count), GetStream());
+    } else if (input_type == Ort::TypeToTensorType<int32_t> && output_type == Ort::TypeToTensorType<int64_t>) {
+      cuda::LaunchInt32ToInt64(reinterpret_cast<const int32_t*>(input_data), reinterpret_cast<int64_t*>(output_data), static_cast<int>(element_count), GetStream());
+    } else
+      return false;
+    return true;
   }
 
-  void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream) override {
-    cuda::LaunchHandleEOSArray(batch_logits, batch_beam_size, vocab_size, eos_token_ids, eos_token_ids_count, stream);
+  void UpdatePositionIds(void* position_ids, int batch_beam_size, int total_length, int new_kv_length, ONNXTensorElementDataType type) override {
+    if (type == Ort::TypeToTensorType<int32_t>)
+      cuda::Launch_UpdatePositionIds(static_cast<int32_t*>(position_ids), batch_beam_size, total_length, new_kv_length, GetStream());
+    else
+      cuda::Launch_UpdatePositionIds(static_cast<int64_t*>(position_ids), batch_beam_size, total_length, new_kv_length, GetStream());
   }
 
-  void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length, cudaStream_t stream) override {
-    cuda::UpdateCacheIndirectionKernelLauncher(tgt_indir_cache, src_indir_cache, beam_ids, batch_size, beam_width, input_seq_length, max_seq_length, current_length, stream);
+  void UpdateAttentionMask(void* mask_data, const void* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, ONNXTensorElementDataType type) override {
+    if (type == Ort::TypeToTensorType<int32_t>)
+      cuda::Launch_UpdateAttentionMask(static_cast<int32_t*>(mask_data), static_cast<const int32_t*>(old_data), batch_beam_size, new_kv_length, total_length, max_length, update_only, GetStream());
+    else
+      cuda::Launch_UpdateAttentionMask(static_cast<int64_t*>(mask_data), static_cast<const int64_t*>(old_data), batch_beam_size, new_kv_length, total_length, max_length, update_only, GetStream());
   }
 
-  void ReorderPastStatesKernelLauncher(void* out_buffer, const void* in_buffer, int batch_size, int num_heads, int max_length, int head_size, int chunk_size, cudaStream_t stream) override {
-    cuda::ReorderPastStatesKernelLauncher(out_buffer, in_buffer, batch_size, num_heads, max_length, head_size, chunk_size, stream);
+  void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count) override {
+    cuda::LaunchHandleEOSArray(batch_logits, batch_beam_size, vocab_size, eos_token_ids, eos_token_ids_count, GetStream());
   }
 
-  void LaunchCopyCrossQKSingleDecodeStep(cudaStream_t stream, float* cross_qk_buffer_data, float** qk_layer_pointers, int token_index, int batch_beam_size, int num_layers, int num_heads, int num_alignment_heads, const int* alignment_heads, int frames, int max_length) override {
-    cuda::LaunchCopyCrossQKSingleDecodeStep(stream, cross_qk_buffer_data, qk_layer_pointers, token_index, batch_beam_size, num_layers, num_heads, num_alignment_heads, alignment_heads, frames, max_length);
+  void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length) override {
+    cuda::UpdateCacheIndirectionKernelLauncher(tgt_indir_cache, src_indir_cache, beam_ids, batch_size, beam_width, input_seq_length, max_seq_length, current_length, GetStream());
   }
 
-  void LaunchFinalizeCrossQK(cudaStream_t stream, int iteration_number, int context_decoding_len, int batch_size, int num_beams, int max_length, int num_alignment_heads, int frames_of_k, const float* cross_qk_buffer_data, float* cross_qk_output, int num_return_sequences, const int* cache_indir_data) override {
-    cuda::LaunchFinalizeCrossQK(stream, iteration_number, context_decoding_len, batch_size, num_beams, max_length, num_alignment_heads, frames_of_k, cross_qk_buffer_data, cross_qk_output, num_return_sequences, cache_indir_data);
+  void ReorderPastStatesKernelLauncher(void* out_buffer, const void* in_buffer, int batch_size, int num_heads, int max_length, int head_size, int chunk_size) override {
+    cuda::ReorderPastStatesKernelLauncher(out_buffer, in_buffer, batch_size, num_heads, max_length, head_size, chunk_size, GetStream());
   }
 
-  cudaError_t cudaMemcpyAsync(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream) override {
-    return ::cudaMemcpyAsync(dst, src, count, kind, stream);
+  void LaunchCopyCrossQKSingleDecodeStep(float* cross_qk_buffer_data, float** qk_layer_pointers, int token_index, int batch_beam_size, int num_layers, int num_heads, int num_alignment_heads, const int* alignment_heads, int frames, int max_length) override {
+    cuda::LaunchCopyCrossQKSingleDecodeStep(GetStream(), cross_qk_buffer_data, qk_layer_pointers, token_index, batch_beam_size, num_layers, num_heads, num_alignment_heads, alignment_heads, frames, max_length);
   }
 
-  cudaError_t cudaMemsetAsync(void* ptr, int value, size_t count, cudaStream_t stream) override {
-    return ::cudaMemsetAsync(ptr, value, count, stream);
+  void LaunchFinalizeCrossQK(int iteration_number, int context_decoding_len, int batch_size, int num_beams, int max_length, int num_alignment_heads, int frames_of_k, const float* cross_qk_buffer_data, float* cross_qk_output, int num_return_sequences, const int* cache_indir_data) override {
+    cuda::LaunchFinalizeCrossQK(GetStream(), iteration_number, context_decoding_len, batch_size, num_beams, max_length, num_alignment_heads, frames_of_k, cross_qk_buffer_data, cross_qk_output, num_return_sequences, cache_indir_data);
   }
-
- private:
-  cuda_stream_holder cuda_stream_;
 };
 
-std::unique_ptr<CudaInterface> g_cuda_device;
+std::unique_ptr<DeviceInterface> g_cuda_device;
 
 DeviceInterface& GetCudaDeviceInterface() { return *g_cuda_device; }
-cudaStream_t GetStream() { return g_cuda_device->GetCudaStream(); }
 
 LogItems& GetLogItems() { return gp_genai->GetLogItems(); }
 std::ostream& operator<<(std::ostream& stream, SGR sgr_code) { return gp_genai->operator_leftshift(stream, sgr_code); }
@@ -212,7 +212,7 @@ void operator delete(void* p, size_t /*size*/) noexcept { Generators::gp_genai->
 #endif
 
 extern "C" {
-Generators::CudaInterface* GetInterface(GenaiInterface* p_genai) {
+Generators::DeviceInterface* GetInterface(GenaiInterface* p_genai) {
   Generators::gp_genai = p_genai;
   Generators::g_cuda_device = std::make_unique<Generators::CudaInterfaceImpl>();
   return Generators::g_cuda_device.get();
diff --git a/src/cuda/interface.h b/src/cuda/interface.h
index 24e6fb94f..2236e7d31 100644
--- a/src/cuda/interface.h
+++ b/src/cuda/interface.h
@@ -22,29 +22,5 @@ struct GenaiInterface {
 
 namespace Generators {
 LogItems& GetLogItems();
-
-#if USE_CUDA
 DeviceInterface& GetCudaDeviceInterface();
-
-struct CudaInterface : DeviceInterface {
-  virtual void Int32ToInt64(const int32_t* input, int64_t* output, int count, cudaStream_t stream) = 0;
-  virtual void Fp16ToFp32(const uint16_t* input, float* output, int count, cudaStream_t stream) = 0;
-  virtual void Fp32ToFp16(const float* input, uint16_t* output, int count, cudaStream_t stream) = 0;
-  // TODO: This can be collapsed into a single function with a template parameter
-  virtual void LaunchExpandAndInt32ToInt64(const int32_t* src, int64_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) = 0;
-  virtual void LaunchExpand(const int32_t* src, int32_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) = 0;
-  virtual void Launch_UpdatePositionIds(int32_t* position_ids, int batch_beam_size, int total_length, int new_kv_length, cudaStream_t stream) = 0;
-  virtual void Launch_UpdatePositionIds(int64_t* position_ids, int batch_beam_size, int total_length, int new_kv_length, cudaStream_t stream) = 0;
-  virtual void Launch_UpdateAttentionMask(int32_t* mask_data, const int32_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) = 0;
-  virtual void Launch_UpdateAttentionMask(int64_t* mask_data, const int64_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) = 0;
-  virtual void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream) = 0;
-  virtual void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length, cudaStream_t stream) = 0;
-  virtual void ReorderPastStatesKernelLauncher(void* out_buffer, const void* in_buffer, int batch_size, int num_heads, int max_length, int head_size, int chunk_size, cudaStream_t stream) = 0;
-  virtual void LaunchCopyCrossQKSingleDecodeStep(cudaStream_t stream, float* cross_qk_buffer_data, float** qk_layer_pointers, int token_index, int batch_beam_size, int num_layers, int num_heads, int num_alignment_heads, const int* alignment_heads, int frames, int max_length) = 0;
-  virtual void LaunchFinalizeCrossQK(cudaStream_t stream, int iteration_number, int context_decoding_len, int batch_size, int num_beams, int max_length, int num_alignment_heads, int frames_of_k, const float* cross_qk_buffer_data, float* cross_qk_output, int num_return_sequences, const int* cache_indir_data) = 0;
-
-  virtual cudaError_t cudaMemcpyAsync(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream) = 0;
-  virtual cudaError_t cudaMemsetAsync(void* ptr, int value, size_t count, cudaStream_t stream) = 0;
-};
-#endif
 }  // namespace Generators
diff --git a/src/models/kernels.h b/src/cuda/kernels.h
similarity index 92%
rename from src/models/kernels.h
rename to src/cuda/kernels.h
index fece7dadf..99be3a416 100644
--- a/src/models/kernels.h
+++ b/src/cuda/kernels.h
@@ -15,8 +15,6 @@ void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_si
 void LaunchFp16ToFp32(const uint16_t* fp16, float* fp32, int count, cudaStream_t stream);
 void LaunchFp32ToFp16(const float* fp32, uint16_t* fp16, int count, cudaStream_t stream);
 void LaunchInt32ToInt64(const int32_t* src, int64_t* dst, int count, cudaStream_t stream);
-void LaunchExpandAndInt32ToInt64(const int32_t* src, int64_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream);
-void LaunchExpand(const int32_t* src, int32_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream);
 
 template <typename T>
 void BufferExpansionKernelLauncher(const T* input, T* output, int batch_size, int beam_width, int chunk_size, cudaStream_t stream);
diff --git a/src/cuda/model_kernels.cu b/src/cuda/model_kernels.cu
index 0eb316383..59b1f5431 100644
--- a/src/cuda/model_kernels.cu
+++ b/src/cuda/model_kernels.cu
@@ -137,36 +137,6 @@ void LaunchInt32ToInt64(const int32_t* src, int64_t* dst, int count, cudaStream_
   ConvertInt32ToInt64<<<num_blocks, block_size, 0, stream>>>(src, dst, count);
 }
 
-__global__ void ExpandAndConvertInt32ToInt64(const int32_t* src, int64_t* dst, int num_beams, int batch_size, int sequence_length) {
-  int idx = threadIdx.x + blockIdx.x * blockDim.x;
-  if (idx < num_beams * batch_size * sequence_length) {
-    int batch_id = idx / (num_beams * sequence_length);
-    int seq_id = idx % sequence_length;
-    dst[idx] = (int64_t)src[batch_id * sequence_length + seq_id];
-  }
-}
-
-void LaunchExpandAndInt32ToInt64(const int32_t* src, int64_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) {
-  int block_size = 256;
-  int num_blocks = (num_beams * batch_size * sequence_length + block_size - 1) / block_size;
-  ExpandAndConvertInt32ToInt64<<<num_blocks, block_size, 0, stream>>>(src, dst, num_beams, batch_size, sequence_length);
-}
-
-__global__ void Expand(const int32_t* src, int32_t* dst, int num_beams, int batch_size, int sequence_length) {
-  int idx = threadIdx.x + blockIdx.x * blockDim.x;
-  if (idx < num_beams * batch_size * sequence_length) {
-    int batch_id = idx / (num_beams * sequence_length);
-    int seq_id = idx % sequence_length;
-    dst[idx] = src[batch_id * sequence_length + seq_id];
-  }
-}
-
-void LaunchExpand(const int32_t* src, int32_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) {
-  int block_size = 256;
-  int num_blocks = (num_beams * batch_size * sequence_length + block_size - 1) / block_size;
-  Expand<<<num_blocks, block_size, 0, stream>>>(src, dst, num_beams, batch_size, sequence_length);
-}
-
 namespace {
 
 struct ReorderPastStateParams {
diff --git a/src/cuda/search_cuda.cpp b/src/cuda/search_cuda.cpp
index 8d7bf6fe0..a53ff37dd 100644
--- a/src/cuda/search_cuda.cpp
+++ b/src/cuda/search_cuda.cpp
@@ -22,7 +22,7 @@ Search_Cuda::Search_Cuda(const GeneratorParams& params)
   sequence_lengths_ = params.p_device->Allocate<int32_t>(batch_beam_size);
 
   eos_meet_buffer_ = CudaMallocArray<bool>(batch_beam_size, &eos_meet_);
-  cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), params_->cuda_stream);
+  cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), GetStream());
 
   done_cpu_ = CudaMallocHostArray<bool>(1);
   *done_cpu_ = false;
@@ -39,7 +39,7 @@ GreedySearch_Cuda::GreedySearch_Cuda(const GeneratorParams& params)
     random_seed = params_->search.random_seed;
   else
     random_seed = std::random_device{}();
-  samplingdata_ = std::make_unique<cuda::SamplingData>(random_seed, params_->search.batch_size, params_->config.model.vocab_size, params_->cuda_stream);
+  samplingdata_ = std::make_unique<cuda::SamplingData>(random_seed, params_->search.batch_size, params_->config.model.vocab_size, GetStream());
 }
 
 BeamSearch_Cuda::BeamSearch_Cuda(const GeneratorParams& params)
@@ -58,7 +58,7 @@ BeamSearch_Cuda::BeamSearch_Cuda(const GeneratorParams& params)
   topk_buffer_ = CudaMallocArray<float>(topk_buffer_size);
   static_assert(sizeof(float) == sizeof(int32_t));  // The topk_buffer assumes these match, fix for float16
 
-  cudaMemsetAsync(topk_buffer_.get(), 0, topk_buffer_size * sizeof(float), params_->cuda_stream);
+  cudaMemsetAsync(topk_buffer_.get(), 0, topk_buffer_size * sizeof(float), GetStream());
 }
 
 BeamSearch_Cuda::~BeamSearch_Cuda() = default;
@@ -84,20 +84,20 @@ DeviceSpan<int32_t> BeamSearch_Cuda::GetNextIndices() {
 }
 
 void BeamSearch_Cuda::SelectTop() {
-  cuda::DispatchBlockwiseSoftmaxForward<true>(const_cast<cudaStream_t*>(&params_->cuda_stream), softmax_buffer_.get(), next_token_scores_.Span().data(), params_->config.model.vocab_size,
+  cuda::DispatchBlockwiseSoftmaxForward<true>(GetStream(), softmax_buffer_.get(), next_token_scores_.Span().data(), params_->config.model.vocab_size,
                                               params_->config.model.vocab_size, params_->config.model.vocab_size, params_->BatchBeamSize());
 
   // Copy next_token_scores to CPU
   auto next_token_scores_cpu = CudaMallocHostArray<float>(params_->BatchBeamSize() * params_->config.model.vocab_size);
-  cudaMemcpyAsync(next_token_scores_cpu.get(), softmax_buffer_.get(), params_->BatchBeamSize() * params_->config.model.vocab_size * sizeof(float), cudaMemcpyDeviceToHost, params_->cuda_stream);
-  CudaCheck() == cudaStreamSynchronize(params_->cuda_stream);
+  cudaMemcpyAsync(next_token_scores_cpu.get(), softmax_buffer_.get(), params_->BatchBeamSize() * params_->config.model.vocab_size * sizeof(float), cudaMemcpyDeviceToHost, GetStream());
+  CudaCheck() == cudaStreamSynchronize(GetStream());
 
   auto beam_scores = beam_scorer_->GetNextScores();
 
   // Add beam score to next token scores. Corresponding python code is like:
   //    next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
   cuda::LaunchAddProbsKernel(softmax_buffer_.get(), beam_scores.Span().data(),
-                             params_->search.batch_size, params_->search.num_beams, params_->config.model.vocab_size, params_->cuda_stream);
+                             params_->search.batch_size, params_->search.num_beams, params_->config.model.vocab_size, GetStream());
 
   if (params_->search.num_beams <= 32) {
     constexpr size_t max_parts_of_vocab = 128;
@@ -120,11 +120,11 @@ void BeamSearch_Cuda::SelectTop() {
                          topk_next_scores_.get(),
                          topk_next_tokens_.get(),
                          topk_next_indices_.get(),
-                         params_->cuda_stream);
+                         GetStream());
   } else
     assert(false);
 
-  CudaCheck() == cudaStreamSynchronize(params_->cuda_stream);
+  CudaCheck() == cudaStreamSynchronize(GetStream());
 
   size_t size = params_->BatchBeamSize() * 2;
   std::span<float> next_scores{topk_next_scores_.get(), size};
@@ -146,13 +146,13 @@ void BeamSearch_Cuda::SelectTop() {
 void GreedySearch_Cuda::SampleTopKTopP(int k, float p, float temperature) {
   std::span<float> scores = next_token_scores_.Span();
   assert(scores.size() == params_->search.batch_size * params_->config.model.vocab_size);
-  cuda::GetSample(samplingdata_.get(), params_->cuda_stream, next_tokens_.data(), scores.data(), int(scores.size() / params_->search.batch_size),
+  cuda::GetSample(samplingdata_.get(), GetStream(), next_tokens_.data(), scores.data(), int(scores.size() / params_->search.batch_size),
                   params_->search.batch_size, k, p, temperature);
 
   // Check for EOS
   assert(next_tokens_.size() == eos_meet_.size());
   // Don't replace EOS with pad for batch_size == 1 for continuous decoding mode
-  cuda::Launch_CheckForEOSAndPad(next_tokens_.data(), static_cast<int>(next_tokens_.size()), eos_meet_.data(), params_->config.model.eos_token_id, params_->search.batch_size > 1 ? params_->config.model.pad_token_id : params_->config.model.eos_token_id, done_cpu_.get(), params_->cuda_stream);
+  cuda::Launch_CheckForEOSAndPad(next_tokens_.data(), static_cast<int>(next_tokens_.size()), eos_meet_.data(), params_->config.model.eos_token_id, params_->search.batch_size > 1 ? params_->config.model.pad_token_id : params_->config.model.eos_token_id, done_cpu_.get(), GetStream());
 
   // Append tokens
   cuda::Launch_AppendNextTokensToSequences(next_tokens_buffer_.Span(), sequences_.GetSequences().Span(), params_->BatchBeamSize(), sequences_.GetSequenceLength(), sequences_.max_length_, GetStream());
@@ -207,7 +207,7 @@ std::span<float> Search_Cuda::GetScores() {
 
 // Set user input tokens (batch_beam_size, sequence_length)
 void GreedySearch_Cuda::AppendTokens(DeviceSpan<int32_t>& next_tokens) {
-  cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), params_->cuda_stream);
+  cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), GetStream());
   *done_cpu_ = false;
 
   auto next_tokens_gpu = next_tokens.Span();
@@ -221,7 +221,7 @@ void GreedySearch_Cuda::AppendTokens(DeviceSpan<int32_t>& next_tokens) {
     return;
   }
 
-  cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), params_->cuda_stream);
+  cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), GetStream());
   *done_cpu_ = false;
 }
 
@@ -234,12 +234,12 @@ void BeamSearch_Cuda::AppendTokens(DeviceSpan<int32_t>& next_tokens) {
 }
 
 void GreedySearch_Cuda::RewindTo(size_t index) {
-  cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), params_->cuda_stream);
+  cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), GetStream());
   *done_cpu_ = false;
   if (index > 0)
     cuda::Launch_GetLastTokens(next_tokens_.data(), sequences_.GetSequences().Span().data(), static_cast<int>(params_->BatchBeamSize()), static_cast<int>(index), sequences_.max_length_, GetStream());
   else
-    cudaMemsetAsync(next_tokens_.data(), 0, params_->search.batch_size * sizeof(int32_t), params_->cuda_stream);
+    cudaMemsetAsync(next_tokens_.data(), 0, params_->search.batch_size * sizeof(int32_t), GetStream());
   sequences_.RewindTo(index);
 }
 
@@ -247,7 +247,7 @@ void Search_Cuda::ApplyMinLength(int min_length) {
   if (sequences_.GetSequenceLength() >= min_length)
     return;
 
-  cuda::LaunchSetScoreProcessor(GetScores().data(), params_->BatchBeamSize(), params_->config.model.vocab_size, params_->config.model.eos_token_id, std::numeric_limits<float>::lowest(), params_->cuda_stream);
+  cuda::LaunchSetScoreProcessor(GetScores().data(), params_->BatchBeamSize(), params_->config.model.vocab_size, params_->config.model.eos_token_id, std::numeric_limits<float>::lowest(), GetStream());
 }
 
 void Search_Cuda::ApplyRepetitionPenalty(float penalty) {
@@ -256,7 +256,7 @@ void Search_Cuda::ApplyRepetitionPenalty(float penalty) {
 
   cuda::LaunchRepetitionPenaltyProcessor(sequences_.GetSequences().Span().data(),
                                          GetScores().data(), params_->search.batch_size, params_->search.num_beams, params_->config.model.vocab_size,
-                                         params_->search.max_length, GetSequenceLength(), penalty, params_->cuda_stream);
+                                         params_->search.max_length, GetSequenceLength(), penalty, GetStream());
 }
 
 }  // namespace Generators
\ No newline at end of file
diff --git a/src/cuda/search_cuda.h b/src/cuda/search_cuda.h
index 2e0ec4610..bb31b4423 100644
--- a/src/cuda/search_cuda.h
+++ b/src/cuda/search_cuda.h
@@ -1,4 +1,5 @@
 #pragma once
+#include <cuda_runtime.h>
 #include "search_cuda.cuh"
 #include "cuda_sampling.cuh"
 
@@ -12,7 +13,7 @@ struct Search_Cuda : Search {
   DeviceSpan<int32_t> GetSequenceLengths() override { return sequence_lengths_; }
 
   bool IsDone() const {
-    cudaStreamSynchronize(params_->cuda_stream);
+    cudaStreamSynchronize(GetStream());
     return *done_cpu_;
   }  // TODO: Use an event
 
diff --git a/src/dml/interface.cpp b/src/dml/interface.cpp
index 61a57d663..459a71054 100644
--- a/src/dml/interface.cpp
+++ b/src/dml/interface.cpp
@@ -116,7 +116,8 @@ struct DmlInterfaceImpl : DeviceInterface {
     Ort::ThrowOnError(Ort::api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&dml_api_)));
   }
 
-  void InitAllocator(Ort::Allocator& allocator) override {
+  void InitOrt(const OrtApi& api, Ort::Allocator& allocator) override {
+    Ort::api = &api;
     assert(!ort_allocator_);
     ort_allocator_ = &allocator;
 
@@ -131,6 +132,10 @@ struct DmlInterfaceImpl : DeviceInterface {
     dml_readback_heap_ = std::make_unique<DmlReadbackHeap>(dml_objects_.d3d12_device.Get(), dml_execution_context_.get());
   }
 
+  Ort::Allocator& GetAllocator() override {
+    return *ort_allocator_;
+  }
+
   std::shared_ptr<DeviceBuffer> AllocateBase(size_t size) override {
     return std::make_shared<GpuMemory>(size);
   }
@@ -147,6 +152,58 @@ struct DmlInterfaceImpl : DeviceInterface {
     return GetCpuInterface()->CreateBeam(params);
   }
 
+#if 0
+  void UpdatePositionIDs() {
+    ComPtr<ID3D12Resource> target_resource;
+    Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, position_ids_->GetTensorMutableRawData(), &target_resource));
+
+    dml_update_position_ids_kernel_ = DmlIncrementValuesKernel(
+        model_.GetD3D12Device(),
+        model_.GetDmlExecutionContext(),
+        static_cast<uint32_t>(position_ids_shape_[0]),
+        type_,
+        target_resource.Get());
+
+    // Execute the cached command list
+    ComPtr<ID3D12Fence> fence;
+    uint64_t completion_value;
+    model_.GetDmlExecutionContext()->ExecuteCommandList(dml_update_position_ids_kernel_->GetCommandList(), &fence, &completion_value);
+  }
+
+  void UpdateAttentionMask(int total_length) {
+    ComPtr<ID3D12Resource> attention_mask_resource;
+    Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_->GetTensorMutableRawData(), &attention_mask_resource));
+    ComPtr<ID3D12Resource> attention_mask_next_resource;
+    Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_next_->GetTensorMutableRawData(), &attention_mask_next_resource));
+    if (is_first_mask_update_) {
+      dml_update_mask_kernel_ = DmlUpdateMaskKernel(
+          model_.GetD3D12Device(),
+          model_.GetDmlExecutionContext(),
+          static_cast<uint32_t>(attention_mask_shape_[0]),
+          static_cast<uint32_t>(attention_mask_shape_[1]),
+          type_,
+          total_length,
+          attention_mask_resource.Get(),
+          attention_mask_next_resource.Get());
+      is_second_mask_update_ = true;
+    } else if (is_second_mask_update_) {
+      dml_update_mask_kernel_ = DmlUpdateMaskKernel(
+          model_.GetD3D12Device(),
+          model_.GetDmlExecutionContext(),
+          static_cast<uint32_t>(attention_mask_shape_[0]),
+          static_cast<uint32_t>(attention_mask_shape_[1]),
+          type_,
+          1,
+          attention_mask_resource.Get(),
+          attention_mask_next_resource.Get());
+      is_second_mask_update_ = false;
+    }
+    ComPtr<ID3D12Fence> fence;
+    uint64_t completion_value;
+    model_.GetDmlExecutionContext()->ExecuteCommandList(dml_update_mask_kernel_->GetCommandList(), &fence, &completion_value);
+  }
+#endif
+
   void Synchronize() override {
   }
 };
@@ -165,7 +222,6 @@ void SetDmlProvider(OrtSessionOptions& session_options) {
 }
 
 DeviceInterface* GetDmlInterface() {
-  assert(g_dml_device);
   return g_dml_device.get();
 }
 
diff --git a/src/generators.cpp b/src/generators.cpp
index ad34d1979..c47f83151 100644
--- a/src/generators.cpp
+++ b/src/generators.cpp
@@ -9,9 +9,6 @@
 #include "cpu/interface.h"
 #include "cuda/interface.h"
 #include "dml/interface.h"
-#if USE_CUDA
-#include "models/kernels.h"
-#endif
 
 #if _WIN32
 EXTERN_C IMAGE_DOS_HEADER __ImageBase;
@@ -39,11 +36,6 @@ void ThrowErrorIfSessionTerminated(bool is_session_terminated) {
 
 namespace Generators {
 
-#if USE_CUDA
-// TODO: Remove once we remove all dependencies
-void OnCudaError(cudaError_t error) { assert(false); }
-#endif
-
 static bool _ = (Ort::InitApi(), false);
 
 OrtGlobals::OrtGlobals()
@@ -93,8 +85,12 @@ OrtEnv& GetOrtEnv() {
 // Fallback to copy between two separate device buffers by going through CPU memory (slow unless we're the CPU device)
 void CopyThroughCpu(DeviceBuffer& dest, size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) {
   source.CopyDeviceToCpu();
-  auto source_span = std::span<const uint8_t>(source.p_cpu_+begin_source, size_in_bytes);
-  dest.AllocateCpu();
+  auto source_span = std::span<const uint8_t>(source.p_cpu_ + begin_source, size_in_bytes);
+  // If we're overwriting the entire destination
+  if (dest.size_in_bytes_ == size_in_bytes)
+    dest.AllocateCpu();
+  else
+    dest.CopyDeviceToCpu();  // Overwriting part of destination, so copy over initial contents first
   std::copy(source_span.begin(), source_span.end(), dest.p_cpu_ + begin_dest);
   dest.CopyCpuToDevice();
 }
@@ -120,8 +116,7 @@ struct GenaiInterfaceImpl : GenaiInterface {
   void Sequences_RewindTo(Sequences* p_this, size_t new_length) override { return p_this->RewindTo(new_length); }
 } g_genai;
 
-#if USE_CUDA
-CudaInterface* GetCudaInterface() {
+DeviceInterface* GetCudaInterface() {
 // Load the shared library onnxruntime-genai-cuda.dll
 // This is a workaround to avoid linking the CUDA library to the generator library
 // The CUDA library is only needed for the CUDA allocator
@@ -137,8 +132,8 @@ CudaInterface* GetCudaInterface() {
     throw std::runtime_error("Cuda interface not available.");
   }
 
-  Generators::CudaInterface* GetInterface(GenaiInterface * p_genai);
-  static CudaInterface* cuda_interface{[] {
+  Generators::DeviceInterface* GetInterface(GenaiInterface * p_genai);
+  static DeviceInterface* cuda_interface{[] {
 #ifdef _WIN32
     auto get_cuda_fn = reinterpret_cast<decltype(&GetInterface)>(GetProcAddress(reinterpret_cast<HMODULE>(cuda_library.get()), "GetInterface"));
 #else
@@ -150,32 +145,6 @@ CudaInterface* GetCudaInterface() {
   return cuda_interface;
 }
 
-
-namespace cuda {
-void LaunchInt32ToInt64(const int32_t* input, int64_t* output, int count, cudaStream_t stream) { GetCudaInterface()->Int32ToInt64(input, output, count, stream); }
-void LaunchFp16ToFp32(const uint16_t* input, float* output, int count, cudaStream_t stream) { GetCudaInterface()->Fp16ToFp32(input, output, count, stream); }
-void LaunchFp32ToFp16(const float* input, uint16_t* output, int count, cudaStream_t stream) { GetCudaInterface()->Fp32ToFp16(input, output, count, stream); }
-void LaunchExpandAndInt32ToInt64(const int32_t* src, int64_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) { GetCudaInterface()->LaunchExpandAndInt32ToInt64(src, dst, num_beams, batch_size, sequence_length, stream); }
-void LaunchExpand(const int32_t* src, int32_t* dst, int num_beams, int batch_size, int sequence_length, cudaStream_t stream) { GetCudaInterface()->LaunchExpand(src, dst, num_beams, batch_size, sequence_length, stream); }
-template <>
-void Launch_UpdatePositionIds<int32_t>(int32_t* position_ids, int batch_beam_size, int total_length, int new_kv_length, cudaStream_t stream) { GetCudaInterface()->Launch_UpdatePositionIds(position_ids, batch_beam_size, total_length, new_kv_length, stream); }
-template <>
-void Launch_UpdatePositionIds<int64_t>(int64_t* position_ids, int batch_beam_size, int total_length, int new_kv_length, cudaStream_t stream) { GetCudaInterface()->Launch_UpdatePositionIds(position_ids, batch_beam_size, total_length, new_kv_length, stream); }
-template <>
-void Launch_UpdateAttentionMask<int32_t>(int32_t* mask_data, const int32_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) { GetCudaInterface()->Launch_UpdateAttentionMask(mask_data, old_data, batch_beam_size, new_kv_length, total_length, max_length, update_only, stream); }
-template <>
-void Launch_UpdateAttentionMask<int64_t>(int64_t* mask_data, const int64_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) { GetCudaInterface()->Launch_UpdateAttentionMask(mask_data, old_data, batch_beam_size, new_kv_length, total_length, max_length, update_only, stream); }
-void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream) { GetCudaInterface()->LaunchHandleEOSArray(batch_logits, batch_beam_size, vocab_size, eos_token_ids, eos_token_ids_count, stream); }
-void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length, cudaStream_t stream) { GetCudaInterface()->UpdateCacheIndirectionKernelLauncher(tgt_indir_cache, src_indir_cache, beam_ids, batch_size, beam_width, input_seq_length, max_seq_length, current_length, stream); }
-void ReorderPastStatesKernelLauncher(void* out_buffer, const void* in_buffer, int batch_size, int num_heads, int max_length, int head_size, int chunk_size, cudaStream_t stream) { GetCudaInterface()->ReorderPastStatesKernelLauncher(out_buffer, in_buffer, batch_size, num_heads, max_length, head_size, chunk_size, stream); }
-template <>
-void LaunchCopyCrossQKSingleDecodeStep<float>(cudaStream_t stream, float* cross_qk_buffer_data, float** qk_layer_pointers, int token_index, int batch_beam_size, int num_layers, int num_heads, int num_alignment_heads, const int* alignment_heads, int frames, int max_length) { GetCudaInterface()->LaunchCopyCrossQKSingleDecodeStep(stream, cross_qk_buffer_data, qk_layer_pointers, token_index, batch_beam_size, num_layers, num_heads, num_alignment_heads, alignment_heads, frames, max_length); }
-template <>
-void LaunchFinalizeCrossQK<float>(cudaStream_t stream, int iteration_number, int context_decoding_len, int batch_size, int num_beams, int max_length, int num_alignment_heads, int frames_of_k, const float* cross_qk_buffer_data, float* cross_qk_output, int num_return_sequences, const int* cache_indir_data) { GetCudaInterface()->LaunchFinalizeCrossQK(stream, iteration_number, context_decoding_len, batch_size, num_beams, max_length, num_alignment_heads, frames_of_k, cross_qk_buffer_data, cross_qk_output, num_return_sequences, cache_indir_data); }
-}  // namespace cuda
-#endif
-
-
 std::string to_string(DeviceType device_type) {
   switch (device_type) {
     case DeviceType::CPU:
@@ -195,10 +164,8 @@ DeviceInterface* GetDeviceInterface(DeviceType type) {
     default:
     case DeviceType::CPU:
       return GetCpuInterface();
-#if USE_CUDA
     case DeviceType::CUDA:
       return GetCudaInterface();
-#endif
 #if USE_DML
     case DeviceType::DML:
       return GetDmlInterface();
@@ -215,7 +182,6 @@ GeneratorParams::GeneratorParams(const Model& model)
     : config{*model.config_.get()},
       p_device{model.p_device_},
       device_type{model.device_type_},
-      cuda_stream{model.cuda_stream_},
       is_cuda_graph_enabled_{IsCudaGraphEnabled(model.config_->model.decoder.session_options)} {
   use_cuda_graph = is_cuda_graph_enabled_;
   if (use_cuda_graph) {
@@ -444,8 +410,3 @@ DeviceSpan<int32_t> Generator::GetSequence(size_t index) const {
 }
 
 }  // namespace Generators
-
-#if USE_CUDA
-cudaError_t cudaMemcpyAsync(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream) { return Generators::GetCudaInterface()->cudaMemcpyAsync(dst, src, count, kind, stream); }
-cudaError_t cudaMemsetAsync(void* ptr, int value, size_t count, cudaStream_t stream) { return Generators::GetCudaInterface()->cudaMemsetAsync(ptr, value, count, stream); }
-#endif
diff --git a/src/generators.h b/src/generators.h
index 7275aa976..0449e198e 100644
--- a/src/generators.h
+++ b/src/generators.h
@@ -23,16 +23,10 @@
 #include <unordered_set>
 #include <variant>
 #include <vector>
-#if USE_CUDA
-#include <cuda_runtime.h>
-#else
-// If we don't include cuda_runtime.h, we define this to avoid lots of extra #ifdefs
-using cudaStream_t = void*;
-#endif
 
 #include "leakcheck.h"
-#include "smartptrs.h"
 #include "models/onnxruntime_api.h"
+#include "smartptrs.h"
 #include "models/debugging.h"
 #include "config.h"
 #include "logging.h"
@@ -94,7 +88,6 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams>, LeakChec
 
   DeviceInterface* p_device{};
   DeviceType device_type{DeviceType::CPU};
-  cudaStream_t cuda_stream{};
 
   cpu_span<int32_t> aux_input_ids{};  // Intermediate solution to be used with SetInputs function for multimodal and whisper models
 
diff --git a/src/models/audio_processor.cpp b/src/models/audio_processor.cpp
index 3b0ac53ed..e81e357d1 100644
--- a/src/models/audio_processor.cpp
+++ b/src/models/audio_processor.cpp
@@ -30,7 +30,7 @@ std::unique_ptr<OrtValue> ProcessMel(ort_extensions::OrtxObjectPtr<OrtxTensor>&
         allocator.GetInfo(),
         std::span<float>(const_cast<float*>(mel_data), input_features_value->GetTensorTypeAndShapeInfo()->GetElementCount()),
         shape_span);
-    ConvertFp32ToFp16(allocator, *input_features_fp32, input_features_value, DeviceType::CPU, nullptr);
+    Cast(*input_features_fp32, input_features_value, *GetDeviceInterface(DeviceType::CPU), Ort::TypeToTensorType<Ort::Float16_t>);
   }
 
   return input_features_value;
diff --git a/src/models/captured_graph_pool.h b/src/models/captured_graph_pool.h
index 42e3be51d..4d405018e 100644
--- a/src/models/captured_graph_pool.h
+++ b/src/models/captured_graph_pool.h
@@ -145,11 +145,6 @@ struct CapturedGraphInfo {
   std::unique_ptr<Generators::StaticBuffer> sb_embeddings_;
   std::unique_ptr<CapturedGraphKey> key_;
 
-#if USE_DML
-  std::unique_ptr<Generators::StaticBuffer> sb_attention_mask_next_;
-  std::unique_ptr<Generators::StaticBuffer> sb_input_ids_int32_;
-#endif
-
   // Generates a unique annotation ID across different captured graph objects. This is necessary because different
   // generators could be alive at the same time and run the same batch size but with different static buffers, so
   // they need to have different annotation IDs.
diff --git a/src/models/debugging.cpp b/src/models/debugging.cpp
index 67e0df3b1..1d127be6c 100644
--- a/src/models/debugging.cpp
+++ b/src/models/debugging.cpp
@@ -89,41 +89,9 @@ void DumpTensor(const Model& model, std::ostream& stream, OrtValue* value, bool
       auto device_span = model.p_device_->WrapMemory<uint8_t>(tensor_span);
       DumpValues(stream, type, device_span.CopyDeviceToCpu().data(), element_count);
       break;
-#if 0
-#if USE_CUDA
-      auto type = type_info->GetElementType();
-      size_t element_size = SizeOf(type);
-      auto cpu_copy = std::make_unique<uint8_t[]>(element_size * element_count);
-      CudaCheck() == cudaMemcpy(cpu_copy.get(), value->GetTensorRawData(), element_size * element_count, cudaMemcpyDeviceToHost);
-      DumpValues(stream, type, cpu_copy.get(), element_count);
-#elif USE_DML
-      auto type = type_info->GetElementType();
-      size_t element_size = SizeOf(type);
-      auto cpu_copy = std::make_unique<uint8_t[]>(element_size * element_count);
-
-      if (value->GetTensorMutableRawData()) {
-        ComPtr<ID3D12Resource> gpu_resource;
-        Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(
-            model.allocator_device_,
-            value->GetTensorMutableRawData(),
-            &gpu_resource));
-
-        model.GetDmlReadbackHeap()->ReadbackFromGpu(
-            std::span(cpu_copy.get(), element_size * element_count),
-            gpu_resource.Get(),
-            0,
-            D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
-      }
-
-      DumpValues(stream, type, cpu_copy.get(), element_count);
-#else
-      stream << "Unexpected, using GPU memory but not compiled with CUDA or DML?";
-#endif
-#endif
-      break;
     }
     default:
-      stream << "Unhandled device type";
+      stream << "Unhandled device type: " << static_cast<int>(memory_info.GetDeviceType()) << "\r\n";
       break;
   }
 }
diff --git a/src/models/extra_inputs.cpp b/src/models/extra_inputs.cpp
index a4bab8ce7..b08a8c8d2 100644
--- a/src/models/extra_inputs.cpp
+++ b/src/models/extra_inputs.cpp
@@ -1,7 +1,6 @@
 #include "../generators.h"
 #include "model.h"
 #include "extra_inputs.h"
-#include "kernels.h"
 
 namespace Generators {
 
diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp
index 47233f3d8..d0be2c67c 100644
--- a/src/models/input_ids.cpp
+++ b/src/models/input_ids.cpp
@@ -1,7 +1,6 @@
 #include "../generators.h"
 #include "model.h"
 #include "input_ids.h"
-#include "kernels.h"
 
 namespace Generators {
 
@@ -49,7 +48,7 @@ void InputIDs::Add() {
   }
 }
 
-void InputIDs::Update(DeviceSpan<int32_t>& new_tokens) {
+void InputIDs::Update(DeviceSpan<int32_t> new_tokens) {
   auto new_tokens_cpu = new_tokens.CopyDeviceToCpu();
 
   const auto get_unpadded_sequence_length = [](std::span<const int32_t> input_ids, int32_t pad_token_id) {
@@ -77,69 +76,32 @@ void InputIDs::Update(DeviceSpan<int32_t>& new_tokens) {
   if (static_cast<size_t>(shape_[1]) != sequence_length) {
     shape_[1] = sequence_length;
     if (!sb_input_ids_) {
-      value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
+      value_ = OrtValue::CreateTensor<int32_t>(*model_.allocator_device_, shape_);
     } else {
-      value_ = sb_input_ids_->CreateTensorOnStaticBuffer(shape_, type_);
+      value_ = sb_input_ids_->CreateTensorOnStaticBuffer(shape_, Ort::TypeToTensorType<int32_t>);
     }
 
     state_.inputs_[input_index_] = value_.get();
   }
 
-  // Update input_ids with next tokens, converting from 32-bit to 64-bit
-  if (type_ == Ort::TypeToTensorType<int64_t>) {
-    switch (model_.device_type_) {
-      case DeviceType::CUDA: {
-#if USE_CUDA
-        auto* data = value_->GetTensorMutableData<int64_t>();
-        auto next_tokens = new_tokens.Span();
-        // For beam search
-        if (is_prompt_ && state_.params_->search.num_beams > 1)
-          cuda::LaunchExpandAndInt32ToInt64(next_tokens.data(), data, state_.params_->search.num_beams, state_.params_->search.batch_size, static_cast<int>(sequence_length), model_.cuda_stream_);
-        else
-          cuda::LaunchInt32ToInt64(next_tokens.data(), data, static_cast<int>(next_tokens.size()), model_.cuda_stream_);
-#endif
-      } break;
-
-      default: {
-        // CPU, DML, WEBGPU
-        auto* data = value_->GetTensorMutableData<int64_t>();
-        auto next_tokens = new_tokens.Span();
-        for (int b = 0; b < shape_[0]; b++) {
-          for (int i = 0; i < shape_[1]; i++) {
-            // For beam search
-            int32_t next_token;
-            if (is_prompt_ && state_.params_->search.num_beams > 1)
-              next_token = next_tokens[(b / state_.params_->search.num_beams) * shape_[1] + i];
-            else
-              next_token = next_tokens[b * shape_[1] + i];
-            data[b * shape_[1] + i] = next_token;
-          }
-        }
-      }
+  // Update input_ids with next tokens
+  auto data_span = WrapTensor<int32_t>(*model_.p_device_, *value_); 
+
+  // For beam search
+  if (is_prompt_ && state_.params_->search.num_beams > 1) {
+    int row_size = static_cast<int>(shape_[1]);
+    for (int b = 0; b < shape_[0]; b++) {
+      int in_offset = (b / state_.params_->search.num_beams) * row_size;
+      int out_offset = b * row_size;
+      data_span.subspan(out_offset, row_size).CopyFrom(new_tokens.subspan(in_offset, row_size));
     }
   } else {
-    auto* data = value_->GetTensorMutableData<int32_t>();
-#if USE_CUDA
-    if (model_.device_type_ == DeviceType::CUDA) {
-      if (is_prompt_ && state_.params_->search.num_beams > 1) {
-        cuda::LaunchExpand(new_tokens.Span().data(), data, state_.params_->search.num_beams, state_.params_->search.batch_size, static_cast<int>(sequence_length), model_.cuda_stream_);
-      } else {
-        cudaMemcpyAsync(data, new_tokens.Span().data(), shape_[0] * shape_[1] * sizeof(int32_t), cudaMemcpyDeviceToDevice, model_.cuda_stream_);
-      }
-    } else
-#endif
-    {
-      // For beam search
-      if (is_prompt_ && state_.params_->search.num_beams > 1) {
-        for (int b = 0; b < shape_[0]; b++) {
-          int in_offset = (b / state_.params_->search.num_beams) * static_cast<int>(shape_[1]);
-          int out_offset = b * static_cast<int>(shape_[1]);
-          memcpy(data + out_offset, new_tokens.Span().data() + in_offset, shape_[1] * sizeof(int32_t));
-        }
-      } else {
-        memcpy(data, new_tokens.Span().data(), shape_[0] * shape_[1] * sizeof(int32_t));
-      }
-    }
+    data_span.CopyFrom(new_tokens);
+  }
+
+  if (type_ == Ort::TypeToTensorType<int64_t>) {
+    Cast(*value_, cast_value_, *model_.p_device_, type_);
+    state_.inputs_[input_index_] = cast_value_.get();
   }
 
   is_prompt_ = false;
diff --git a/src/models/input_ids.h b/src/models/input_ids.h
index 02af3a98a..6c3a5fc76 100644
--- a/src/models/input_ids.h
+++ b/src/models/input_ids.h
@@ -14,7 +14,7 @@ struct InputIDs {
   void Add();
   // Resize input_ids based on size of next_tokens.
   // Update value with next_tokens.
-  void Update(DeviceSpan<int32_t>& next_tokens);
+  void Update(DeviceSpan<int32_t> next_tokens);
 
   auto& GetShape() const { return shape_; }
   const char* name_;
@@ -31,6 +31,7 @@ struct InputIDs {
   std::array<int64_t, 2> shape_{};
   ONNXTensorElementDataType type_;
   std::unique_ptr<OrtValue> value_;
+  std::unique_ptr<OrtValue> cast_value_;
 
   // Used for decoding runs with cuda graphs.
   StaticBuffer* sb_input_ids_{};
diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp
index 4fa3243e2..3f02d4db3 100644
--- a/src/models/kv_cache.cpp
+++ b/src/models/kv_cache.cpp
@@ -108,21 +108,13 @@ void KV_Cache_Combined::RewindPastTensorsTo(size_t index) {
   for (int i = 0; i < layer_count_; i++) {
     OrtValue& present = *presents_[i];
     std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_);
+    auto present_span = WrapTensor<T>(*model_.p_device_, present);
+    auto past_span = WrapTensor<T>(*model_.p_device_, *past);
+
     for (int j = 0; j < 2 * batch_x_num_heads; j++) {
-      auto present_data = present.GetTensorData<T>() + j * old_length_x_head_size;
-      auto past_data = past->GetTensorMutableData<T>() + j * new_length_x_head_size;
-#if USE_CUDA
-      if (model_.device_type_ == DeviceType::CUDA) {
-        cudaMemcpyAsync(past_data, present_data, new_length_x_head_size * sizeof(T), cudaMemcpyDeviceToDevice, model_.cuda_stream_);
-      } else
-#elif USE_DML
-      if (model_.device_type_ == DeviceType::DML) {
-        // TODO: Implement DML version
-      } else
-#endif
-      {
-        copy(std::span<const T>(present_data, new_length_x_head_size), std::span<T>(past_data, new_length_x_head_size));
-      }
+      auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size);
+      auto past_data = past_span.subspan(j * new_length_x_head_size, new_length_x_head_size);
+      past_data.CopyFrom(present_data);
     }
     pasts_[i] = std::move(past);
     state_.inputs_[input_index_ + i] = pasts_[i].get();
@@ -135,38 +127,22 @@ void KV_Cache_Combined::PickPastState(DeviceSpan<int32_t> beam_indices_device, i
   std::span<const int32_t> beam_indices = beam_indices_device.CopyDeviceToCpu();
   auto block_size_per_beam = shape_[2] * shape_[3] * shape_[4];
   auto past_key_size = shape_[1] * block_size_per_beam;
-  auto element_count = shape_[0] * past_key_size;
 
-  const OrtValue& present = *presents_[index];
+  OrtValue& present = *presents_[index];
   std::unique_ptr<OrtValue> past = OrtValue::CreateTensor<ScoreType>(*model_.allocator_kvcache_, shape_);
-  auto past_span = std::span<ScoreType>(past->GetTensorMutableData<ScoreType>(), element_count);
-  auto present_span = std::span<const ScoreType>(present.GetTensorData<ScoreType>(), element_count);
-
-#if USE_CUDA
-  if (model_.device_type_ == DeviceType::CUDA) {
-    for (size_t j = 0; j < beam_indices.size(); j++) {
-      int32_t beam_index = beam_indices[j];
-      auto present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
-      auto present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam);
-
-      auto past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
-      auto past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam);
-      cudaMemcpyAsync(past_key.data(), present_key.data(), present_key.size_bytes(), cudaMemcpyDeviceToDevice, model_.cuda_stream_);
-      cudaMemcpyAsync(past_value.data(), present_value.data(), present_value.size_bytes(), cudaMemcpyDeviceToDevice, model_.cuda_stream_);
-    }
-  } else
-#endif
-  {
-    for (size_t j = 0; j < beam_indices.size(); j++) {
-      int32_t const beam_index = beam_indices[j];
-      auto present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
-      auto present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam);
-
-      auto past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
-      auto past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam);
-      copy(present_key, past_key);
-      copy(present_value, past_value);
-    }
+
+  auto past_span = WrapTensor<ScoreType>(*model_.p_device_, *past);
+  auto present_span = WrapTensor<ScoreType>(*model_.p_device_, present);
+
+  for (size_t j = 0; j < beam_indices.size(); j++) {
+    int32_t beam_index = beam_indices[j];
+    auto present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
+    auto present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam);
+
+    auto past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
+    auto past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam);
+    past_key.CopyFrom(present_key);
+    past_value.CopyFrom(present_value);
   }
 
   pasts_[index] = std::move(past);
@@ -325,21 +301,14 @@ void KV_Cache::RewindPastTensorsTo(size_t index) {
   for (int i = 0; i < layer_count_ * 2; i++) {
     OrtValue& present = *presents_[i];
     std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_);
+
+    auto past_span = WrapTensor<T>(*model_.p_device_, *past);
+    auto present_span = WrapTensor<T>(*model_.p_device_, present);
+
     for (int j = 0; j < batch_x_num_heads; j++) {
-      auto present_data = present.GetTensorData<T>() + j * old_length_x_head_size;
-      auto past_data = past->GetTensorMutableData<T>() + j * new_length_x_head_size;
-#if USE_CUDA
-      if (model_.device_type_ == DeviceType::CUDA) {
-        cudaMemcpyAsync(past_data, present_data, new_length_x_head_size * sizeof(T), cudaMemcpyDeviceToDevice, model_.cuda_stream_);
-      } else
-#elif USE_DML
-      if (model_.device_type_ == DeviceType::DML) {
-        // TODO: Implement DML copy
-      } else
-#endif
-      {
-        copy(std::span<const T>(present_data, new_length_x_head_size), std::span<T>(past_data, new_length_x_head_size));
-      }
+      auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size);
+      auto past_data = past_span.subspan(j * new_length_x_head_size, new_length_x_head_size);
+      past_data.CopyFrom(present_data);
     }
     pasts_[i] = std::move(past);
     state_.inputs_[input_index_ + i] = pasts_[i].get();
@@ -351,32 +320,20 @@ template <typename ScoreType>
 void KV_Cache::PickPastState(DeviceSpan<int32_t> beam_indices_device, int index) {
   std::span<int32_t> beam_indices = beam_indices_device.Span();
   auto block_size_per_beam = shape_[1] * shape_[2] * shape_[3];
-  auto element_count = shape_[0] * block_size_per_beam;
 
-  const OrtValue& present_value = *presents_[index];
+  OrtValue& present_value = *presents_[index];
   std::unique_ptr<OrtValue> past_value = OrtValue::CreateTensor<ScoreType>(*model_.allocator_kvcache_, shape_);
-  auto past_span = std::span<ScoreType>(past_value->GetTensorMutableData<ScoreType>(), element_count);
-  auto present_span = std::span<const ScoreType>(present_value.GetTensorData<ScoreType>(), element_count);
-
-#if USE_CUDA
-  if (model_.device_type_ == DeviceType::CUDA) {
-    for (size_t j = 0; j < beam_indices.size(); j++) {
-      int32_t beam_index = beam_indices[j];
-      auto present = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
-      auto past = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
-      cudaMemcpyAsync(past.data(), present.data(), present.size_bytes(), cudaMemcpyDeviceToDevice, model_.cuda_stream_);
-    }
-  } else
-#endif
-  {
-    for (size_t j = 0; j < beam_indices.size(); j++) {
-      int32_t const beam_index = beam_indices[j];
-      auto present = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
-      auto past = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
-      copy(present, past);
-    }
-  }
 
+  auto past_span = WrapTensor<ScoreType>(*model_.p_device_, *past_value);
+  auto present_span = WrapTensor<ScoreType>(*model_.p_device_, present_value);
+
+  for (size_t j = 0; j < beam_indices.size(); j++) {
+    int32_t beam_index = beam_indices[j];
+    auto present = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
+    auto past = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
+    past.CopyFrom(present);
+  }
+  
   pasts_[index] = std::move(past_value);
 }
 
diff --git a/src/models/logits.cpp b/src/models/logits.cpp
index ae7d9bdcd..c270d6ccc 100644
--- a/src/models/logits.cpp
+++ b/src/models/logits.cpp
@@ -3,10 +3,6 @@
 #include "../generators.h"
 #include "model.h"
 #include "logits.h"
-#if USE_CUDA
-#include "../cuda/cuda_common.h"
-#include "kernels.h"
-#endif
 
 namespace Generators {
 
@@ -25,14 +21,12 @@ Logits::Logits(State& state)
     }
   }
 
-#if USE_CUDA
   if (model_.device_type_ == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) {
     auto& cpu_ids = model_.config_->model.eos_token_ids;
     cuda_eos_token_ids_ = state_.params_->p_device->Allocate<int32_t>(cpu_ids.size());
     copy(std::span<const int32_t>{cpu_ids}, cuda_eos_token_ids_.CpuSpan());
     cuda_eos_token_ids_.CopyCpuToDevice();
   }
-#endif
 
   input_sequence_lengths.resize(state_.params_->search.batch_size);
 }
@@ -79,26 +73,23 @@ DeviceSpan<float> Logits::Get() {
 
   // Convert from float16 to float32 if necessary
   if (type_ == Ort::TypeToTensorType<Ort::Float16_t>) {
-      ConvertFp16ToFp32(*model_.allocator_device_, *logits_of_last_token, logits_of_last_token_fp32_, model_.device_type_, model_.cuda_stream_);
-      logits_of_last_token = logits_of_last_token_fp32_.get();
+    Cast(*logits_of_last_token, logits_of_last_token_fp32_, *model_.p_device_, Ort::TypeToTensorType<float>);
+    logits_of_last_token = logits_of_last_token_fp32_.get();
   }
 
   if (logits_.empty() || logits_of_last_token->GetTensorMutableRawData() != logits_.Span().data())
     logits_ = WrapTensor<float>(*state_.params_->p_device, *logits_of_last_token);
 
-#if USE_CUDA
   if (model_.device_type_ == DeviceType::CUDA) {
     if (!cuda_eos_token_ids_.empty())
-      cuda::LaunchHandleEOSArray(
+      model_.p_device_->LaunchHandleEOSArray(
           logits_.Span().data(),
           static_cast<int>(shape_[0]) /* batch_beam_size*/,
           static_cast<int>(shape_[2]) /* vocab_size */,
           cuda_eos_token_ids_.Span().data(),
-          static_cast<int>(cuda_eos_token_ids_.size()),
-          model_.cuda_stream_);
+          static_cast<int>(cuda_eos_token_ids_.size()));
     return logits_;
   }
-#endif
 
   HandleEOSArray(logits_.Span());
   return logits_;
diff --git a/src/models/logits.h b/src/models/logits.h
index 187ecd8cb..3eae05875 100644
--- a/src/models/logits.h
+++ b/src/models/logits.h
@@ -43,9 +43,7 @@ struct Logits {
   StaticBuffer* sb_logits32_{};
   StaticBuffer* sb_logits16_{};
 
-#if USE_CUDA
   DeviceSpan<int32_t> cuda_eos_token_ids_;  // eos_token_ids from params, but in cuda accessible memory
-#endif
 };
 
 }  // namespace Generators
diff --git a/src/models/model.cpp b/src/models/model.cpp
index 9270a2cc7..487874b25 100644
--- a/src/models/model.cpp
+++ b/src/models/model.cpp
@@ -11,7 +11,6 @@
 #include "gpt.h"
 #include "decoder_only.h"
 #include "whisper.h"
-#include "kernels.h"
 #include "multi_modal_vision_model.h"
 #include "decoder_only_pipeline.h"
 #if USE_DML
@@ -228,7 +227,7 @@ Ort::Allocator* GetDeviceAllocator(OrtSession& session, DeviceType type) {
     static const char* device_type_names[static_cast<int>(DeviceType::MAX)] = {"CPU", "Cuda", "DML", "WebGPU Buffer"};
     auto memory_info = OrtMemoryInfo::Create(device_type_names[static_cast<int>(type)], OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
     device = Ort::Allocator::Create(session, *memory_info);
-    GetDeviceInterface(type)->InitAllocator(*device);
+    GetDeviceInterface(type)->InitOrt(*Ort::api, *device);
   }
   return device.get();
 }
@@ -284,9 +283,9 @@ Model::Model(std::unique_ptr<Config> config) : config_{std::move(config)} {
 
 Model::~Model() = default;
 
-void Model::InitDeviceAllocator([[maybe_unused]] OrtSession& session) {
+void Model::InitDeviceAllocator(OrtSession& session) {
   allocator_device_ = &allocator_cpu_;
-  if (device_type_== DeviceType::CUDA)
+  if (device_type_ == DeviceType::CUDA)
     allocator_device_ = GetDeviceAllocator(session, device_type_);
 
   allocator_kvcache_ = allocator_device_;
@@ -400,8 +399,7 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
         p_device_ = GetDeviceInterface(device_type_);
 
         // Create and set our cudaStream_t
-        cuda_stream_ = p_device_->GetCudaStream();
-        ort_provider_options->UpdateValue("user_compute_stream", cuda_stream_);
+        ort_provider_options->UpdateValue("user_compute_stream", p_device_->GetCudaStream());
       }
 
       session_options.AppendExecutionProvider_CUDA_V2(*ort_provider_options);
@@ -419,7 +417,7 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
       session_options.AppendExecutionProvider_ROCM(ort_provider_options);
 #if USE_DML
     } else if (provider_options.name == "dml") {
-      if (!p_dml_api_) {
+      if (!GetDmlInterface()) {
         LUID device_luid{};
         LUID* p_device_luid{};
         for (const auto& [name, value] : provider_options.options) {
@@ -537,80 +535,17 @@ std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Config& config) {
   return std::make_shared<GeneratorParams>(config);
 }
 
-void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<OrtValue>& p_out, DeviceType device_type, cudaStream_t stream) {
-  auto shape_info = in.GetTensorTypeAndShapeInfo();
-  auto shape = shape_info->GetShape();
-  assert(shape_info->GetElementType() == Ort::TypeToTensorType<Ort::Float16_t>);
+void Cast(OrtValue& input, std::unique_ptr<OrtValue>& output, DeviceInterface& device, ONNXTensorElementDataType output_type) {
+  auto input_info = input.GetTensorTypeAndShapeInfo();
+  auto shape = input_info->GetShape();
 
-  bool allocate_p_out = p_out == nullptr;
-  if (p_out) {
-    auto out_shape_info = p_out->GetTensorTypeAndShapeInfo();
-    auto out_shape = out_shape_info->GetShape();
-    allocate_p_out = shape != out_shape;
-  }
-
-  if (allocate_p_out)
-    p_out = OrtValue::CreateTensor<float>(allocator, shape);
-
-  int count = static_cast<int>(shape_info->GetElementCount());
-  auto* fp16 = in.GetTensorData<uint16_t>();
-  auto* fp32 = p_out->GetTensorMutableData<float>();
-
-  switch (device_type) {
-    case DeviceType::WEBGPU:
-    case DeviceType::DML:
-      // DML, WebGpu doesn't currently support on-device scoring, so we fall back to the CPU
-    case DeviceType::CPU:
-      for (int i = 0; i < count; i++)
-        fp32[i] = FastFloat16ToFloat32(fp16[i]);
-      break;
+  if (output && shape != output->GetTensorTypeAndShapeInfo()->GetShape())
+    output = nullptr;
+  if (!output)
+    output = OrtValue::CreateTensor(device.GetAllocator(), shape, output_type);
 
-#if USE_CUDA
-    case DeviceType::CUDA:
-      cuda::LaunchFp16ToFp32(fp16, fp32, count, stream);
-      break;
-#endif
-
-    default:
-      throw std::runtime_error("ConvertFp16ToFp32 - Unsupported device type");
-  }
-}
-
-void ConvertFp32ToFp16(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<OrtValue>& p_out,
-                       DeviceType device_type, cudaStream_t stream) {
-  auto shape_info = in.GetTensorTypeAndShapeInfo();
-  auto shape = shape_info->GetShape();
-  assert(shape_info->GetElementType() == Ort::TypeToTensorType<float>);
-
-  bool allocate_p_out = p_out == nullptr;
-  if (p_out) {
-    auto out_shape_info = p_out->GetTensorTypeAndShapeInfo();
-    auto out_shape = out_shape_info->GetShape();
-    allocate_p_out = shape != out_shape;
-  }
-
-  if (allocate_p_out)
-    p_out = OrtValue::CreateTensor<float>(allocator, shape);
-
-  int count = static_cast<int>(shape_info->GetElementCount());
-  auto* fp32 = in.GetTensorData<float>();
-  auto* fp16 = p_out->GetTensorMutableData<uint16_t>();
-
-  switch (device_type) {
-    case DeviceType::DML:
-    case DeviceType::CPU:
-      for (int i = 0; i < count; i++)
-        fp16[i] = FastFloat32ToFloat16(fp32[i]);
-      break;
-
-#if USE_CUDA
-    case DeviceType::CUDA:
-      cuda::LaunchFp32ToFp16(fp32, fp16, count, stream);
-#endif
-
-    default:
-      throw std::runtime_error("ConvertFp32ToFp16 - Unsupported device type");
-  }
+  if(!device.Cast(input, *output))
+    GetDeviceInterface(DeviceType::CPU)->Cast(input, *output);
 }
 
 std::unique_ptr<OrtValue> Model::ExpandInputs(std::unique_ptr<OrtValue>& input, int num_beams) const {
@@ -625,44 +560,21 @@ std::unique_ptr<OrtValue> Model::ExpandInputs(std::unique_ptr<OrtValue>& input,
 
   auto input_type_info = input->GetTensorTypeAndShapeInfo();
   auto element_type = input_type_info->GetElementType();
-  auto element_size = SizeOf(element_type);
   auto input_shape = input_type_info->GetShape();
   const int64_t batch_size = input_shape[0];
-  const int64_t data_size_bytes = input_type_info->GetElementCount() * element_size / batch_size;
+  const int64_t data_size_bytes = input_type_info->GetElementCount() * SizeOf(element_type) / batch_size;
 
   input_shape[0] *= num_beams;
 
   auto& allocator = device_type_ == DeviceType::DML ? allocator_cpu_ : *allocator_device_;
   auto expanded = OrtValue::CreateTensor(allocator, input_shape, element_type);
-  const auto* input_data = reinterpret_cast<const uint8_t*>(input->GetTensorRawData());
-  auto* expanded_data = reinterpret_cast<uint8_t*>(expanded->GetTensorMutableRawData());
-  auto* target = expanded_data;
-
-  switch (device_type_) {
-    case DeviceType::WEBGPU:
-    case DeviceType::DML:
-      // DML and WebGpu doesn't currently support on-device scoring, so we use the CPU for non-cache inputs/outputs
-    case DeviceType::CPU:
-      for (int i = 0; i < batch_size; i++) {
-        for (int j = 0; j < num_beams; j++) {
-          memcpy(target, input_data + i * data_size_bytes, data_size_bytes);
-          target += data_size_bytes;
-        }
-      }
-      break;
-
-#if USE_CUDA
-    case DeviceType::CUDA:
-      for (int i = 0; i < batch_size; i++) {
-        for (int j = 0; j < num_beams; j++) {
-          cudaMemcpyAsync(target, input_data + i * data_size_bytes, data_size_bytes, cudaMemcpyHostToDevice, cuda_stream_);
-          target += data_size_bytes;
-        }
-      }
-      break;
-#endif
-    default:
-      throw std::runtime_error("ExpandInputs - Unsupported device type");
+  auto input_span = ByteWrapTensor(*GetDeviceInterface(DeviceType::CPU), *input);
+  auto expanded_span = ByteWrapTensor(*p_device_, *expanded);
+
+  for (int i = 0; i < batch_size; i++) {
+    for (int j = 0; j < num_beams; j++) {
+      expanded_span.subspan((i * num_beams + j) * data_size_bytes, data_size_bytes).CopyFrom(input_span.subspan(i * data_size_bytes, data_size_bytes));
+    }
   }
   return expanded;
 }
diff --git a/src/models/model.h b/src/models/model.h
index 55ad960b3..63fd7749f 100644
--- a/src/models/model.h
+++ b/src/models/model.h
@@ -8,22 +8,11 @@
 #include "audio_processor.h"
 #include "adapters.h"
 
-#if USE_DML
-#include "dml_provider_factory.h"
-#include "../dml/dml_helpers.h"
-#include "../dml/dml_execution_context.h"
-#include "../dml/dml_pooled_upload_heap.h"
-#include "../dml/dml_readback_heap.h"
-#endif
-
 namespace Generators {
 
 struct Tokenizer;
 
-void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<OrtValue>& p_out, DeviceType device_type, cudaStream_t stream);
-
-void ConvertFp32ToFp16(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<OrtValue>& p_out, DeviceType device_type, cudaStream_t stream);
-
+void Cast(OrtValue& input, std::unique_ptr<OrtValue>& output, DeviceInterface& device, ONNXTensorElementDataType type);
 void CheckResult(extError_t error);
 
 struct State {
@@ -145,7 +134,6 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
   std::unique_ptr<Config> config_;
   std::unique_ptr<OrtSessionOptions> session_options_;
 
-  cudaStream_t cuda_stream_{};
   mutable DeviceInterface* p_device_{};
   DeviceType device_type_{DeviceType::CPU};
   Ort::Allocator& allocator_cpu_{Ort::Allocator::GetWithDefaultOptions()};
@@ -156,15 +144,6 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
 
   std::shared_ptr<Model> external_owner_;  // Set to 'this' when created by the C API to preserve lifetime
 
-#if USE_DML
-  DmlExecutionContext* GetDmlExecutionContext() const { return dml_execution_context_.get(); }
-  DmlReadbackHeap* GetDmlReadbackHeap() const { return dml_readback_heap_.get(); }
-  DmlPooledUploadHeap* GetDmlUploadHeap() const { return dml_pooled_upload_heap_.get(); }
-  const OrtDmlApi* GetOrtDmlApi() const { return p_dml_api_; }
-  IDMLDevice* GetDmlDevice() const { return dml_device_.Get(); }
-  ID3D12Device* GetD3D12Device() const { return dml_objects_.d3d12_device.Get(); }
-#endif
-
  protected:
   void InitDeviceAllocator(OrtSession& session);
   void CreateSessionOptions();
@@ -174,14 +153,6 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
                                       bool is_primary_session_options,
                                       bool disable_graph_capture);
 
-#if USE_DML
-  mutable DmlObjects dml_objects_;
-  const OrtDmlApi* p_dml_api_{};
-  std::unique_ptr<DmlPooledUploadHeap> dml_pooled_upload_heap_;
-  std::unique_ptr<DmlExecutionContext> dml_execution_context_;
-  std::unique_ptr<DmlReadbackHeap> dml_readback_heap_;
-  ComPtr<IDMLDevice> dml_device_;
-#endif
 #if USE_WEBGPU
   std::unique_ptr<OrtIoBinding> webgpu_io_binding_;
 #endif
diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp
index 5d4c82684..4f57637a8 100644
--- a/src/models/position_inputs.cpp
+++ b/src/models/position_inputs.cpp
@@ -1,11 +1,6 @@
 #include "../generators.h"
 #include "model.h"
 #include "position_inputs.h"
-#include "kernels.h"
-
-#if USE_DML
-#include "../dml/dml_update_mask_kernel.h"
-#endif
 
 namespace Generators {
 
@@ -49,12 +44,6 @@ PositionInputs::PositionInputs(const Model& model, State& state, DeviceSpan<int3
     }
     if (has_mask_input_) {
       sb_attention_mask_ = state_.GetCapturedGraphInfo()->sb_attention_mask_.get();
-
-#if USE_DML
-      if (model_.device_type_ == DeviceType::DML) {
-        sb_attention_mask_next_ = state_.GetCapturedGraphInfo()->sb_attention_mask_next_.get();
-      }
-#endif
     }
   }
 }
@@ -104,9 +93,7 @@ void PositionInputs::RewindTo(size_t index) {
     // Rewind the mask input to a previous state
   } else if (has_mask_input_) {
     if (attention_mask_shape_[0] == 1) {
-#if USE_CUDA
       RewindMask(index);
-#endif
     } else
       throw std::runtime_error("PositionInputs::RewindTo - Unsupported batch size");
   }
@@ -126,29 +113,6 @@ void PositionInputs::AddPositionIDs() {
   state_.input_names_.push_back(model_.config_->model.decoder.inputs.position_ids.c_str());
 }
 
-#if USE_CUDA || USE_DML
-void PositionInputs::CopyNextPositionIDsToCurrent() {
-#if USE_CUDA
-  assert(model_.device_type_ == DeviceType::CUDA);
-  cudaMemcpyAsync(position_ids_->GetTensorMutableRawData(),
-                  position_ids_next_->GetTensorMutableRawData(),
-                  (type_ == Ort::TypeToTensorType<int32_t> ? sizeof(int32_t) : sizeof(int64_t)) * position_ids_shape_[0],
-                  cudaMemcpyDeviceToDevice,
-                  model_.cuda_stream_);
-#elif USE_DML
-  assert(model_.device_type_ == DeviceType::DML);
-  ComPtr<ID3D12Resource> target_resource;
-  Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, position_ids_->GetTensorMutableRawData(), &target_resource));
-  auto source = std::span(position_ids_next_->GetTensorData<const uint8_t>(), (type_ == Ort::TypeToTensorType<int32_t> ? sizeof(int32_t) : sizeof(int64_t)) * position_ids_shape_[0]);
-  model_.GetDmlUploadHeap()->BeginUploadToGpu(
-      target_resource.Get(),
-      0,
-      D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
-      source);
-#endif
-}
-#endif
-
 void PositionInputs::CreateNextPositionIDsTensor() {
   if (!sb_position_ids_) {
     if (position_ids_shape_[1] == 1 && position_ids_next_) {
@@ -158,12 +122,12 @@ void PositionInputs::CreateNextPositionIDsTensor() {
       position_ids_ = OrtValue::CreateTensor(*model_.allocator_device_, position_ids_shape_, type_);
     }
   } else {
-#if USE_CUDA || USE_DML
     position_ids_ = sb_position_ids_->CreateTensorOnStaticBuffer(position_ids_shape_, type_);
     if (position_ids_shape_[1] == 1) {
-      CopyNextPositionIDsToCurrent();
+      auto position_ids_span = ByteWrapTensor(*model_.p_device_, *position_ids_);
+      auto position_ids_next_span = ByteWrapTensor(*model_.p_device_, *position_ids_next_);
+      position_ids_span.CopyFrom(position_ids_next_span);
     }
-#endif
   }
 }
 
@@ -182,25 +146,16 @@ void PositionInputs::UpdatePositionIDs(int total_length, int new_kv_length) {
 
   switch (model_.device_type_) {
     case DeviceType::WEBGPU:
+    case DeviceType::DML:
     case DeviceType::CPU: {
       type_ == Ort::TypeToTensorType<int32_t> ? UpdatePositionIDsImpl<int32_t>(total_length, new_kv_length)
                                               : UpdatePositionIDsImpl<int64_t>(total_length, new_kv_length);
       break;
     }
-#if USE_CUDA
     case DeviceType::CUDA: {
-      if (type_ == Ort::TypeToTensorType<int32_t>)
-        cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData<int32_t>(), static_cast<int>(position_ids_shape_[0]), total_length, new_kv_length, model_.cuda_stream_);
-      else
-        cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData<int64_t>(), static_cast<int>(position_ids_shape_[0]), total_length, new_kv_length, model_.cuda_stream_);
-      break;
-    }
-#elif USE_DML
-    case DeviceType::DML: {
-      UpdatePositionIDsImplDML();
+      model_.p_device_->UpdatePositionIds(position_ids_->GetTensorMutableRawData(), static_cast<int>(position_ids_shape_[0]), total_length, new_kv_length, type_);
       break;
     }
-#endif
     default:
       throw std::runtime_error("PositionIDs::Update - Unsupported device type");
   }
@@ -210,22 +165,12 @@ void PositionInputs::CreateNextAttentionMaskTensor(int total_length) {
   if (!sb_attention_mask_) {
     attention_mask_shape_[1] = total_length;
     attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_);
-#if USE_DML
-    if (model_.device_type_ == DeviceType::DML)
-      attention_mask_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_);
-#endif
   } else {
-#if USE_CUDA
     attention_mask_shape_[1] = state_.params_->search.max_length;
     attention_mask_next_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_);
     if (is_first_mask_update_) {
       ByteWrapTensor(*model_.p_device_, *attention_mask_next_).Zero();
     }
-#elif USE_DML
-    attention_mask_shape_[1] = state_.params_->search.max_length;
-    attention_mask_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_);
-    attention_mask_next_ = sb_attention_mask_next_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_);
-#endif
   }
 }
 
@@ -240,52 +185,31 @@ void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) {
 
   switch (model_.device_type_) {
     case DeviceType::WEBGPU:
+    case DeviceType::DML:
     case DeviceType::CPU: {
       type_ == Ort::TypeToTensorType<int32_t> ? UpdateAttentionMaskImpl<int32_t>(total_length)
                                               : UpdateAttentionMaskImpl<int64_t>(total_length);
       break;
     }
-#if USE_CUDA
     case DeviceType::CUDA: {
-      int max_length = sb_attention_mask_ ? state_.params_->search.max_length : total_length;
+        int max_length = sb_attention_mask_ ? state_.params_->search.max_length : total_length;
       bool update_only = sb_attention_mask_ && !is_first_mask_update_;
-      if (type_ == Ort::TypeToTensorType<int32_t>) {
-        cuda::Launch_UpdateAttentionMask(attention_mask_next_->GetTensorMutableData<int32_t>(),
-                                         attention_mask_->GetTensorData<int32_t>(),
-                                         static_cast<int>(attention_mask_shape_[0]),
-                                         new_kv_length,
-                                         total_length,
-                                         max_length,
-                                         update_only,
-                                         model_.cuda_stream_);
-      } else {
-        cuda::Launch_UpdateAttentionMask(attention_mask_next_->GetTensorMutableData<int64_t>(),
-                                         attention_mask_->GetTensorData<int64_t>(),
-                                         static_cast<int>(attention_mask_shape_[0]),
-                                         new_kv_length,
-                                         total_length,
-                                         max_length,
-                                         update_only,
-                                         model_.cuda_stream_);
-      }
+      model_.p_device_->UpdateAttentionMask(attention_mask_next_->GetTensorMutableRawData(),
+                                            attention_mask_->GetTensorRawData(),
+                                            static_cast<int>(attention_mask_shape_[0]),
+                                            new_kv_length,
+                                            total_length,
+                                            max_length,
+                                            update_only,
+                                            type_);
       break;
     }
-#elif USE_DML
-    case DeviceType::DML: {
-      UpdateAttentionMaskImplDML(total_length);
-      break;
-    }
-#endif
+
     default:
       throw std::runtime_error("PositionInputs::Update - Unsupported device type");
   }
-#if USE_DML
-  if (model_.device_type_ != DeviceType::DML) {
-    attention_mask_ = std::move(attention_mask_next_);
-  }
-#else
+
   attention_mask_ = std::move(attention_mask_next_);
-#endif
   state_.inputs_[mask_input_index_] = attention_mask_.get();
   is_first_mask_update_ = false;
 }
@@ -365,25 +289,6 @@ void PositionInputs::UpdatePositionIDsImpl(int total_length, int new_kv_length)
   }
 }
 
-#if USE_DML
-void PositionInputs::UpdatePositionIDsImplDML() {
-  ComPtr<ID3D12Resource> target_resource;
-  Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, position_ids_->GetTensorMutableRawData(), &target_resource));
-
-  dml_update_position_ids_kernel_ = DmlIncrementValuesKernel(
-      model_.GetD3D12Device(),
-      model_.GetDmlExecutionContext(),
-      static_cast<uint32_t>(position_ids_shape_[0]),
-      type_,
-      target_resource.Get());
-
-  // Execute the cached command list
-  ComPtr<ID3D12Fence> fence;
-  uint64_t completion_value;
-  model_.GetDmlExecutionContext()->ExecuteCommandList(dml_update_position_ids_kernel_->GetCommandList(), &fence, &completion_value);
-}
-#endif
-
 template <typename T>
 void PositionInputs::UpdateAttentionMaskImpl(int total_length) {
   auto* data = attention_mask_next_->GetTensorMutableData<T>();
@@ -403,44 +308,10 @@ void PositionInputs::UpdateAttentionMaskImpl(int total_length) {
   }
 }
 
-#if USE_DML
-void PositionInputs::UpdateAttentionMaskImplDML(int total_length) {
-  ComPtr<ID3D12Resource> attention_mask_resource;
-  Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_->GetTensorMutableRawData(), &attention_mask_resource));
-  ComPtr<ID3D12Resource> attention_mask_next_resource;
-  Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_next_->GetTensorMutableRawData(), &attention_mask_next_resource));
-  if (is_first_mask_update_) {
-    dml_update_mask_kernel_ = DmlUpdateMaskKernel(
-        model_.GetD3D12Device(),
-        model_.GetDmlExecutionContext(),
-        static_cast<uint32_t>(attention_mask_shape_[0]),
-        static_cast<uint32_t>(attention_mask_shape_[1]),
-        type_,
-        total_length,
-        attention_mask_resource.Get(),
-        attention_mask_next_resource.Get());
-    is_second_mask_update_ = true;
-  } else if (is_second_mask_update_) {
-    dml_update_mask_kernel_ = DmlUpdateMaskKernel(
-        model_.GetD3D12Device(),
-        model_.GetDmlExecutionContext(),
-        static_cast<uint32_t>(attention_mask_shape_[0]),
-        static_cast<uint32_t>(attention_mask_shape_[1]),
-        type_,
-        1,
-        attention_mask_resource.Get(),
-        attention_mask_next_resource.Get());
-    is_second_mask_update_ = false;
-  }
-  ComPtr<ID3D12Fence> fence;
-  uint64_t completion_value;
-  model_.GetDmlExecutionContext()->ExecuteCommandList(dml_update_mask_kernel_->GetCommandList(), &fence, &completion_value);
-}
-#endif
-
-#if USE_CUDA
 void PositionInputs::RewindMask(size_t index) {
   if (sb_attention_mask_ && !is_first_mask_update_) {
+    throw std::runtime_error("PositionInputs::RewindMask - Static buffer is not supported for continuous decoding.");
+#if 0  // TODO: Fix implementation, cudaMemsetAsync of 1 is setting bytes of 1 vs int32's of 1
     int past_length = static_cast<int>(index);
     int max_length = static_cast<int>(state_.params_->search.max_length);
     cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(),
@@ -451,8 +322,8 @@ void PositionInputs::RewindMask(size_t index) {
                     1,
                     (type_ == Ort::TypeToTensorType<int32_t> ? sizeof(int32_t) : sizeof(int64_t)) * past_length,
                     model_.cuda_stream_);
+#endif
   }
 }
-#endif
 
 }  // namespace Generators
diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h
index 64a38779c..6c297cc85 100644
--- a/src/models/position_inputs.h
+++ b/src/models/position_inputs.h
@@ -1,12 +1,6 @@
 #pragma once
-
 #include "static_buffer.h"
 
-#if USE_DML
-#include "../dml/dml_update_mask_kernel.h"
-#include "../dml/dml_increment_values_kernel.h"
-#endif
-
 namespace Generators {
 
 struct PositionInputs {
@@ -39,18 +33,7 @@ struct PositionInputs {
   template <typename T>
   void UpdateAttentionMaskImpl(int total_length);
 
-#if USE_CUDA || USE_DML
-  void CopyNextPositionIDsToCurrent();
-#endif
-
-#if USE_DML
-  void UpdatePositionIDsImplDML();
-  void UpdateAttentionMaskImplDML(int total_length);
-#endif
-
-#if USE_CUDA
   void RewindMask(size_t index);
-#endif
 
   const Model& model_;
   State& state_;
@@ -77,13 +60,6 @@ struct PositionInputs {
 
   bool is_first_mask_update_{true};
   bool is_first_update_{true};
-
-#if USE_DML
-  std::optional<DmlUpdateMaskKernel> dml_update_mask_kernel_;
-  StaticBuffer* sb_attention_mask_next_{};
-  std::optional<DmlIncrementValuesKernel> dml_update_position_ids_kernel_;
-  bool is_second_mask_update_{};
-#endif
 };
 
 }  // namespace Generators
diff --git a/src/models/prompt_image_processor.cpp b/src/models/prompt_image_processor.cpp
index 249711a0b..33ed2cfa3 100644
--- a/src/models/prompt_image_processor.cpp
+++ b/src/models/prompt_image_processor.cpp
@@ -89,7 +89,7 @@ std::unique_ptr<OrtValue> ProcessPixelValues(ortc::Tensor<float>* pixel_values,
         allocator.GetInfo(),
         std::span<float>(const_cast<float*>(pixel_values->Data()), pixel_values->NumberOfElement()),
         pixel_values->Shape());
-    ConvertFp32ToFp16(allocator, *pixel_values_fp32, pixel_values_value, DeviceType::CPU, nullptr);
+    Cast(*pixel_values_fp32, pixel_values_value, *GetDeviceInterface(DeviceType::CPU), Ort::TypeToTensorType<Ort::Float16_t>);
   }
 
   return pixel_values_value;
diff --git a/src/models/whisper.cpp b/src/models/whisper.cpp
index 723b7aeda..ec6cf42bf 100644
--- a/src/models/whisper.cpp
+++ b/src/models/whisper.cpp
@@ -3,10 +3,6 @@
 #include "../generators.h"
 #include "whisper.h"
 #include <vector>
-#include "kernels.h"
-#if USE_CUDA
-#include "../cuda/cuda_common.h"
-#endif
 
 namespace Generators {
 
@@ -42,7 +38,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, DeviceSpan<int32_t> seq
   }
 
   if (inputs.alignment_heads != nullptr) {
-#if USE_CUDA
+#if 0 // USE_CUDA
     auto alignment_heads_type_and_shape_info = inputs.alignment_heads->ort_tensor_->GetTensorTypeAndShapeInfo();
     auto alignment_heads_type = alignment_heads_type_and_shape_info->GetElementType();  // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
     auto alignment_heads_shape = alignment_heads_type_and_shape_info->GetShape();
@@ -101,7 +97,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, DeviceSpan<int32_t> seq
   }
 }
 
-#if USE_CUDA
+#if 0 // USE_CUDA
 template <typename T>
 void TransposeKCacheForDMMHA(T* dest_data,
                              T* temp_buffer,
@@ -147,7 +143,7 @@ DeviceSpan<float> Whisper_State::Run(int current_length, DeviceSpan<int32_t>& ne
 
       const auto copy_data_size_all = src_shape_info->GetElementCount() * SizeOf(src_shape_info->GetElementType());
 
-#if USE_CUDA
+#if 0 // USE_CUDA
       const auto src_dims = src_shape_info->GetShape();
       const auto src_element_type = src_shape_info->GetElementType();
       const auto src_element_size = SizeOf(src_element_type);
@@ -187,7 +183,7 @@ DeviceSpan<float> Whisper_State::Run(int current_length, DeviceSpan<int32_t>& ne
         auto dest_data = presents_[i]->GetTensorMutableRawData();
 
         switch (model_.device_type_) {
-#if USE_CUDA
+#if 0 // USE_CUDA
           case DeviceType::CUDA:
             if (self_attn_kv_cache_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
               // CUDA EP + FP16 precision == `DecoderMaskedMultiHeadAttention` op is used
@@ -228,7 +224,7 @@ DeviceSpan<float> Whisper_State::Run(int current_length, DeviceSpan<int32_t>& ne
         }
       }
 
-#if USE_CUDA
+#if 0 // USE_CUDA
       if (self_attn_kv_cache_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 && model_.device_type_ == DeviceType::CUDA) {
         // Transpose cross attention K caches for `DecoderMaskedMultiHeadAttention`
 
@@ -331,7 +327,7 @@ void Whisper_State::UpdateInputsOutputs(DeviceSpan<int32_t>& next_tokens, Device
   }
 
   if (cache_indirection_) {
-#if USE_CUDA
+#if 0 // USE_CUDA
     auto beam_indices_gpu = gpu_span<int32_t>{beam_indices.Span()};
     if (beam_indices_gpu.empty()) {
       auto beam_indices_cpu = beam_indices.CpuSpan();
@@ -359,7 +355,7 @@ void Whisper_State::UpdateInputsOutputs(DeviceSpan<int32_t>& next_tokens, Device
   }
 
   if (output_cross_qk_.size() && alignment_heads_) {
-#if USE_CUDA
+#if 0 // USE_CUDA
     // Collect a GPU array of float* pointers from the vector of OrtValues to pass to the kernel
     auto output_cross_qk_ptrs = cross_qk_ptrs_gpu_.CpuSpan();
     assert(output_cross_qk_ptrs.size() == output_cross_qk_.size());
@@ -390,7 +386,7 @@ void Whisper_State::Initialize(DeviceSpan<int32_t>& next_tokens, int total_lengt
 
 void Whisper_State::Finalize() {
   if (output_cross_qk_.size() && alignment_heads_) {
-#if USE_CUDA
+#if 0 // USE_CUDA
     int decoded_length = *(past_sequence_length_->GetTensorMutableData<int32_t>()) + 1;
     auto output_cross_qk_dims = output_cross_qk_[0]->GetTensorTypeAndShapeInfo()->GetShape();
 
diff --git a/src/smartptrs.h b/src/smartptrs.h
index fbaf12ed2..a25c8c133 100644
--- a/src/smartptrs.h
+++ b/src/smartptrs.h
@@ -85,7 +85,8 @@ struct DeviceSpan {
 
 struct DeviceInterface {
   virtual ~DeviceInterface() {}
-  virtual void InitAllocator(Ort::Allocator& allocator) = 0;
+  virtual void InitOrt(const OrtApi& api, Ort::Allocator& allocator) = 0;
+  virtual Ort::Allocator& GetAllocator() = 0;
 
   template <typename T>
   DeviceSpan<T> Allocate(size_t count) { return DeviceSpan<T>(AllocateBase(sizeof(T) * count)); }
@@ -101,7 +102,18 @@ struct DeviceInterface {
 
   virtual void Synchronize() = 0;  // Synchronize the device, typically used for timing or debugging
 
-  virtual cudaStream_t GetCudaStream() {
+  virtual bool Cast(OrtValue& /*input*/, OrtValue& /*output*/) { return false; }
+
+  virtual void UpdatePositionIds(void* /*position_ids*/, int /*batch_beam_size*/, int /*total_length*/, int /*new_kv_length*/, ONNXTensorElementDataType /*type*/) { assert(false); }
+  virtual void UpdateAttentionMask(void* /*mask_data*/, const void* /*old_data*/, int /*batch_beam_size*/, int /*new_kv_length*/, int /*total_length*/, int /*max_length*/, bool /*update_only*/, ONNXTensorElementDataType /*type*/) { assert(false); }
+
+  virtual void LaunchHandleEOSArray(float* /*batch_logits*/, int /*batch_beam_size*/, int /*vocab_size*/, const int32_t* /*eos_token_ids*/, int /*eos_token_ids_count*/) { assert(false); }
+  virtual void UpdateCacheIndirectionKernelLauncher(int32_t* /*tgt_indir_cache*/, const int32_t* /*src_indir_cache*/, const int32_t* /*beam_ids*/, int /*batch_size*/, int /*beam_width*/, int /*input_seq_length*/, int /*max_seq_length*/, int /*current_length*/) { assert(false); }
+  virtual void ReorderPastStatesKernelLauncher(void* /*out_buffer*/, const void* /*in_buffer*/, int /*batch_size*/, int /*num_heads*/, int /*max_length*/, int /*head_size*/, int /*chunk_size*/) { assert(false); }
+  virtual void LaunchCopyCrossQKSingleDecodeStep(float* /*cross_qk_buffer_data*/, float** /*qk_layer_pointers*/, int /*token_index*/, int /*batch_beam_size*/, int /*num_layers*/, int /*num_heads*/, int /*num_alignment_heads*/, const int* /*alignment_heads*/, int /*frames*/, int /*max_length*/) { assert(false);  }
+  virtual void LaunchFinalizeCrossQK(int /*iteration_number*/, int /*context_decoding_len*/, int /*batch_size*/, int /*num_beams*/, int /*max_length*/, int /*num_alignment_heads*/, int /*frames_of_k*/, const float* /*cross_qk_buffer_data*/, float* /*cross_qk_output*/, int /*num_return_sequences*/, const int* /*cache_indir_data*/) { assert(false); }
+
+  virtual void* GetCudaStream() {
     assert(false);
     return nullptr;
   }  // Temporary until we fully factor out providers
diff --git a/test/model_tests.cpp b/test/model_tests.cpp
index 321d1ac46..9482a0998 100644
--- a/test/model_tests.cpp
+++ b/test/model_tests.cpp
@@ -18,12 +18,6 @@
 #define PHI2_PATH MODEL_PATH "phi-2/int4/cpu"
 #endif
 #endif
-#if USE_DML
-#include <DirectML.h>
-#include <wrl.h>
-#include <d3d12.h>
-#include <dxgi1_6.h>
-#endif
 
 // To generate this file:
 // python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 --output tiny_gpt2_greedysearch_fp16.onnx --use_gpu --max_length 20
@@ -35,7 +29,7 @@ static const std::pair<const char*, const char*> c_tiny_gpt2_model_paths[] = {
 
 #if USE_DML
 TEST(ModelTests, DMLAdapterSelection) {
-#if TEST_PHI2
+#if 0 // TEST_PHI2 TODO: Remove this? Can't access the device directly anymore.
   auto model = Generators::CreateModel(Generators::GetOrtEnv(), PHI2_PATH);
   auto d3d12Device = model->GetD3D12Device();
 

From bdbb09c80a9aa4a3e37461d4b1b05819a89072da Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Wed, 15 Jan 2025 16:04:50 -0800
Subject: [PATCH 04/31] Fix merge build issues

---
 src/models/input_ids.cpp       |  4 +--
 src/models/input_ids.h         |  4 +--
 src/models/model.h             |  4 ---
 src/models/position_inputs.cpp | 62 ++++++++++++----------------------
 src/ort_genai_c.cpp            |  7 ++++
 5 files changed, 32 insertions(+), 49 deletions(-)

diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp
index 253fbc2bf..133be44f5 100644
--- a/src/models/input_ids.cpp
+++ b/src/models/input_ids.cpp
@@ -44,7 +44,7 @@ void DefaultInputIDs::Add() {
   }
 }
 
-void InputIDs::Update(DeviceSpan<int32_t> new_tokens) {
+void DefaultInputIDs::Update(DeviceSpan<int32_t> new_tokens) {
   auto new_tokens_cpu = new_tokens.CopyDeviceToCpu();
 
   const auto get_unpadded_sequence_length = [](std::span<const int32_t> input_ids, int32_t pad_token_id) {
@@ -130,7 +130,7 @@ void WindowedInputIDs::Add() {
   state_.input_names_.push_back(name_);
 }
 
-void WindowedInputIDs::Update(DeviceSpan<int32_t>& new_tokens) {
+void WindowedInputIDs::Update(DeviceSpan<int32_t> new_tokens) {
   if (window_index_ == 0) {
     num_windows_ = (new_tokens.size() + window_size_ - 1) / window_size_;
 
diff --git a/src/models/input_ids.h b/src/models/input_ids.h
index b81718022..cc9ab5640 100644
--- a/src/models/input_ids.h
+++ b/src/models/input_ids.h
@@ -8,7 +8,7 @@ struct InputIDs {
   virtual ~InputIDs() = default;
   virtual void Add() = 0;
   virtual std::array<int64_t, 2> GetShape() const = 0;
-  virtual void Update(DeviceSpan<int32_t>& next_tokens) = 0;
+  virtual void Update(DeviceSpan<int32_t> next_tokens) = 0;
 };
 
 struct DefaultInputIDs : InputIDs {
@@ -60,7 +60,7 @@ struct WindowedInputIDs : public InputIDs {
   WindowedInputIDs& operator=(const WindowedInputIDs&) = delete;
 
   void Add() override;
-  void Update(DeviceSpan<int32_t>& next_tokens) override;
+  void Update(DeviceSpan<int32_t> next_tokens) override;
   std::array<int64_t, 2> GetShape() const override { return shape_; }
 
  private:
diff --git a/src/models/model.h b/src/models/model.h
index 034c0defb..9635e79d2 100644
--- a/src/models/model.h
+++ b/src/models/model.h
@@ -155,10 +155,6 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
                                       bool is_primary_session_options,
                                       bool disable_graph_capture);
 
-#endif
-#if USE_DML || USE_WEBGPU
-  std::unique_ptr<OrtMemoryInfo> memory_info_device_;
-#endif
   std::shared_ptr<CapturedGraphPool> captured_graph_pool_;
   std::map<std::string, std::unique_ptr<OrtSessionOptions>> pipeline_session_options_;
 };
diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp
index 6e02d92d7..76372bd54 100644
--- a/src/models/position_inputs.cpp
+++ b/src/models/position_inputs.cpp
@@ -113,7 +113,7 @@ void DefaultPositionInputs::AddPositionIDs() {
   state_.input_names_.push_back(model_.config_->model.decoder.inputs.position_ids.c_str());
 }
 
-void PositionInputs::CreateNextPositionIDsTensor() {
+void DefaultPositionInputs::CreateNextPositionIDsTensor() {
   if (!sb_position_ids_) {
     if (position_ids_shape_[1] == 1 && position_ids_next_) {
       position_ids_ = std::move(position_ids_next_);
@@ -142,20 +142,11 @@ void DefaultPositionInputs::UpdatePositionIDs(int total_length, int new_kv_lengt
     state_.inputs_[posid_input_index_] = position_ids_.get();
   }
 
-  switch (model_.device_type_) {
-    case DeviceType::WEBGPU:
-    case DeviceType::DML:
-    case DeviceType::CPU: {
-      type_ == Ort::TypeToTensorType<int32_t> ? UpdatePositionIDsImpl<int32_t>(total_length, new_kv_length)
-                                              : UpdatePositionIDsImpl<int64_t>(total_length, new_kv_length);
-      break;
-    }
-    case DeviceType::CUDA: {
-      model_.p_device_->UpdatePositionIds(position_ids_->GetTensorMutableRawData(), static_cast<int>(position_ids_shape_[0]), total_length, new_kv_length, type_);
-      break;
-    }
-    default:
-      throw std::runtime_error("PositionIDs::Update - Unsupported device type");
+  if (model_.device_type_ == DeviceType::CUDA)
+    model_.p_device_->UpdatePositionIds(position_ids_->GetTensorMutableRawData(), static_cast<int>(position_ids_shape_[0]), total_length, new_kv_length, type_);
+  else {
+    type_ == Ort::TypeToTensorType<int32_t> ? UpdatePositionIDsImpl<int32_t>(total_length, new_kv_length)
+                                            : UpdatePositionIDsImpl<int64_t>(total_length, new_kv_length);
   }
 }
 
@@ -181,31 +172,20 @@ void DefaultPositionInputs::UpdateAttentionMask(int total_length, int new_kv_len
   CreateNextAttentionMaskTensor(total_length);
   state_.inputs_[mask_input_index_] = attention_mask_.get();
 
-  switch (model_.device_type_) {
-    case DeviceType::WEBGPU:
-    case DeviceType::DML:
-    case DeviceType::QNN:
-    case DeviceType::CPU: {
-      type_ == Ort::TypeToTensorType<int32_t> ? UpdateAttentionMaskImpl<int32_t>(total_length)
-                                              : UpdateAttentionMaskImpl<int64_t>(total_length);
-      break;
-    }
-    case DeviceType::CUDA: {
-        int max_length = sb_attention_mask_ ? state_.params_->search.max_length : total_length;
-      bool update_only = sb_attention_mask_ && !is_first_mask_update_;
-      model_.p_device_->UpdateAttentionMask(attention_mask_next_->GetTensorMutableRawData(),
-                                            attention_mask_->GetTensorRawData(),
-                                            static_cast<int>(attention_mask_shape_[0]),
-                                            new_kv_length,
-                                            total_length,
-                                            max_length,
-                                            update_only,
-                                            type_);
-      break;
-    }
-
-    default:
-      throw std::runtime_error("DefaultPositionInputs::Update - Unsupported device type");
+  if (model_.device_type_ == DeviceType::CUDA) {
+    int max_length = sb_attention_mask_ ? state_.params_->search.max_length : total_length;
+    bool update_only = sb_attention_mask_ && !is_first_mask_update_;
+    model_.p_device_->UpdateAttentionMask(attention_mask_next_->GetTensorMutableRawData(),
+                                          attention_mask_->GetTensorRawData(),
+                                          static_cast<int>(attention_mask_shape_[0]),
+                                          new_kv_length,
+                                          total_length,
+                                          max_length,
+                                          update_only,
+                                          type_);
+  } else {
+    type_ == Ort::TypeToTensorType<int32_t> ? UpdateAttentionMaskImpl<int32_t>(total_length)
+                                            : UpdateAttentionMaskImpl<int64_t>(total_length);
   }
 
   attention_mask_ = std::move(attention_mask_next_);
@@ -307,7 +287,7 @@ void DefaultPositionInputs::UpdateAttentionMaskImpl(int total_length) {
   }
 }
 
-void PositionInputs::RewindMask(size_t index) {
+void DefaultPositionInputs::RewindMask(size_t index) {
   if (sb_attention_mask_ && !is_first_mask_update_) {
     throw std::runtime_error("PositionInputs::RewindMask - Static buffer is not supported for continuous decoding.");
 #if 0  // TODO: Fix implementation, cudaMemsetAsync of 1 is setting bytes of 1 vs int32's of 1
diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp
index 0da495420..82fac1c7d 100644
--- a/src/ort_genai_c.cpp
+++ b/src/ort_genai_c.cpp
@@ -339,6 +339,13 @@ OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator
   auto copy_span = Generators::ByteWrapTensor(*Generators::GetDeviceInterface(Generators::DeviceType::CPU), *ortvalue_clone);
   copy_span.CopyFrom(output_span);
 
+  auto tensor = std::make_shared<Generators::Tensor>(std::move(ortvalue_clone));
+  tensor->external_owner_ = tensor;
+  *out = reinterpret_cast<OgaTensor*>(tensor.get());
+  return nullptr;
+  OGA_CATCH
+}
+
 OgaResult* OGA_API_CALL OgaGenerator_GetLogits(OgaGenerator* oga_generator, OgaTensor** out) {
   OGA_TRY
   auto generator = reinterpret_cast<Generators::Generator*>(oga_generator);

From 66321dd8fa84457e7839934156fabbbf13455552 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Wed, 15 Jan 2025 23:08:47 -0800
Subject: [PATCH 05/31] Formatting

---
 src/dml/interface.cpp    |  3 +--
 src/dml/interface.h      |  2 +-
 src/generators.h         |  5 +++--
 src/models/input_ids.cpp |  2 +-
 src/models/kv_cache.cpp  |  2 +-
 src/models/logits.cpp    |  2 --
 src/models/model.cpp     |  2 +-
 src/models/whisper.cpp   | 16 ++++++++--------
 src/smartptrs.h          |  2 +-
 9 files changed, 17 insertions(+), 19 deletions(-)

diff --git a/src/dml/interface.cpp b/src/dml/interface.cpp
index 459a71054..58a3457e2 100644
--- a/src/dml/interface.cpp
+++ b/src/dml/interface.cpp
@@ -17,7 +17,7 @@
 std::string CurrentModulePath();
 
 namespace Generators {
-namespace Dml { // If this was in a shared library it wouldn't need to be in its own namespace
+namespace Dml {  // If this was in a shared library it wouldn't need to be in its own namespace
 
 Ort::Allocator* ort_allocator_{};
 const char* label_dml = "dml";
@@ -95,7 +95,6 @@ struct GpuMemory final : DeviceBuffer {
 };
 
 struct DmlInterfaceImpl : DeviceInterface {
-
   DmlInterfaceImpl(LUID* p_device_luid) {
     Ort::ThrowOnError(Ort::api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&dml_api_)));
     if (!dml_api_) {
diff --git a/src/dml/interface.h b/src/dml/interface.h
index c70faf721..6f745ce13 100644
--- a/src/dml/interface.h
+++ b/src/dml/interface.h
@@ -8,4 +8,4 @@ void SetDmlProvider(OrtSessionOptions& options);
 
 DeviceInterface* GetDmlInterface();
 
-}
\ No newline at end of file
+}  // namespace Generators
\ No newline at end of file
diff --git a/src/generators.h b/src/generators.h
index c7b2875d0..9efe85d2d 100644
--- a/src/generators.h
+++ b/src/generators.h
@@ -50,10 +50,10 @@ DeviceSpan<T> WrapTensor(DeviceInterface& device, OrtValue& value) {
 
 DeviceSpan<uint8_t> ByteWrapTensor(DeviceInterface& device, OrtValue& value);
 
-template<typename T>
+template <typename T>
 struct OrtTensor {
   OrtTensor(std::unique_ptr<OrtValue> ort_value, DeviceInterface& device)
-    : ort_value_{std::move(ort_value)}, device_span_{WrapTensor<T>(device, *ort_value_)} {}
+      : ort_value_{std::move(ort_value)}, device_span_{WrapTensor<T>(device, *ort_value_)} {}
 
   operator OrtValue*() { return ort_value_.get(); }
 
@@ -151,6 +151,7 @@ struct OrtGlobals {
 
   std::unique_ptr<OrtEnv> env_;
   std::unique_ptr<Ort::Allocator> allocator_device_[static_cast<int>(DeviceType::MAX)];
+
  private:
   OrtGlobals(const OrtGlobals&) = delete;
   void operator=(const OrtGlobals&) = delete;
diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp
index 133be44f5..38bdba445 100644
--- a/src/models/input_ids.cpp
+++ b/src/models/input_ids.cpp
@@ -81,7 +81,7 @@ void DefaultInputIDs::Update(DeviceSpan<int32_t> new_tokens) {
   }
 
   // Update input_ids with next tokens
-  auto data_span = WrapTensor<int32_t>(*model_.p_device_, *value_); 
+  auto data_span = WrapTensor<int32_t>(*model_.p_device_, *value_);
 
   // For beam search
   if (is_prompt_ && state_.params_->search.num_beams > 1) {
diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp
index a8eff8005..51210ae4e 100644
--- a/src/models/kv_cache.cpp
+++ b/src/models/kv_cache.cpp
@@ -341,7 +341,7 @@ void DefaultKeyValueCache::PickPastState(DeviceSpan<int32_t> beam_indices_device
     auto past = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
     past.CopyFrom(present);
   }
-  
+
   pasts_[index] = std::move(past_value);
 }
 
diff --git a/src/models/logits.cpp b/src/models/logits.cpp
index cd28c5b31..4a7ac7f6a 100644
--- a/src/models/logits.cpp
+++ b/src/models/logits.cpp
@@ -12,7 +12,6 @@ Logits::Logits(State& state)
       type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} {
   output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
 
-
   if (model_.device_type_ == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) {
     auto& cpu_ids = model_.config_->model.eos_token_ids;
     cuda_eos_token_ids_ = state_.params_->p_device->Allocate<int32_t>(cpu_ids.size());
@@ -52,7 +51,6 @@ DeviceSpan<float> Logits::Get() {
       // Find the first non pad token from the end
       size_t token_index = input_sequence_lengths[batch_index] - 1;
       for (int beam_index = 0; beam_index < num_beams; beam_index++) {
-
         auto target = logits_last_tokens.subspan(vocab_index * element_size, vocab_size * element_size);
         auto source = logits_raw.subspan((vocab_index * seq_length + token_index * vocab_size) * element_size, vocab_size * element_size);
         target.CopyFrom(source);
diff --git a/src/models/model.cpp b/src/models/model.cpp
index 29869f7e7..fd552b7c8 100644
--- a/src/models/model.cpp
+++ b/src/models/model.cpp
@@ -564,7 +564,7 @@ void Cast(OrtValue& input, std::unique_ptr<OrtValue>& output, DeviceInterface& d
   if (!output)
     output = OrtValue::CreateTensor(device.GetAllocator(), shape, output_type);
 
-  if(!device.Cast(input, *output))
+  if (!device.Cast(input, *output))
     GetDeviceInterface(DeviceType::CPU)->Cast(input, *output);
 }
 
diff --git a/src/models/whisper.cpp b/src/models/whisper.cpp
index ec6cf42bf..6c46f218d 100644
--- a/src/models/whisper.cpp
+++ b/src/models/whisper.cpp
@@ -38,7 +38,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, DeviceSpan<int32_t> seq
   }
 
   if (inputs.alignment_heads != nullptr) {
-#if 0 // USE_CUDA
+#if 0  // USE_CUDA
     auto alignment_heads_type_and_shape_info = inputs.alignment_heads->ort_tensor_->GetTensorTypeAndShapeInfo();
     auto alignment_heads_type = alignment_heads_type_and_shape_info->GetElementType();  // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
     auto alignment_heads_shape = alignment_heads_type_and_shape_info->GetShape();
@@ -97,7 +97,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, DeviceSpan<int32_t> seq
   }
 }
 
-#if 0 // USE_CUDA
+#if 0  // USE_CUDA
 template <typename T>
 void TransposeKCacheForDMMHA(T* dest_data,
                              T* temp_buffer,
@@ -143,7 +143,7 @@ DeviceSpan<float> Whisper_State::Run(int current_length, DeviceSpan<int32_t>& ne
 
       const auto copy_data_size_all = src_shape_info->GetElementCount() * SizeOf(src_shape_info->GetElementType());
 
-#if 0 // USE_CUDA
+#if 0  // USE_CUDA
       const auto src_dims = src_shape_info->GetShape();
       const auto src_element_type = src_shape_info->GetElementType();
       const auto src_element_size = SizeOf(src_element_type);
@@ -183,7 +183,7 @@ DeviceSpan<float> Whisper_State::Run(int current_length, DeviceSpan<int32_t>& ne
         auto dest_data = presents_[i]->GetTensorMutableRawData();
 
         switch (model_.device_type_) {
-#if 0 // USE_CUDA
+#if 0  // USE_CUDA
           case DeviceType::CUDA:
             if (self_attn_kv_cache_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
               // CUDA EP + FP16 precision == `DecoderMaskedMultiHeadAttention` op is used
@@ -224,7 +224,7 @@ DeviceSpan<float> Whisper_State::Run(int current_length, DeviceSpan<int32_t>& ne
         }
       }
 
-#if 0 // USE_CUDA
+#if 0  // USE_CUDA
       if (self_attn_kv_cache_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 && model_.device_type_ == DeviceType::CUDA) {
         // Transpose cross attention K caches for `DecoderMaskedMultiHeadAttention`
 
@@ -327,7 +327,7 @@ void Whisper_State::UpdateInputsOutputs(DeviceSpan<int32_t>& next_tokens, Device
   }
 
   if (cache_indirection_) {
-#if 0 // USE_CUDA
+#if 0  // USE_CUDA
     auto beam_indices_gpu = gpu_span<int32_t>{beam_indices.Span()};
     if (beam_indices_gpu.empty()) {
       auto beam_indices_cpu = beam_indices.CpuSpan();
@@ -355,7 +355,7 @@ void Whisper_State::UpdateInputsOutputs(DeviceSpan<int32_t>& next_tokens, Device
   }
 
   if (output_cross_qk_.size() && alignment_heads_) {
-#if 0 // USE_CUDA
+#if 0  // USE_CUDA
     // Collect a GPU array of float* pointers from the vector of OrtValues to pass to the kernel
     auto output_cross_qk_ptrs = cross_qk_ptrs_gpu_.CpuSpan();
     assert(output_cross_qk_ptrs.size() == output_cross_qk_.size());
@@ -386,7 +386,7 @@ void Whisper_State::Initialize(DeviceSpan<int32_t>& next_tokens, int total_lengt
 
 void Whisper_State::Finalize() {
   if (output_cross_qk_.size() && alignment_heads_) {
-#if 0 // USE_CUDA
+#if 0  // USE_CUDA
     int decoded_length = *(past_sequence_length_->GetTensorMutableData<int32_t>()) + 1;
     auto output_cross_qk_dims = output_cross_qk_[0]->GetTensorTypeAndShapeInfo()->GetShape();
 
diff --git a/src/smartptrs.h b/src/smartptrs.h
index a25c8c133..1e2a79f47 100644
--- a/src/smartptrs.h
+++ b/src/smartptrs.h
@@ -110,7 +110,7 @@ struct DeviceInterface {
   virtual void LaunchHandleEOSArray(float* /*batch_logits*/, int /*batch_beam_size*/, int /*vocab_size*/, const int32_t* /*eos_token_ids*/, int /*eos_token_ids_count*/) { assert(false); }
   virtual void UpdateCacheIndirectionKernelLauncher(int32_t* /*tgt_indir_cache*/, const int32_t* /*src_indir_cache*/, const int32_t* /*beam_ids*/, int /*batch_size*/, int /*beam_width*/, int /*input_seq_length*/, int /*max_seq_length*/, int /*current_length*/) { assert(false); }
   virtual void ReorderPastStatesKernelLauncher(void* /*out_buffer*/, const void* /*in_buffer*/, int /*batch_size*/, int /*num_heads*/, int /*max_length*/, int /*head_size*/, int /*chunk_size*/) { assert(false); }
-  virtual void LaunchCopyCrossQKSingleDecodeStep(float* /*cross_qk_buffer_data*/, float** /*qk_layer_pointers*/, int /*token_index*/, int /*batch_beam_size*/, int /*num_layers*/, int /*num_heads*/, int /*num_alignment_heads*/, const int* /*alignment_heads*/, int /*frames*/, int /*max_length*/) { assert(false);  }
+  virtual void LaunchCopyCrossQKSingleDecodeStep(float* /*cross_qk_buffer_data*/, float** /*qk_layer_pointers*/, int /*token_index*/, int /*batch_beam_size*/, int /*num_layers*/, int /*num_heads*/, int /*num_alignment_heads*/, const int* /*alignment_heads*/, int /*frames*/, int /*max_length*/) { assert(false); }
   virtual void LaunchFinalizeCrossQK(int /*iteration_number*/, int /*context_decoding_len*/, int /*batch_size*/, int /*num_beams*/, int /*max_length*/, int /*num_alignment_heads*/, int /*frames_of_k*/, const float* /*cross_qk_buffer_data*/, float* /*cross_qk_output*/, int /*num_return_sequences*/, const int* /*cache_indir_data*/) { assert(false); }
 
   virtual void* GetCudaStream() {

From 0bc39a5b7e4fb9df026b184647e20d8248810c9f Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Thu, 16 Jan 2025 20:15:19 -0800
Subject: [PATCH 06/31] Build fixes

---
 src/dml/interface.h    | 4 ++++
 src/generators.cpp     | 3 ++-
 src/models/input_ids.h | 2 +-
 src/models/model.cpp   | 5 +----
 4 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/src/dml/interface.h b/src/dml/interface.h
index 6f745ce13..5c43aa11b 100644
--- a/src/dml/interface.h
+++ b/src/dml/interface.h
@@ -1,6 +1,10 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
+#ifndef _WIN32
+using LUID = void;
+#endif
+
 namespace Generators {
 
 void InitDmlInterface(LUID* p_device_luid);
diff --git a/src/generators.cpp b/src/generators.cpp
index c5ad02e45..e7a781642 100644
--- a/src/generators.cpp
+++ b/src/generators.cpp
@@ -164,8 +164,9 @@ std::string to_string(DeviceType device_type) {
       return "WebGpu";
     case DeviceType::QNN:
       return "QnnWithSharedMemory";
+    default:
+      throw std::runtime_error("Unknown device type");
   }
-  throw std::runtime_error("Unknown device type");
 }
 
 DeviceInterface* GetDeviceInterface(DeviceType type) {
diff --git a/src/models/input_ids.h b/src/models/input_ids.h
index cc9ab5640..3388f0cd8 100644
--- a/src/models/input_ids.h
+++ b/src/models/input_ids.h
@@ -21,7 +21,7 @@ struct DefaultInputIDs : InputIDs {
   void Add() override;
   // Resize input_ids based on size of next_tokens.
   // Update value with next_tokens.
-  void Update(DeviceSpan<int32_t> next_tokens);
+  void Update(DeviceSpan<int32_t> next_tokens) override;
 
   std::array<int64_t, 2> GetShape() const override { return shape_; }
   const char* name_;
diff --git a/src/models/model.cpp b/src/models/model.cpp
index fd552b7c8..4009481cc 100644
--- a/src/models/model.cpp
+++ b/src/models/model.cpp
@@ -15,9 +15,7 @@
 #include "whisper.h"
 #include "multi_modal_vision_model.h"
 #include "decoder_only_pipeline.h"
-#if USE_DML
 #include "../dml/interface.h"
-#endif
 
 namespace Generators {
 
@@ -425,7 +423,6 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
 
       Ort::ThrowOnError(Ort::api->UpdateROCMProviderOptions(&ort_provider_options, keys.data(), values.data(), keys.size()));
       session_options.AppendExecutionProvider_ROCM(ort_provider_options);
-#if USE_DML
     } else if (provider_options.name == "dml") {
       if (!GetDmlInterface()) {
         LUID device_luid{};
@@ -452,7 +449,7 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
 
       if (is_primary_session_options)
         device_type_ = DeviceType::DML;  // We use a DML allocator for input/output caches, but other tensors will use CPU tensors
-#endif
+
     } else if (provider_options.name == "qnn") {
       session_options.AddConfigEntry("ep.share_ep_contexts", "1");
       std::unordered_map<std::string, std::string> opts;

From 5244049dc15dbbec7b65a85ca7abcd59501d26ca Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Thu, 16 Jan 2025 21:21:24 -0800
Subject: [PATCH 07/31] Build fix

---
 src/dml/interface.h  | 7 +++++--
 src/models/model.cpp | 3 ++-
 2 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/src/dml/interface.h b/src/dml/interface.h
index 5c43aa11b..e3f2afdae 100644
--- a/src/dml/interface.h
+++ b/src/dml/interface.h
@@ -2,7 +2,10 @@
 // Licensed under the MIT License.
 
 #ifndef _WIN32
-using LUID = void;
+typedef struct _LUID {
+  DWORD LowPart;
+  LONG HighPart;
+} LUID, *PLUID;
 #endif
 
 namespace Generators {
@@ -12,4 +15,4 @@ void SetDmlProvider(OrtSessionOptions& options);
 
 DeviceInterface* GetDmlInterface();
 
-}  // namespace Generators
\ No newline at end of file
+}  // namespace Generators
diff --git a/src/models/model.cpp b/src/models/model.cpp
index 4009481cc..71c876fba 100644
--- a/src/models/model.cpp
+++ b/src/models/model.cpp
@@ -423,6 +423,7 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
 
       Ort::ThrowOnError(Ort::api->UpdateROCMProviderOptions(&ort_provider_options, keys.data(), values.data(), keys.size()));
       session_options.AppendExecutionProvider_ROCM(ort_provider_options);
+#if USE_DML
     } else if (provider_options.name == "dml") {
       if (!GetDmlInterface()) {
         LUID device_luid{};
@@ -449,7 +450,7 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
 
       if (is_primary_session_options)
         device_type_ = DeviceType::DML;  // We use a DML allocator for input/output caches, but other tensors will use CPU tensors
-
+#endif
     } else if (provider_options.name == "qnn") {
       session_options.AddConfigEntry("ep.share_ep_contexts", "1");
       std::unordered_map<std::string, std::string> opts;

From 49b51ef1431232cb0f4b0ac5f0f05f48a32026cf Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Thu, 16 Jan 2025 21:33:46 -0800
Subject: [PATCH 08/31] Build fix

---
 src/dml/interface.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/dml/interface.h b/src/dml/interface.h
index e3f2afdae..9ef2c5785 100644
--- a/src/dml/interface.h
+++ b/src/dml/interface.h
@@ -3,8 +3,8 @@
 
 #ifndef _WIN32
 typedef struct _LUID {
-  DWORD LowPart;
-  LONG HighPart;
+  uint32_t LowPart;
+  int32_t HighPart;
 } LUID, *PLUID;
 #endif
 

From d3db2f6673fc2283d383e452410a98eb756d9465 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Tue, 21 Jan 2025 14:06:51 -0800
Subject: [PATCH 09/31] Fix input_ids issue from merge

---
 src/models/input_ids.cpp | 9 ++-------
 src/models/input_ids.h   | 3 ---
 2 files changed, 2 insertions(+), 10 deletions(-)

diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp
index 38bdba445..a3a988608 100644
--- a/src/models/input_ids.cpp
+++ b/src/models/input_ids.cpp
@@ -7,7 +7,7 @@ namespace Generators {
 DefaultInputIDs::DefaultInputIDs(State& state)
     : state_{state} {
   name_ = model_.config_->model.decoder.inputs.input_ids.c_str();
-  shape_ = {state_.params_->search.batch_size, 0};
+  shape_ = {state_.params_->BatchBeamSize(), 0};
   type_ = model_.session_info_->GetInputDataType(name_);
 
   if (model_.session_info_->HasInput(model_.config_->model.decoder.inputs.current_sequence_length) &&
@@ -71,12 +71,7 @@ void DefaultInputIDs::Update(DeviceSpan<int32_t> new_tokens) {
 
   if (static_cast<size_t>(shape_[1]) != sequence_length) {
     shape_[1] = sequence_length;
-    if (!sb_input_ids_) {
-      value_ = OrtValue::CreateTensor<int32_t>(*model_.allocator_device_, shape_);
-    } else {
-      value_ = sb_input_ids_->CreateTensorOnStaticBuffer(shape_, Ort::TypeToTensorType<int32_t>);
-    }
-
+    value_ = OrtValue::CreateTensor<int32_t>(*model_.allocator_device_, shape_);
     state_.inputs_[input_index_] = value_.get();
   }
 
diff --git a/src/models/input_ids.h b/src/models/input_ids.h
index 3388f0cd8..fd4b159e6 100644
--- a/src/models/input_ids.h
+++ b/src/models/input_ids.h
@@ -40,9 +40,6 @@ struct DefaultInputIDs : InputIDs {
   std::unique_ptr<OrtValue> value_;
   std::unique_ptr<OrtValue> cast_value_;
 
-  // Used for decoding runs with cuda graphs.
-  StaticBuffer* sb_input_ids_{};
-
   std::unique_ptr<OrtValue> current_sequence_length_;
   std::unique_ptr<OrtValue> past_sequence_length_;
 };

From 133d5a09698a7b2c665035abfc2f227aadde3c5d Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Tue, 21 Jan 2025 16:56:59 -0800
Subject: [PATCH 10/31] Fix C# unit tests

---
 test/csharp/TestOnnxRuntimeGenAIAPI.cs | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs
index ddced9e42..1a3256625 100644
--- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs
+++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs
@@ -8,10 +8,11 @@
 using System.Runtime.CompilerServices;
 using Xunit;
 using Xunit.Abstractions;
+using Microsoft.ML.OnnxRuntimeGenAI;
 
 namespace Microsoft.ML.OnnxRuntimeGenAI.Tests
 {
-    public class OnnxRuntimeGenAITests
+    public class OnnxRuntimeGenAITests : IDisposable
     {
         private readonly ITestOutputHelper output;
 
@@ -86,12 +87,17 @@ private static string GetDirectoryInTreeThatContains(string currentDirectory, st
         });
 
         private static string _adaptersPath => _lazyAdaptersPath.Value;
-
+        private OgaHandle ogaHandle;
 
         public OnnxRuntimeGenAITests(ITestOutputHelper o)
         {
+            this.ogaHandle = new OgaHandle();
             this.output = o;
         }
+        public void Dispose()
+        {
+            ogaHandle?.Dispose();
+        }
 
         private class IgnoreOnModelAbsenceFact : FactAttribute
         {

From b079b7407b64557b3c613eb0cd61ccd9dfc7e21e Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Tue, 21 Jan 2025 20:51:21 -0800
Subject: [PATCH 11/31] Try again to fix C# test

---
 test/csharp/TestOnnxRuntimeGenAIAPI.cs | 11 ++++-------
 1 file changed, 4 insertions(+), 7 deletions(-)

diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs
index 1a3256625..34baf64e7 100644
--- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs
+++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs
@@ -12,7 +12,7 @@
 
 namespace Microsoft.ML.OnnxRuntimeGenAI.Tests
 {
-    public class OnnxRuntimeGenAITests : IDisposable
+    public class OnnxRuntimeGenAITests
     {
         private readonly ITestOutputHelper output;
 
@@ -87,17 +87,14 @@ private static string GetDirectoryInTreeThatContains(string currentDirectory, st
         });
 
         private static string _adaptersPath => _lazyAdaptersPath.Value;
-        private OgaHandle ogaHandle;
+        private static OgaHandle ogaHandle;
 
         public OnnxRuntimeGenAITests(ITestOutputHelper o)
         {
-            this.ogaHandle = new OgaHandle();
+            ogaHandle = new OgaHandle();
+            AppDomain.CurrentDomain.ProcessExit += (sender, e) => ogaHandle.Dispose();
             this.output = o;
         }
-        public void Dispose()
-        {
-            ogaHandle?.Dispose();
-        }
 
         private class IgnoreOnModelAbsenceFact : FactAttribute
         {

From afecf1dd48b3d17f02910da0c8b7d5c735574a6c Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Wed, 22 Jan 2025 16:19:11 -0800
Subject: [PATCH 12/31] Test theory

---
 src/generators.cpp   | 2 ++
 src/models/model.cpp | 5 ++++-
 2 files changed, 6 insertions(+), 1 deletion(-)

diff --git a/src/generators.cpp b/src/generators.cpp
index 6fe0a8ae4..57fa43f67 100644
--- a/src/generators.cpp
+++ b/src/generators.cpp
@@ -50,6 +50,8 @@ OrtGlobals::OrtGlobals()
   auto arena_config = OrtArenaCfg::Create(0, -1, -1, -1);
   Ort::Allocator& allocator_cpu{Ort::Allocator::GetWithDefaultOptions()};
   env_->CreateAndRegisterAllocator(allocator_cpu.GetInfo(), *arena_config);
+
+  std::cerr << "OrtGlobals::OrtGlobals completed" << std::endl;
 }
 
 // Ensure Shutdown() has been called before process exit
diff --git a/src/models/model.cpp b/src/models/model.cpp
index 221786307..ea2626b5a 100644
--- a/src/models/model.cpp
+++ b/src/models/model.cpp
@@ -225,9 +225,12 @@ Ort::Allocator* GetDeviceAllocator(OrtSession& session, DeviceType type) {
   auto& device = GetOrtGlobals()->allocator_device_[static_cast<int>(type)];
   if (!device) {
     static const char* device_type_names[static_cast<int>(DeviceType::MAX)] = {"CPU", "Cuda", "DML", "WebGPU Buffer"};
+    std::cerr << "GetDeviceAllocator: Creating device allocator for " << device_type_names[static_cast<int>(type)] << std::endl;
+
     auto memory_info = OrtMemoryInfo::Create(device_type_names[static_cast<int>(type)], OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
     device = Ort::Allocator::Create(session, *memory_info);
-    GetDeviceInterface(type)->InitOrt(*Ort::api, *device);
+    GetDeviceInterface(type)->InitOrt(*Ort::api, *device); // Necessary for any shared library providers so they can access Ort::api
+    std::cerr << "GetDeviceAllocator: Device created" << std::endl;
   }
   return device.get();
 }

From 1734f5c18b2cda02411b9e392431340cfe46d299 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Fri, 24 Jan 2025 03:36:49 -0800
Subject: [PATCH 13/31] Test instrumenting

---
 test/csharp/TestOnnxRuntimeGenAIAPI.cs | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs
index e9f825ddf..58ad56fcf 100644
--- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs
+++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs
@@ -91,6 +91,10 @@ private static string GetDirectoryInTreeThatContains(string currentDirectory, st
 
         public OnnxRuntimeGenAITests(ITestOutputHelper o)
         {
+            Console.WriteLine("**** Running OnnxRuntimeGenAITests constructor");
+            // Initialize GenAI and register a handler to dispose it on process exit
+            ogaHandle = new OgaHandle();
+            AppDomain.CurrentDomain.ProcessExit += (sender, e) => ogaHandle.Dispose();
             this.output = o;
         }
 
@@ -575,6 +579,8 @@ public IgnoreOnAdaptersAbsentFact()
         [IgnoreOnAdaptersAbsentFact(DisplayName = "TestAdapters")]
         public void TestAdapters()
         {
+            Console.WriteLine("**** Running TestAdapters");
+
             string modelPath = _adaptersPath;
             string adapterPath = Path.Combine(modelPath, "adapters.onnx_adapter");
 

From 2bc83ebb7c3f75ca1ea725678f6e69dafd84f406 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Fri, 24 Jan 2025 13:46:49 -0800
Subject: [PATCH 14/31] Crash investigation

---
 test/csharp/TestOnnxRuntimeGenAIAPI.cs | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs
index 58ad56fcf..747675a1a 100644
--- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs
+++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs
@@ -8,7 +8,6 @@
 using System.Runtime.CompilerServices;
 using Xunit;
 using Xunit.Abstractions;
-using Microsoft.ML.OnnxRuntimeGenAI;
 
 namespace Microsoft.ML.OnnxRuntimeGenAI.Tests
 {
@@ -96,6 +95,7 @@ public OnnxRuntimeGenAITests(ITestOutputHelper o)
             ogaHandle = new OgaHandle();
             AppDomain.CurrentDomain.ProcessExit += (sender, e) => ogaHandle.Dispose();
             this.output = o;
+            Console.WriteLine("**** OnnxRuntimeGenAI constructor completed");
         }
 
         private class IgnoreOnModelAbsenceFact : FactAttribute

From fd788d76aff72754c8632538592b1257d5cc0074 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Fri, 24 Jan 2025 13:48:33 -0800
Subject: [PATCH 15/31] Extra debug logging

---
 src/generators.cpp | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/src/generators.cpp b/src/generators.cpp
index 57fa43f67..c789e5bcc 100644
--- a/src/generators.cpp
+++ b/src/generators.cpp
@@ -37,6 +37,11 @@ void ThrowErrorIfSessionTerminated(bool is_session_terminated) {
 
 namespace Generators {
 
+static bool _1 = []() {
+  std::cerr << "OrtGlobals::OrtGlobals started" << std::endl;
+  return false;
+}();
+
 static bool _ = (Ort::InitApi(), false);
 
 static OrtLoggingLevel GetDefaultOrtLoggingLevel() {
@@ -47,6 +52,9 @@ static OrtLoggingLevel GetDefaultOrtLoggingLevel() {
 
 OrtGlobals::OrtGlobals()
     : env_{OrtEnv::Create(GetDefaultOrtLoggingLevel())} {
+  std::cerr << "OrtGlobals::OrtGlobals started" << std::endl;
+
+
   auto arena_config = OrtArenaCfg::Create(0, -1, -1, -1);
   Ort::Allocator& allocator_cpu{Ort::Allocator::GetWithDefaultOptions()};
   env_->CreateAndRegisterAllocator(allocator_cpu.GetInfo(), *arena_config);

From 0303592f461721963c98c11395e624c6a5e8e07f Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Fri, 24 Jan 2025 18:17:04 -0800
Subject: [PATCH 16/31] Undefined behavior fix in startup

---
 src/cpu/interface.cpp |  8 +++++---
 src/generators.cpp    | 11 ++---------
 src/models/model.cpp  | 10 ++++++----
 3 files changed, 13 insertions(+), 16 deletions(-)

diff --git a/src/cpu/interface.cpp b/src/cpu/interface.cpp
index c6705fa3b..0b34b80c4 100644
--- a/src/cpu/interface.cpp
+++ b/src/cpu/interface.cpp
@@ -44,7 +44,6 @@ struct CpuMemory final : DeviceBuffer {
 
 struct CpuInterface : DeviceInterface {
   CpuInterface() {
-    InitOrt(*Ort::api, Ort::Allocator::GetWithDefaultOptions());
   }
 
   void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override {
@@ -101,8 +100,11 @@ struct CpuInterface : DeviceInterface {
   std::unique_ptr<Search> CreateBeam(const GeneratorParams& params) override { return std::make_unique<BeamSearch_Cpu>(params); }
 
   void Synchronize() override {}  // Nothing to do as CPU is always in sync with itself
-} g_cpu;
+};
 
-DeviceInterface* GetCpuInterface() { return &g_cpu; }
+DeviceInterface* GetCpuInterface() {
+  static std::unique_ptr<CpuInterface> g_cpu = std::make_unique<CpuInterface>();
+  return g_cpu.get();
+}
 
 }  // namespace Generators
diff --git a/src/generators.cpp b/src/generators.cpp
index c789e5bcc..b003ac331 100644
--- a/src/generators.cpp
+++ b/src/generators.cpp
@@ -37,11 +37,6 @@ void ThrowErrorIfSessionTerminated(bool is_session_terminated) {
 
 namespace Generators {
 
-static bool _1 = []() {
-  std::cerr << "OrtGlobals::OrtGlobals started" << std::endl;
-  return false;
-}();
-
 static bool _ = (Ort::InitApi(), false);
 
 static OrtLoggingLevel GetDefaultOrtLoggingLevel() {
@@ -52,14 +47,12 @@ static OrtLoggingLevel GetDefaultOrtLoggingLevel() {
 
 OrtGlobals::OrtGlobals()
     : env_{OrtEnv::Create(GetDefaultOrtLoggingLevel())} {
-  std::cerr << "OrtGlobals::OrtGlobals started" << std::endl;
-
-
   auto arena_config = OrtArenaCfg::Create(0, -1, -1, -1);
   Ort::Allocator& allocator_cpu{Ort::Allocator::GetWithDefaultOptions()};
   env_->CreateAndRegisterAllocator(allocator_cpu.GetInfo(), *arena_config);
 
-  std::cerr << "OrtGlobals::OrtGlobals completed" << std::endl;
+  // Init the CPU device (special case because it always exists, and its allocator is special
+  GetDeviceInterface(DeviceType::CPU)->InitOrt(*Ort::api, allocator_cpu);
 }
 
 // Ensure Shutdown() has been called before process exit
diff --git a/src/models/model.cpp b/src/models/model.cpp
index ea2626b5a..b328bed11 100644
--- a/src/models/model.cpp
+++ b/src/models/model.cpp
@@ -222,15 +222,17 @@ int32_t Tokenizer::TokenToTokenId(const char* token) const {
 // has been destroyed. Without this, we will crash in the Onnxruntime BFCArena code when deleting tensors due to the
 // arena already being destroyed.
 Ort::Allocator* GetDeviceAllocator(OrtSession& session, DeviceType type) {
+  // CPU Allocator is a special case, so don't try to create it
+  if (type == DeviceType::CPU)
+    return &GetDeviceInterface(DeviceType::CPU)->GetAllocator();
+
   auto& device = GetOrtGlobals()->allocator_device_[static_cast<int>(type)];
   if (!device) {
-    static const char* device_type_names[static_cast<int>(DeviceType::MAX)] = {"CPU", "Cuda", "DML", "WebGPU Buffer"};
-    std::cerr << "GetDeviceAllocator: Creating device allocator for " << device_type_names[static_cast<int>(type)] << std::endl;
+    static const char* device_type_names[static_cast<int>(DeviceType::MAX)] = {"CPU - SEE ABOVE", "Cuda", "DML", "WebGPU Buffer"};
 
     auto memory_info = OrtMemoryInfo::Create(device_type_names[static_cast<int>(type)], OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
     device = Ort::Allocator::Create(session, *memory_info);
-    GetDeviceInterface(type)->InitOrt(*Ort::api, *device); // Necessary for any shared library providers so they can access Ort::api
-    std::cerr << "GetDeviceAllocator: Device created" << std::endl;
+    GetDeviceInterface(type)->InitOrt(*Ort::api, *device);  // Necessary for any shared library providers so they can access Ort::api
   }
   return device.get();
 }

From d87807c9629b35d4fb0314fd8b9953c3b05b37a4 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Fri, 24 Jan 2025 20:40:43 -0800
Subject: [PATCH 17/31] Don't load cuda library outside of linux & windows

---
 src/generators.cpp | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/src/generators.cpp b/src/generators.cpp
index b003ac331..af2285336 100644
--- a/src/generators.cpp
+++ b/src/generators.cpp
@@ -11,7 +11,7 @@
 #include "cuda/interface.h"
 #include "dml/interface.h"
 
-#if _WIN32
+#if defined(_WIN32)
 EXTERN_C IMAGE_DOS_HEADER __ImageBase;
 
 std::string CurrentModulePath() {
@@ -130,12 +130,14 @@ DeviceInterface* GetCudaInterface() {
 // Load the shared library onnxruntime-genai-cuda.dll
 // This is a workaround to avoid linking the CUDA library to the generator library
 // The CUDA library is only needed for the CUDA allocator
-#ifdef _WIN32
+#if defined(_WIN32)
   static std::unique_ptr<void, void (*)(void*)> cuda_library{LoadLibrary((CurrentModulePath() + "onnxruntime-genai-cuda.dll").c_str()),
                                                              [](void* h) { FreeLibrary(reinterpret_cast<HMODULE>(h)); }};
-#else
+#elif defined(__linux__)
   static std::unique_ptr<void, void (*)(void*)> cuda_library{dlopen((Ort::GetCurrentModuleDir() + "/libonnxruntime-genai-cuda.so").c_str(), RTLD_NOW | RTLD_DEEPBIND),
                                                              [](void* h) { dlclose(h); }};
+#else
+  static std::unique_ptr<void, void (*)(void*)> cuda_library{nullptr, [](void* h) {}};
 #endif
 
   if (!cuda_library) {

From 2df5fe17390bcf0501badac172598a2fdd9fe72d Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Fri, 24 Jan 2025 20:44:57 -0800
Subject: [PATCH 18/31] Fix iOS break

---
 src/generators.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/src/generators.cpp b/src/generators.cpp
index af2285336..984cd788f 100644
--- a/src/generators.cpp
+++ b/src/generators.cpp
@@ -146,10 +146,12 @@ DeviceInterface* GetCudaInterface() {
 
   Generators::DeviceInterface* GetInterface(GenaiInterface * p_genai);
   static DeviceInterface* cuda_interface{[] {
-#ifdef _WIN32
+#if defined(_WIN32)
     auto get_cuda_fn = reinterpret_cast<decltype(&GetInterface)>(GetProcAddress(reinterpret_cast<HMODULE>(cuda_library.get()), "GetInterface"));
-#else
+#elif defined(__linux__)
     auto get_cuda_fn = reinterpret_cast<decltype(&GetInterface)>(dlsym(cuda_library.get(), "GetInterface"));
+#else
+    auto get_cuda_fn = [](GenaiInterface*) { return nullptr; };
 #endif
     return get_cuda_fn(&g_genai);
   }()};

From 67365174296873ed6603e30d9ab79f69fb9baf62 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Fri, 24 Jan 2025 20:59:24 -0800
Subject: [PATCH 19/31] Android tweak

---
 src/generators.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/generators.cpp b/src/generators.cpp
index 984cd788f..9aa561e9e 100644
--- a/src/generators.cpp
+++ b/src/generators.cpp
@@ -133,7 +133,7 @@ DeviceInterface* GetCudaInterface() {
 #if defined(_WIN32)
   static std::unique_ptr<void, void (*)(void*)> cuda_library{LoadLibrary((CurrentModulePath() + "onnxruntime-genai-cuda.dll").c_str()),
                                                              [](void* h) { FreeLibrary(reinterpret_cast<HMODULE>(h)); }};
-#elif defined(__linux__)
+#elif defined(__linux__) && !defined(__ANDROID__)
   static std::unique_ptr<void, void (*)(void*)> cuda_library{dlopen((Ort::GetCurrentModuleDir() + "/libonnxruntime-genai-cuda.so").c_str(), RTLD_NOW | RTLD_DEEPBIND),
                                                              [](void* h) { dlclose(h); }};
 #else
@@ -148,7 +148,7 @@ DeviceInterface* GetCudaInterface() {
   static DeviceInterface* cuda_interface{[] {
 #if defined(_WIN32)
     auto get_cuda_fn = reinterpret_cast<decltype(&GetInterface)>(GetProcAddress(reinterpret_cast<HMODULE>(cuda_library.get()), "GetInterface"));
-#elif defined(__linux__)
+#elif defined(__linux__) && !defined(__ANDROID__)
     auto get_cuda_fn = reinterpret_cast<decltype(&GetInterface)>(dlsym(cuda_library.get(), "GetInterface"));
 #else
     auto get_cuda_fn = [](GenaiInterface*) { return nullptr; };

From a011fe0dc6f5fc0078343c75237dca27e500bda3 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Sun, 26 Jan 2025 18:06:18 -0800
Subject: [PATCH 20/31] Leftover #ifdef fix

---
 test/c_api_tests.cpp        | 4 ----
 test/sampling_benchmark.cpp | 3 ---
 test/sampling_tests.cpp     | 3 ---
 test/tests_helper.cu        | 4 ++--
 test/tests_helper.cuh       | 4 ++--
 5 files changed, 4 insertions(+), 14 deletions(-)

diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp
index c91c69733..8d9d07522 100644
--- a/test/c_api_tests.cpp
+++ b/test/c_api_tests.cpp
@@ -31,10 +31,6 @@ TEST(CAPITests, Config) {
   config->AppendProvider("cuda");
 #endif
 }
-
-#undef USE_CUDA
-#define USE_CUDA 0
-
 TEST(CAPITests, TokenizerCAPI) {
 #if TEST_PHI2
   auto config = OgaConfig::Create(PHI2_PATH);
diff --git a/test/sampling_benchmark.cpp b/test/sampling_benchmark.cpp
index c15478009..eb7f04cd9 100644
--- a/test/sampling_benchmark.cpp
+++ b/test/sampling_benchmark.cpp
@@ -14,9 +14,6 @@
 #define MODEL_PATH "../../test/test_models/"
 #endif
 
-#undef USE_CUDA
-#define USE_CUDA 0
-
 // Defined in sampling_tests.cpp
 void CreateRandomLogits(float* logits, int num_large, int vocab_size, int batch_size, std::mt19937& engine);
 
diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp
index 188367088..6ca0080eb 100644
--- a/test/sampling_tests.cpp
+++ b/test/sampling_tests.cpp
@@ -13,9 +13,6 @@
 #define MODEL_PATH "../../test/test_models/"
 #endif
 
-#undef USE_CUDA
-#define USE_CUDA 0
-
 template<typename T>
 auto AllocateFromCpuMem(Generators::DeviceInterface& device, std::span<const T> cpu_memory) {
   auto memory = device.Allocate<float>(cpu_memory.size());
diff --git a/test/tests_helper.cu b/test/tests_helper.cu
index c15539d02..b6a3c9d5d 100644
--- a/test/tests_helper.cu
+++ b/test/tests_helper.cu
@@ -21,10 +21,10 @@ __global__ void GeometricDecayKernel(float* logits, int vocab_size, int num_larg
   }
 }
 
-void LaunchGeometricDecayKernel(float* logits, int vocab_size, int batch_size, int num_large, float large_val, cudaStream_t stream) {
+void LaunchGeometricDecayKernel(float* logits, int vocab_size, int batch_size, int num_large, float large_val, void* stream) {
   int num_threads = 256;
   int num_blocks = batch_size;
-  GeometricDecayKernel<<<num_blocks, num_threads, 0, stream>>>(logits, vocab_size, num_large, large_val);
+  GeometricDecayKernel<<<num_blocks, num_threads, 0, static_cast<cudaStream_t>(stream)>>>(logits, vocab_size, num_large, large_val);
 }
 
 __global__ void FisherYatesKernel(float* logits, int* indices, int vocab_size, curandState* states) {
diff --git a/test/tests_helper.cuh b/test/tests_helper.cuh
index ff0f3f319..d4dd48a92 100644
--- a/test/tests_helper.cuh
+++ b/test/tests_helper.cuh
@@ -1,5 +1,5 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
-void LaunchGeometricDecayKernel(float* logits, int vocab_size, int batch_size, int num_large, float large_val, cudaStream_t stream);
-void LaunchFisherYatesKernel(float* logits, int* indices, int vocab_size, int batch_size, cudaStream_t stream);
\ No newline at end of file
+void LaunchGeometricDecayKernel(float* logits, int vocab_size, int batch_size, int num_large, float large_val, void* stream);
+void LaunchFisherYatesKernel(float* logits, int* indices, int vocab_size, int batch_size, void* stream);
\ No newline at end of file

From c11704f35cbd6232e9662a124e58d0587e5dd3da Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Sun, 26 Jan 2025 18:13:35 -0800
Subject: [PATCH 21/31] Type tweak

---
 test/sampling_tests.cpp | 12 ++++++------
 test/tests_helper.cu    |  6 +++---
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp
index 6ca0080eb..a42609306 100644
--- a/test/sampling_tests.cpp
+++ b/test/sampling_tests.cpp
@@ -417,8 +417,8 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) {
   for (int i = 0; i < num_iter; i++) {
     int num_large = dist(engine);
     auto generator = Generators::CreateGenerator(*model, *params);
-    LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream);
-    LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->cuda_stream);
+    LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->p_device->GetCudaStream());
+    LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->p_device->GetCudaStream());
     generator->SetLogits(logits_gpu); 
     generator->GenerateNextToken();
     auto next_tokens = generator->search_->GetNextTokens().CopyDeviceToCpu();
@@ -522,8 +522,8 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) {
   for (int i = 0; i < num_iter; i++) {
     int num_large = dist(engine);
     auto generator = Generators::CreateGenerator(*model, *params);
-    LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream);
-    LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->cuda_stream);
+    LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->p_device->GetCudaStream());
+    LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->p_device->GetCudaStream());
     generator->SetLogits(logits_gpu);
     generator->GenerateNextToken();
     auto next_tokens = generator->search_->GetNextTokens().CopyDeviceToCpu();
@@ -558,8 +558,8 @@ TEST(SamplingTests, RandomizedSamplingSelectTopCuda) {
   int num_iter = 100;
   for (int i = 0; i < num_iter; i++) {
     int num_large = dist(engine);
-    LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream);
-    LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->cuda_stream);
+    LaunchGeometricDecayKernel(logits_gpu.Span().data(), config.model.vocab_size, batch_size, num_large, 20.0f, params->p_device->GetCudaStream());
+    LaunchFisherYatesKernel(logits_gpu.Span().data(), indices_buffer.Span().data(), config.model.vocab_size, batch_size, params->p_device->GetCudaStream());
     auto generator = Generators::CreateGenerator(*model, *params);
     generator->SetLogits(logits_gpu);
     generator->GenerateNextToken();
diff --git a/test/tests_helper.cu b/test/tests_helper.cu
index b6a3c9d5d..28f97b0e3 100644
--- a/test/tests_helper.cu
+++ b/test/tests_helper.cu
@@ -74,13 +74,13 @@ void LaunchPopulateIndices(int* indices, int size, int batch_size, cudaStream_t
   PopulateIndices<<<grid, block, 0, stream>>>(indices, size, batch_size);
 }
 
-void LaunchFisherYatesKernel(float* logits, int* indices_buffer, int vocab_size, int batch_size, cudaStream_t stream) {
+void LaunchFisherYatesKernel(float* logits, int* indices_buffer, int vocab_size, int batch_size, void* stream) {
   int num_threads = 256;
   int num_blocks = batch_size;
   curandState *random_states;
   cudaMalloc((void **)&random_states, num_threads * sizeof(curandState));
   std::span<float> logits_span{logits, static_cast<size_t>(vocab_size * batch_size)};
   std::span<int32_t> indices{indices_buffer, static_cast<size_t>(vocab_size * batch_size)};
-  LaunchPopulateIndices(indices.data(), vocab_size, batch_size, stream);
-  FisherYatesKernel<<<num_blocks, num_threads, 0, stream>>>(logits_span.data(), indices.data(), vocab_size, random_states);
+  LaunchPopulateIndices(indices.data(), vocab_size, batch_size, static_cast<cudaStream_t>(stream));
+  FisherYatesKernel<<<num_blocks, num_threads, 0, static_cast<cudaStream_t>(stream)>>>(logits_span.data(), indices.data(), vocab_size, random_states);
 }

From 45dad2b97ba68867c64dc2ce5c382e1aa655d15a Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Tue, 28 Jan 2025 11:51:15 -0800
Subject: [PATCH 22/31] Review feedback

---
 src/cuda/beam_search_scorer_cuda.cpp | 3 +++
 src/cuda/beam_search_scorer_cuda.cu  | 3 +++
 src/cuda/beam_search_scorer_cuda.cuh | 3 +++
 src/cuda/beam_search_scorer_cuda.h   | 3 +++
 src/cuda/beam_search_topk.cu         | 3 +++
 src/cuda/cuda_sampling.cuh           | 1 +
 src/cuda/kernels.h                   | 1 +
 src/cuda/model_kernels.cu            | 1 +
 src/cuda/search_cuda.cpp             | 3 +++
 src/cuda/search_cuda.cu              | 3 +++
 src/cuda/search_cuda.cuh             | 3 +++
 src/cuda/search_cuda.h               | 3 +++
 src/generators.h                     | 1 +
 test/c_api_tests.cpp                 | 4 ++++
 14 files changed, 35 insertions(+)

diff --git a/src/cuda/beam_search_scorer_cuda.cpp b/src/cuda/beam_search_scorer_cuda.cpp
index e2295e00d..a59bc80bc 100644
--- a/src/cuda/beam_search_scorer_cuda.cpp
+++ b/src/cuda/beam_search_scorer_cuda.cpp
@@ -1,3 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
 #include "generators.h"
 #include "search.h"
 #include "search_cuda.h"
diff --git a/src/cuda/beam_search_scorer_cuda.cu b/src/cuda/beam_search_scorer_cuda.cu
index fa9148dcc..3add21bc1 100644
--- a/src/cuda/beam_search_scorer_cuda.cu
+++ b/src/cuda/beam_search_scorer_cuda.cu
@@ -1,3 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
 #include <cuda_runtime.h>
 #include <assert.h>
 #include <algorithm>
diff --git a/src/cuda/beam_search_scorer_cuda.cuh b/src/cuda/beam_search_scorer_cuda.cuh
index 9e441d4bd..7a8834b69 100644
--- a/src/cuda/beam_search_scorer_cuda.cuh
+++ b/src/cuda/beam_search_scorer_cuda.cuh
@@ -1,3 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
 #include "models/onnxruntime_api.h"
 #include "smartptrs.h"
 
diff --git a/src/cuda/beam_search_scorer_cuda.h b/src/cuda/beam_search_scorer_cuda.h
index 8b23a4225..7ec208485 100644
--- a/src/cuda/beam_search_scorer_cuda.h
+++ b/src/cuda/beam_search_scorer_cuda.h
@@ -1,3 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
 namespace Generators {
 
 struct BeamSearchScorer_Cuda {
diff --git a/src/cuda/beam_search_topk.cu b/src/cuda/beam_search_topk.cu
index 222561ce8..32da76fa2 100644
--- a/src/cuda/beam_search_topk.cu
+++ b/src/cuda/beam_search_topk.cu
@@ -1,3 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
 #include <cuda_runtime.h>
 #include <cub/cub.cuh>
 #include <limits>
diff --git a/src/cuda/cuda_sampling.cuh b/src/cuda/cuda_sampling.cuh
index 390ff92bc..529d2e65a 100644
--- a/src/cuda/cuda_sampling.cuh
+++ b/src/cuda/cuda_sampling.cuh
@@ -1,5 +1,6 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
+
 #include <assert.h>
 #include "cuda_common.h"
 #include <curand_kernel.h>
diff --git a/src/cuda/kernels.h b/src/cuda/kernels.h
index 99be3a416..860af48c8 100644
--- a/src/cuda/kernels.h
+++ b/src/cuda/kernels.h
@@ -1,5 +1,6 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
+
 #pragma once
 namespace Generators {
 
diff --git a/src/cuda/model_kernels.cu b/src/cuda/model_kernels.cu
index 59b1f5431..23d9037e0 100644
--- a/src/cuda/model_kernels.cu
+++ b/src/cuda/model_kernels.cu
@@ -1,5 +1,6 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
+
 #include <cuda_fp16.h>
 #include <cuda_runtime.h>
 #include <stdint.h>
diff --git a/src/cuda/search_cuda.cpp b/src/cuda/search_cuda.cpp
index a53ff37dd..5bfbc6ac8 100644
--- a/src/cuda/search_cuda.cpp
+++ b/src/cuda/search_cuda.cpp
@@ -1,3 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
 #include "generators.h"
 #include "interface.h"
 #include "search.h"
diff --git a/src/cuda/search_cuda.cu b/src/cuda/search_cuda.cu
index f8c9ed3bf..fcb21200d 100644
--- a/src/cuda/search_cuda.cu
+++ b/src/cuda/search_cuda.cu
@@ -1,3 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
 #include <cuda_runtime.h>
 #include <cub/cub.cuh>
 #include <algorithm>
diff --git a/src/cuda/search_cuda.cuh b/src/cuda/search_cuda.cuh
index f21237c41..a0a07fb9f 100644
--- a/src/cuda/search_cuda.cuh
+++ b/src/cuda/search_cuda.cuh
@@ -1,3 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
 namespace Generators {
 
 namespace cuda {
diff --git a/src/cuda/search_cuda.h b/src/cuda/search_cuda.h
index bb31b4423..acdd0525f 100644
--- a/src/cuda/search_cuda.h
+++ b/src/cuda/search_cuda.h
@@ -1,3 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
 #pragma once
 #include <cuda_runtime.h>
 #include "search_cuda.cuh"
diff --git a/src/generators.h b/src/generators.h
index 13b3fe617..6d2865a37 100644
--- a/src/generators.h
+++ b/src/generators.h
@@ -167,6 +167,7 @@ std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Model& model);
 std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Config& config);  // For benchmarking purposes only
 std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorParams& params);
 
+// Fallback to copy between two separate device buffers by going through CPU memory (slow unless we're the CPU device)
 void CopyThroughCpu(DeviceBuffer& dest, size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes);
 
 float Float16ToFloat32(uint16_t v);  // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction
diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp
index 8d9d07522..56e40a2c1 100644
--- a/test/c_api_tests.cpp
+++ b/test/c_api_tests.cpp
@@ -1,3 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
 #include <gtest/gtest.h>
 #include <generators.h>
 #include <search.h>
@@ -31,6 +34,7 @@ TEST(CAPITests, Config) {
   config->AppendProvider("cuda");
 #endif
 }
+
 TEST(CAPITests, TokenizerCAPI) {
 #if TEST_PHI2
   auto config = OgaConfig::Create(PHI2_PATH);

From 53c666c1013223b731c18d32e7bc0f7562ca1ab7 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Wed, 29 Jan 2025 12:51:08 -0800
Subject: [PATCH 23/31] Edward gave me ideas.

---
 src/models/model.cpp | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/src/models/model.cpp b/src/models/model.cpp
index b328bed11..e3387f9cb 100644
--- a/src/models/model.cpp
+++ b/src/models/model.cpp
@@ -228,10 +228,14 @@ Ort::Allocator* GetDeviceAllocator(OrtSession& session, DeviceType type) {
 
   auto& device = GetOrtGlobals()->allocator_device_[static_cast<int>(type)];
   if (!device) {
-    static const char* device_type_names[static_cast<int>(DeviceType::MAX)] = {"CPU - SEE ABOVE", "Cuda", "DML", "WebGPU Buffer"};
+    static const char* device_type_names[] = {"CPU (Not used, see above)", "Cuda", "DML", "WebGPU_Buffer", "QNN (Not used, uses CPU memory)"};
+    static_assert(std::size(device_type_names) == static_cast<size_t>(DeviceType::MAX));
 
-    auto memory_info = OrtMemoryInfo::Create(device_type_names[static_cast<int>(type)], OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
+    auto name = device_type_names[static_cast<int>(type)];
+    auto memory_info = OrtMemoryInfo::Create(name, OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
     device = Ort::Allocator::Create(session, *memory_info);
+    if (!device)
+      throw std::runtime_error("Unexpected failure to create device memory allocator for " + std::string(name));
     GetDeviceInterface(type)->InitOrt(*Ort::api, *device);  // Necessary for any shared library providers so they can access Ort::api
   }
   return device.get();

From e8046976c71bbd19e27b5b53d8b231963e345e63 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Wed, 29 Jan 2025 18:14:28 -0800
Subject: [PATCH 24/31] Clean up allocators, now everything is through
 p_device_* interfaces.

---
 cmake/global_variables.cmake         |  4 ++
 src/csharp/NativeMethods.cs          |  3 +-
 src/cuda/interface.cpp               |  6 +--
 src/dml/interface.cpp                | 14 ++---
 src/generators.cpp                   |  6 +++
 src/models/adapters.cpp              |  6 +--
 src/models/decoder_only_pipeline.cpp |  4 +-
 src/models/embeddings.cpp            |  4 +-
 src/models/image_features.cpp        |  4 +-
 src/models/input_ids.cpp             |  6 +--
 src/models/kv_cache.cpp              | 58 ++++++++++----------
 src/models/kv_cache.h                | 13 +++++
 src/models/logits.cpp                | 18 +++----
 src/models/model.cpp                 | 55 +++++++++----------
 src/models/model.h                   |  9 ++--
 src/models/position_inputs.cpp       | 14 ++---
 src/models/whisper.cpp               |  8 +--
 src/qnn/interface.cpp                | 79 ++++++++++++++++++++++++++++
 src/qnn/interface.h                  |  8 +++
 src/webgpu/interface.cpp             | 79 ++++++++++++++++++++++++++++
 src/webgpu/interface.h               |  8 +++
 21 files changed, 300 insertions(+), 106 deletions(-)
 create mode 100644 src/qnn/interface.cpp
 create mode 100644 src/qnn/interface.h
 create mode 100644 src/webgpu/interface.cpp
 create mode 100644 src/webgpu/interface.h

diff --git a/cmake/global_variables.cmake b/cmake/global_variables.cmake
index 9145bdb3c..b05731ee5 100644
--- a/cmake/global_variables.cmake
+++ b/cmake/global_variables.cmake
@@ -66,6 +66,10 @@ file(GLOB generator_srcs CONFIGURE_DEPENDS
   "${GENERATORS_ROOT}/*.cpp"
   "${GENERATORS_ROOT}/cpu/*.h"
   "${GENERATORS_ROOT}/cpu/*.cpp"
+  "${GENERATORS_ROOT}/qnn/*.h"
+  "${GENERATORS_ROOT}/qnn/*.cpp"
+  "${GENERATORS_ROOT}/webgpu/*.h"
+  "${GENERATORS_ROOT}/webgpu/*.cpp"
   "${MODELS_ROOT}/*.h"
   "${MODELS_ROOT}/*.cpp"
 )
diff --git a/src/csharp/NativeMethods.cs b/src/csharp/NativeMethods.cs
index 4dd11169e..e53da7421 100644
--- a/src/csharp/NativeMethods.cs
+++ b/src/csharp/NativeMethods.cs
@@ -17,7 +17,8 @@ internal class NativeLib
             // define the library name required for iOS
             internal const string DllName = "__Internal";
 #else
-            internal const string DllName = "onnxruntime-genai";
+//          internal const string DllName = "C:\\code\\onnxruntime-genai2\\build\\Windows\\Debug\\Debug\\onnxruntime-genai";
+            internal const string DllName = "C:\\code\\onnxruntime-genai2\\build\\Windows\\RelWithDebInfo\\RelWithDebInfo\\onnxruntime-genai";
 #endif
         }
 
diff --git a/src/cuda/interface.cpp b/src/cuda/interface.cpp
index 4c85379b0..2638d0e4b 100644
--- a/src/cuda/interface.cpp
+++ b/src/cuda/interface.cpp
@@ -13,7 +13,7 @@ namespace Generators {
 
 GenaiInterface* gp_genai{};
 Ort::Allocator* ort_allocator_{};
-const char* label_cuda = "cuda";
+const char* device_label = "cuda";
 
 cuda_stream_holder g_stream;
 cudaStream_t GetStream() { return g_stream.get(); }
@@ -36,7 +36,7 @@ struct GpuMemory final : DeviceBuffer {
       ::cudaFreeHost(p_cpu_);
   }
 
-  const char* GetType() const override { return label_cuda; }
+  const char* GetType() const override { return device_label; }
 
   void AllocateCpu() override {
     if (!p_cpu_)
@@ -55,7 +55,7 @@ struct GpuMemory final : DeviceBuffer {
   }
 
   void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override {
-    if (source.GetType() == label_cuda)
+    if (source.GetType() == device_label)
       ::cudaMemcpyAsync(p_device_ + begin_dest, source.p_device_ + begin_source, size_in_bytes, ::cudaMemcpyDeviceToDevice, GetStream());
     else
       gp_genai->CopyThroughCpu(*this, begin_dest, source, begin_source, size_in_bytes);
diff --git a/src/dml/interface.cpp b/src/dml/interface.cpp
index 58a3457e2..a16d770d5 100644
--- a/src/dml/interface.cpp
+++ b/src/dml/interface.cpp
@@ -20,7 +20,7 @@ namespace Generators {
 namespace Dml {  // If this was in a shared library it wouldn't need to be in its own namespace
 
 Ort::Allocator* ort_allocator_{};
-const char* label_dml = "dml";
+const char* device_label = "dml";
 
 wil::unique_hmodule smart_directml_dll_;
 DmlObjects dml_objects_;
@@ -50,7 +50,7 @@ struct GpuMemory final : DeviceBuffer {
       free(p_cpu_);
   }
 
-  const char* GetType() const override { return label_dml; }
+  const char* GetType() const override { return device_label; }
 
   void AllocateCpu() override {
     if (!p_cpu_)
@@ -69,7 +69,7 @@ struct GpuMemory final : DeviceBuffer {
   }
 
   void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override {
-    if (source.GetType() == label_dml) {
+    if (source.GetType() == device_label) {
       auto& source_gpu = dynamic_cast<GpuMemory&>(source);
       dml_execution_context_->CopyBufferRegion(
           gpu_resource_.Get(),
@@ -94,8 +94,8 @@ struct GpuMemory final : DeviceBuffer {
   bool owned_;  // If we own the memory, we delete it on destruction
 };
 
-struct DmlInterfaceImpl : DeviceInterface {
-  DmlInterfaceImpl(LUID* p_device_luid) {
+struct InterfaceImpl : DeviceInterface {
+  InterfaceImpl(LUID* p_device_luid) {
     Ort::ThrowOnError(Ort::api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&dml_api_)));
     if (!dml_api_) {
       throw std::runtime_error("Unexpected nullptr getting OrtDmlApi");
@@ -209,11 +209,11 @@ struct DmlInterfaceImpl : DeviceInterface {
 
 }  // namespace Dml
 
-std::unique_ptr<Dml::DmlInterfaceImpl> g_dml_device;
+std::unique_ptr<Dml::InterfaceImpl> g_dml_device;
 
 void InitDmlInterface(LUID* p_device_luid) {
   if (!g_dml_device)
-    g_dml_device = std::make_unique<Dml::DmlInterfaceImpl>(p_device_luid);
+    g_dml_device = std::make_unique<Dml::InterfaceImpl>(p_device_luid);
 }
 
 void SetDmlProvider(OrtSessionOptions& session_options) {
diff --git a/src/generators.cpp b/src/generators.cpp
index 9aa561e9e..66f49ee64 100644
--- a/src/generators.cpp
+++ b/src/generators.cpp
@@ -10,6 +10,8 @@
 #include "cpu/interface.h"
 #include "cuda/interface.h"
 #include "dml/interface.h"
+#include "qnn/interface.h"
+#include "webgpu/interface.h"
 
 #if defined(_WIN32)
 EXTERN_C IMAGE_DOS_HEADER __ImageBase;
@@ -187,6 +189,10 @@ DeviceInterface* GetDeviceInterface(DeviceType type) {
     case DeviceType::DML:
       return GetDmlInterface();
 #endif
+    case DeviceType::WEBGPU:
+      return GetWebGPUInterface();
+    case DeviceType::QNN:
+      return GetQNNInterface();
   }
 }
 
diff --git a/src/models/adapters.cpp b/src/models/adapters.cpp
index 3840c9b88..5a95ebdd9 100644
--- a/src/models/adapters.cpp
+++ b/src/models/adapters.cpp
@@ -34,9 +34,9 @@ void Adapters::LoadAdapter(const char* adapter_file_path, const std::string& ada
   }
 
   adapters_.emplace(adapter_name, std::make_unique<Adapter>(adapter_file_path,
-                                                            model_->allocator_device_ == &model_->allocator_cpu_
-                                                                ? nullptr
-                                                                : model_->allocator_device_));
+                                                            model_->device_type_ == DeviceType::CUDA
+                                                                ? &model_->p_device_->GetAllocator()
+                                                                : nullptr));
 }
 
 void Adapters::UnloadAdapter(const std::string& adapter_name) {
diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp
index b18fea179..c8c43c2b1 100644
--- a/src/models/decoder_only_pipeline.cpp
+++ b/src/models/decoder_only_pipeline.cpp
@@ -12,7 +12,7 @@ DecoderOnlyPipelineModel::DecoderOnlyPipelineModel(std::unique_ptr<Config> confi
     sessions_.emplace_back(OrtSession::Create(ort_env, (config_->config_path / fs::path(model.filename)).c_str(),
                                               GetSessionOptions(model.model_id)));
 
-    if (!allocator_device_ && model.session_options.has_value()) {
+    if (!p_device_inputs_ && model.session_options.has_value()) {
       const auto& provider_options = (*model.session_options).provider_options;
       if (std::any_of(provider_options.begin(), provider_options.end(),
                       [](const auto& elem) { return !elem.name.empty(); })) {
@@ -21,7 +21,7 @@ DecoderOnlyPipelineModel::DecoderOnlyPipelineModel(std::unique_ptr<Config> confi
     }
   }
 
-  if (!allocator_device_) {
+  if (!p_device_inputs_) {
     // If the device allocator has not been created, it implies all
     // sessions are configured to run on CPU.
     // Pick any session to create the device allocator.
diff --git a/src/models/embeddings.cpp b/src/models/embeddings.cpp
index 01b494078..4b38db207 100644
--- a/src/models/embeddings.cpp
+++ b/src/models/embeddings.cpp
@@ -25,7 +25,7 @@ Embeddings::Embeddings(State& state, Embeddings::Mode mode, const std::string& n
       sb_embeddings_ = state_.GetCapturedGraphInfo()->sb_embeddings_.get();
     }
 
-    embeddings_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
+    embeddings_ = OrtValue::CreateTensor(model_.p_device_->GetAllocator(), shape_, type_);
   }
 }
 
@@ -54,7 +54,7 @@ void Embeddings::UpdateSequenceLength(size_t new_length) {
 
     if (mode_ == Embeddings::Mode::Input) {
       if (!sb_embeddings_) {
-        embeddings_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
+        embeddings_ = OrtValue::CreateTensor(model_.p_device_->GetAllocator(), shape_, type_);
       } else {
         embeddings_ = sb_embeddings_->CreateTensorOnStaticBuffer(shape_, type_);
       }
diff --git a/src/models/image_features.cpp b/src/models/image_features.cpp
index 71dee99fa..6f51abd0f 100644
--- a/src/models/image_features.cpp
+++ b/src/models/image_features.cpp
@@ -26,7 +26,7 @@ ImageFeatures::ImageFeatures(State& state, ImageFeatures::Mode mode, const std::
   // 4) Created as an input for embedding model (num_image_tokens = 0)
   //    The tensor does not need to be pre-allocated because it will be created during (2).
   if (mode == ImageFeatures::Mode::Output) {
-    image_features_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
+    image_features_ = OrtValue::CreateTensor(model_.p_device_->GetAllocator(), shape_, type_);
   }
 }
 
@@ -50,7 +50,7 @@ void ImageFeatures::Update(bool is_prompt) {
   // num_image_tokens will be 0 when no image is provided
   if (!is_prompt && shape_[0] > 0) {  // if num_image_tokens > 0
     shape_[0] = 0;
-    image_features_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
+    image_features_ = OrtValue::CreateTensor(model_.p_device_->GetAllocator(), shape_, type_);
     state_.inputs_[index_] = image_features_.get();
   }
 }
diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp
index a3a988608..6a32a6233 100644
--- a/src/models/input_ids.cpp
+++ b/src/models/input_ids.cpp
@@ -71,12 +71,12 @@ void DefaultInputIDs::Update(DeviceSpan<int32_t> new_tokens) {
 
   if (static_cast<size_t>(shape_[1]) != sequence_length) {
     shape_[1] = sequence_length;
-    value_ = OrtValue::CreateTensor<int32_t>(*model_.allocator_device_, shape_);
+    value_ = OrtValue::CreateTensor<int32_t>(model_.p_device_inputs_->GetAllocator(), shape_);
     state_.inputs_[input_index_] = value_.get();
   }
 
   // Update input_ids with next tokens
-  auto data_span = WrapTensor<int32_t>(*model_.p_device_, *value_);
+  auto data_span = WrapTensor<int32_t>(*model_.p_device_inputs_, *value_);
 
   // For beam search
   if (is_prompt_ && state_.params_->search.num_beams > 1) {
@@ -91,7 +91,7 @@ void DefaultInputIDs::Update(DeviceSpan<int32_t> new_tokens) {
   }
 
   if (type_ == Ort::TypeToTensorType<int64_t>) {
-    Cast(*value_, cast_value_, *model_.p_device_, type_);
+    Cast(*value_, cast_value_, *model_.p_device_inputs_, type_);
     state_.inputs_[input_index_] = cast_value_.get();
   }
 
diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp
index 51210ae4e..994e883ce 100644
--- a/src/models/kv_cache.cpp
+++ b/src/models/kv_cache.cpp
@@ -46,11 +46,11 @@ CombinedKeyValueCache::CombinedKeyValueCache(State& state)
   // Derive the KV data type from the KV input 0
   type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]);
 
-  empty_past_ = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_);
+  empty_past_ = OrtValue::CreateTensor(Allocator(), shape_, type_);
   shape_[3] = 0;
 
   for (int i = 0; i < layer_count_; ++i) {
-    presents_.push_back(OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_));
+    presents_.push_back(OrtValue::CreateTensor(Allocator(), shape_, type_));
   }
 }
 
@@ -82,7 +82,7 @@ void CombinedKeyValueCache::Update(DeviceSpan<int32_t> beam_indices, int total_l
 
   shape_[3] = total_length;
   for (int i = 0; i < layer_count_; i++) {
-    presents_[i] = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_);
+    presents_[i] = OrtValue::CreateTensor(Allocator(), shape_, type_);
     state_.outputs_[output_index_ + i] = presents_[i].get();
   }
 
@@ -119,9 +119,9 @@ void CombinedKeyValueCache::RewindPastTensorsTo(size_t index) {
 
   for (int i = 0; i < layer_count_; i++) {
     OrtValue& present = *presents_[i];
-    std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_);
-    auto present_span = WrapTensor<T>(*model_.p_device_, present);
-    auto past_span = WrapTensor<T>(*model_.p_device_, *past);
+    std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(Allocator(), shape_, type_);
+    auto present_span = WrapTensor<T>(Device(), present);
+    auto past_span = WrapTensor<T>(Device(), *past);
 
     for (int j = 0; j < 2 * batch_x_num_heads; j++) {
       auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size);
@@ -141,10 +141,10 @@ void CombinedKeyValueCache::PickPastState(DeviceSpan<int32_t> beam_indices_devic
   auto past_key_size = shape_[1] * block_size_per_beam;
 
   OrtValue& present = *presents_[index];
-  std::unique_ptr<OrtValue> past = OrtValue::CreateTensor<ScoreType>(*model_.allocator_kvcache_, shape_);
+  std::unique_ptr<OrtValue> past = OrtValue::CreateTensor<ScoreType>(Allocator(), shape_);
 
-  auto past_span = WrapTensor<ScoreType>(*model_.p_device_, *past);
-  auto present_span = WrapTensor<ScoreType>(*model_.p_device_, present);
+  auto past_span = WrapTensor<ScoreType>(Device(), *past);
+  auto present_span = WrapTensor<ScoreType>(Device(), present);
 
   for (size_t j = 0; j < beam_indices.size(); j++) {
     int32_t beam_index = beam_indices[j];
@@ -190,7 +190,7 @@ DefaultKeyValueCache::DefaultKeyValueCache(State& state)
   // Derive the KV data type from the KV input 0
   type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]);
 
-  empty_past_ = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_);
+  empty_past_ = OrtValue::CreateTensor(Allocator(), shape_, type_);
 
   // Set the size after empty_past_ has been created with 0 for this field
   if (past_present_share_buffer_) {
@@ -207,10 +207,10 @@ DefaultKeyValueCache::DefaultKeyValueCache(State& state)
   try {
     for (int i = 0; i < layer_count_ * 2; ++i) {
       presents_.push_back(
-          sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_)
+          sb_kv_caches_.empty() ? OrtValue::CreateTensor(Allocator(), shape_, type_)
                                 : sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_));
       // Zero the memory so we don't leak any data from the previous run
-      ByteWrapTensor(*model_.p_device_, *presents_.back()).Zero();
+      ByteWrapTensor(Device(), *presents_.back()).Zero();
     }
   } catch (const Ort::Exception&) {
     std::ostringstream oss;
@@ -269,7 +269,7 @@ void DefaultKeyValueCache::Update(DeviceSpan<int32_t> beam_indices, int total_le
 
   shape_[2] = total_length;
   for (int i = 0; i < layer_count_ * 2; i++) {
-    presents_[i] = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_);
+    presents_[i] = OrtValue::CreateTensor(Allocator(), shape_, type_);
     state_.outputs_[output_index_ + i] = presents_[i].get();
   }
 
@@ -308,10 +308,10 @@ void DefaultKeyValueCache::RewindPastTensorsTo(size_t index) {
 
   for (int i = 0; i < layer_count_ * 2; i++) {
     OrtValue& present = *presents_[i];
-    std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_);
+    std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(Allocator(), shape_, type_);
 
-    auto past_span = WrapTensor<T>(*model_.p_device_, *past);
-    auto present_span = WrapTensor<T>(*model_.p_device_, present);
+    auto past_span = WrapTensor<T>(Device(), *past);
+    auto present_span = WrapTensor<T>(Device(), present);
 
     for (int j = 0; j < batch_x_num_heads; j++) {
       auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size);
@@ -330,10 +330,10 @@ void DefaultKeyValueCache::PickPastState(DeviceSpan<int32_t> beam_indices_device
   auto block_size_per_beam = shape_[1] * shape_[2] * shape_[3];
 
   OrtValue& present_value = *presents_[index];
-  std::unique_ptr<OrtValue> past_value = OrtValue::CreateTensor<ScoreType>(*model_.allocator_kvcache_, shape_);
+  std::unique_ptr<OrtValue> past_value = OrtValue::CreateTensor<ScoreType>(Allocator(), shape_);
 
-  auto past_span = WrapTensor<ScoreType>(*model_.p_device_, *past_value);
-  auto present_span = WrapTensor<ScoreType>(*model_.p_device_, present_value);
+  auto past_span = WrapTensor<ScoreType>(Device(), *past_value);
+  auto present_span = WrapTensor<ScoreType>(Device(), present_value);
 
   for (size_t j = 0; j < beam_indices.size(); j++) {
     int32_t beam_index = beam_indices[j];
@@ -371,8 +371,8 @@ CrossCache::CrossCache(State& state)
   type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]);
 
   for (int i = 0; i < layer_count_; ++i) {
-    values_.push_back(OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_));
-    values_.push_back(OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_));
+    values_.push_back(OrtValue::CreateTensor(Allocator(), shape_, type_));
+    values_.push_back(OrtValue::CreateTensor(Allocator(), shape_, type_));
   }
 }
 
@@ -422,21 +422,21 @@ WindowedKeyValueCache::WindowedKeyValueCache(State& state)
 
   for (int i = 0; i < layer_count_; ++i) {
     key_caches_in_.push_back(
-        OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_in_, type_));
+        OrtValue::CreateTensor(Allocator(), key_cache_shape_in_, type_));
     std::fill_n(key_caches_in_[i]->GetTensorMutableData<uint8_t>(),
                 ElementCountFromShape(key_cache_shape_in_),
                 static_cast<uint8_t>(model_.config_->model.decoder.sliding_window->pad_value));
 
     value_caches_in_.push_back(
-        OrtValue::CreateTensor(*model_.allocator_device_, value_cache_shape_in_, type_));
+        OrtValue::CreateTensor(Allocator(), value_cache_shape_in_, type_));
     std::fill_n(value_caches_in_[i]->GetTensorMutableData<uint8_t>(),
                 ElementCountFromShape(value_cache_shape_in_),
                 static_cast<uint8_t>(model_.config_->model.decoder.sliding_window->pad_value));
 
     key_caches_out_.push_back(
-        OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_out_, type_));
+        OrtValue::CreateTensor(Allocator(), key_cache_shape_out_, type_));
     value_caches_out_.push_back(
-        OrtValue::CreateTensor(*model_.allocator_device_, value_cache_shape_out_, type_));
+        OrtValue::CreateTensor(Allocator(), value_cache_shape_out_, type_));
   }
 }
 
@@ -548,7 +548,7 @@ void WindowedKeyValueCache::Update(DeviceSpan<int32_t> beam_indices, int current
 
   ThreadPool thread_pool{static_cast<size_t>(layer_count_)};
   thread_pool.Compute([&](size_t layer_idx) {
-    std::unique_ptr<OrtValue> key_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_in, type_);
+    std::unique_ptr<OrtValue> key_cache = OrtValue::CreateTensor(Allocator(), updated_key_cache_shape_in, type_);
 
     uint8_t* key_cache_data = key_cache->GetTensorMutableData<uint8_t>();
     uint8_t* key_cache_in_data = key_caches_in_[layer_idx]->GetTensorMutableData<uint8_t>();
@@ -574,9 +574,9 @@ void WindowedKeyValueCache::Update(DeviceSpan<int32_t> beam_indices, int current
     }
 
     key_caches_in_[layer_idx] = std::move(key_cache);
-    key_caches_out_[layer_idx] = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_out, type_);
+    key_caches_out_[layer_idx] = OrtValue::CreateTensor(Allocator(), updated_key_cache_shape_out, type_);
 
-    std::unique_ptr<OrtValue> value_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_in, type_);
+    std::unique_ptr<OrtValue> value_cache = OrtValue::CreateTensor(Allocator(), updated_value_cache_shape_in, type_);
 
     uint8_t* value_cache_data = value_cache->GetTensorMutableData<uint8_t>();
     uint8_t* value_cache_in_data = value_caches_in_[layer_idx]->GetTensorMutableData<uint8_t>();
@@ -602,7 +602,7 @@ void WindowedKeyValueCache::Update(DeviceSpan<int32_t> beam_indices, int current
     }
 
     value_caches_in_[layer_idx] = std::move(value_cache);
-    value_caches_out_[layer_idx] = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_out, type_);
+    value_caches_out_[layer_idx] = OrtValue::CreateTensor(Allocator(), updated_value_cache_shape_out, type_);
   });
 
   window_size_ = 1;
diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h
index 0e871d938..588e52045 100644
--- a/src/models/kv_cache.h
+++ b/src/models/kv_cache.h
@@ -30,6 +30,9 @@ struct CombinedKeyValueCache : KeyValueCache {
   template <typename T>
   void RewindPastTensorsTo(size_t index);
 
+  DeviceInterface& Device() { return *model_.p_device_kvcache_; }
+  Ort::Allocator& Allocator() { return model_.p_device_kvcache_->GetAllocator(); }
+
   State& state_;
   const Model& model_{state_.model_};
   int layer_count_;
@@ -64,6 +67,9 @@ struct DefaultKeyValueCache : KeyValueCache {
   template <typename T>
   void RewindPastTensorsTo(size_t index);
 
+  DeviceInterface& Device() { return *model_.p_device_kvcache_; }
+  Ort::Allocator& Allocator() { return model_.p_device_kvcache_->GetAllocator(); }
+
   State& state_;
   const Model& model_{state_.model_};
   int layer_count_;
@@ -89,6 +95,9 @@ struct CrossCache {
   void AddInputs();
 
  private:
+  DeviceInterface& Device() { return *model_.p_device_kvcache_; }
+  Ort::Allocator& Allocator() { return model_.p_device_kvcache_->GetAllocator(); }
+
   State& state_;
   const Model& model_{state_.model_};
   int layer_count_;
@@ -113,6 +122,10 @@ struct WindowedKeyValueCache : KeyValueCache {
   }
 
  private:
+
+  DeviceInterface& Device() { return *model_.p_device_kvcache_; }
+  Ort::Allocator& Allocator() { return model_.p_device_kvcache_->GetAllocator(); }
+
   void Slide();
 
   State& state_;
diff --git a/src/models/logits.cpp b/src/models/logits.cpp
index 4a7ac7f6a..48bdc4f32 100644
--- a/src/models/logits.cpp
+++ b/src/models/logits.cpp
@@ -10,11 +10,11 @@ Logits::Logits(State& state)
     : state_{state},
       shape_{static_cast<int64_t>(state_.params_->BatchBeamSize()), 0, model_.config_->model.vocab_size},
       type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} {
-  output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
+  output_raw_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_);
 
   if (model_.device_type_ == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) {
     auto& cpu_ids = model_.config_->model.eos_token_ids;
-    cuda_eos_token_ids_ = state_.params_->p_device->Allocate<int32_t>(cpu_ids.size());
+    cuda_eos_token_ids_ = model_.p_device_->Allocate<int32_t>(cpu_ids.size());
     copy(std::span<const int32_t>{cpu_ids}, cuda_eos_token_ids_.CpuSpan());
     cuda_eos_token_ids_.CopyCpuToDevice();
   }
@@ -34,18 +34,18 @@ DeviceSpan<float> Logits::Get() {
     const size_t num_beams = state_.params_->search.num_beams;
 
     // create new OrtValue for logits_of_last_token and use output_last_tokens_ to hold it
-    output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_last, type_);
+    output_last_tokens_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_last, type_);
 
     if (type_ == Ort::TypeToTensorType<Ort::Float16_t>)
-      logits_of_last_token_fp32_ = OrtValue::CreateTensor<float>(*model_.allocator_device_, shape_);
+      logits_of_last_token_fp32_ = OrtValue::CreateTensor<float>(model_.p_device_inputs_->GetAllocator(), shape_);
 
     logits_of_last_token = output_last_tokens_.get();
 
     size_t element_size = SizeOf(type_);
     size_t vocab_index = 0;  // Simpler math to have this index go up by vocab_size for every logit chunk we process
 
-    auto logits_raw = ByteWrapTensor(*state_.params_->p_device, *output_raw_);
-    auto logits_last_tokens = ByteWrapTensor(*state_.params_->p_device, *logits_of_last_token);
+    auto logits_raw = ByteWrapTensor(*model_.p_device_inputs_, *output_raw_);
+    auto logits_last_tokens = ByteWrapTensor(*model_.p_device_inputs_, *logits_of_last_token);
 
     for (int batch_index = 0; batch_index < state_.params_->search.batch_size; batch_index++) {
       // Find the first non pad token from the end
@@ -63,12 +63,12 @@ DeviceSpan<float> Logits::Get() {
 
   // Convert from float16 to float32 if necessary
   if (type_ == Ort::TypeToTensorType<Ort::Float16_t>) {
-    Cast(*logits_of_last_token, logits_of_last_token_fp32_, *model_.p_device_, Ort::TypeToTensorType<float>);
+    Cast(*logits_of_last_token, logits_of_last_token_fp32_, *model_.p_device_inputs_, Ort::TypeToTensorType<float>);
     logits_of_last_token = logits_of_last_token_fp32_.get();
   }
 
   if (logits_.empty() || logits_of_last_token->GetTensorMutableRawData() != logits_.Span().data())
-    logits_ = WrapTensor<float>(*state_.params_->p_device, *logits_of_last_token);
+    logits_ = WrapTensor<float>(*model_.p_device_inputs_, *logits_of_last_token);
 
   if (model_.device_type_ == DeviceType::CUDA) {
     if (!cuda_eos_token_ids_.empty())
@@ -108,7 +108,7 @@ void Logits::Update(const DeviceSpan<int32_t>& next_tokens, size_t new_kv_length
 
   shape_[1] = new_kv_length;
   StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType<Ort::Float16_t> ? sb_logits16_ : sb_logits32_;
-  output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_)
+  output_raw_ = !sb_logits ? OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_)
                            : sb_logits->CreateTensorOnStaticBuffer(shape_, type_);
 
   if (state_.GetCapturedGraphInfo()) {
diff --git a/src/models/model.cpp b/src/models/model.cpp
index e3387f9cb..ad454f6c4 100644
--- a/src/models/model.cpp
+++ b/src/models/model.cpp
@@ -221,24 +221,24 @@ int32_t Tokenizer::TokenToTokenId(const char* token) const {
 // the allocator used is not destroyed until last. This keeps the allocator around until exit, after all other memory
 // has been destroyed. Without this, we will crash in the Onnxruntime BFCArena code when deleting tensors due to the
 // arena already being destroyed.
-Ort::Allocator* GetDeviceAllocator(OrtSession& session, DeviceType type) {
-  // CPU Allocator is a special case, so don't try to create it
+void EnsureDeviceOrtInit(OrtSession& session, DeviceType type) {
+  // CPU Allocator is a special case, it's not in the owned 'allocator_device_' table below so we handle it separately
   if (type == DeviceType::CPU)
-    return &GetDeviceInterface(DeviceType::CPU)->GetAllocator();
+    return;
 
   auto& device = GetOrtGlobals()->allocator_device_[static_cast<int>(type)];
-  if (!device) {
-    static const char* device_type_names[] = {"CPU (Not used, see above)", "Cuda", "DML", "WebGPU_Buffer", "QNN (Not used, uses CPU memory)"};
-    static_assert(std::size(device_type_names) == static_cast<size_t>(DeviceType::MAX));
+  if (device)
+    return;
 
-    auto name = device_type_names[static_cast<int>(type)];
-    auto memory_info = OrtMemoryInfo::Create(name, OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
-    device = Ort::Allocator::Create(session, *memory_info);
-    if (!device)
-      throw std::runtime_error("Unexpected failure to create device memory allocator for " + std::string(name));
-    GetDeviceInterface(type)->InitOrt(*Ort::api, *device);  // Necessary for any shared library providers so they can access Ort::api
-  }
-  return device.get();
+  static const char* device_type_names[] = {"CPU (Not used, see above)", "Cuda", "DML", "WebGPU_Buffer", "QnnHtpShared"};
+  static_assert(std::size(device_type_names) == static_cast<size_t>(DeviceType::MAX));
+
+  auto name = device_type_names[static_cast<int>(type)];
+  auto memory_info = OrtMemoryInfo::Create(name, OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
+  device = Ort::Allocator::Create(session, *memory_info);
+  if (!device)
+    throw std::runtime_error("Unexpected failure to create device memory allocator for " + std::string(name));
+  GetDeviceInterface(type)->InitOrt(*Ort::api, *device);  // Necessary for any shared library providers so they can access Ort::api
 }
 
 SessionInfo::SessionInfo(OrtSession& session) {
@@ -301,18 +301,19 @@ Model::Model(std::unique_ptr<Config> config) : config_{std::move(config)} {
 Model::~Model() = default;
 
 void Model::InitDeviceAllocator(OrtSession& session) {
-  allocator_device_ = &allocator_cpu_;
+  EnsureDeviceOrtInit(session, device_type_);
+
+  // Only CUDA does every input on the device
   if (device_type_ == DeviceType::CUDA)
-    allocator_device_ = GetDeviceAllocator(session, device_type_);
+    p_device_inputs_ = p_device_;
+  else
+    p_device_inputs_ = GetDeviceInterface(DeviceType::CPU);
 
-  allocator_kvcache_ = allocator_device_;
-  if (device_type_ == DeviceType::WEBGPU || device_type_ == DeviceType::DML) {
-    // for dml and webgpu we only use device memory for kv_cache
-    allocator_kvcache_ = GetDeviceAllocator(session, device_type_);
-  }
+  // The kvcache is always allocated in device memory
+  p_device_kvcache_ = p_device_;
 
   session_info_ = std::make_unique<SessionInfo>(session);
-  captured_graph_pool_ = std::make_shared<CapturedGraphPool>(config_.get(), session_info_.get(), allocator_device_);
+  captured_graph_pool_ = std::make_shared<CapturedGraphPool>(config_.get(), session_info_.get(), &p_device_->GetAllocator());
 }
 
 void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_session_options,
@@ -487,7 +488,6 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
       throw std::runtime_error("Unknown provider type: " + provider_options.name);
   }
 
-  // If no device is set, create it, default to CPU
   if (!p_device_) {
     p_device_ = GetDeviceInterface(device_type_);
   }
@@ -495,10 +495,6 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
 
 void Model::CreateSessionOptions() {
   session_options_ = OrtSessionOptions::Create();
-#if 0
-  ClearProviders(*config_);
-  SetProviderOption(*config_, "dml", {}, {});
-#endif
 
   CreateSessionOptionsFromConfig(config_->model.decoder.session_options, *session_options_, true, false);
 
@@ -593,10 +589,9 @@ std::unique_ptr<OrtValue> Model::ExpandInputs(std::unique_ptr<OrtValue>& input,
 
   input_shape[0] *= num_beams;
 
-  auto& allocator = device_type_ == DeviceType::DML ? allocator_cpu_ : *allocator_device_;
-  auto expanded = OrtValue::CreateTensor(allocator, input_shape, element_type);
   auto input_span = ByteWrapTensor(*GetDeviceInterface(DeviceType::CPU), *input);
-  auto expanded_span = ByteWrapTensor(*p_device_, *expanded);
+  auto expanded = OrtValue::CreateTensor(p_device_inputs_->GetAllocator(), input_shape, element_type);
+  auto expanded_span = ByteWrapTensor(*p_device_inputs_, *expanded);
 
   for (int i = 0; i < batch_size; i++) {
     for (int j = 0; j < num_beams; j++) {
diff --git a/src/models/model.h b/src/models/model.h
index 9635e79d2..ead0648a7 100644
--- a/src/models/model.h
+++ b/src/models/model.h
@@ -136,11 +136,12 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
   std::unique_ptr<Config> config_;
   std::unique_ptr<OrtSessionOptions> session_options_;
 
-  mutable DeviceInterface* p_device_{};
   DeviceType device_type_{DeviceType::CPU};
-  Ort::Allocator& allocator_cpu_{Ort::Allocator::GetWithDefaultOptions()};
-  Ort::Allocator* allocator_device_{};   // Can be CUDA or CPU based on the DeviceType in the model
-  Ort::Allocator* allocator_kvcache_{};  // keep allocator for kv_cache seperate to allow that only kv_cache is on device
+  mutable DeviceInterface* p_device_{};  // The device we're running on (matches device_type_) used for things that work the same on all devices
+  mutable DeviceInterface* p_device_inputs_{};   // For some model inputs, the device might be the CPU device (all but KV cache currently)
+  mutable DeviceInterface* p_device_kvcache_{};  // The kvcache is always allocated in device memory  (TODO: Remove in favor of just p_device_?)
+
+  Ort::Allocator& allocator_cpu_{GetDeviceInterface(DeviceType::CPU)->GetAllocator()};
 
   std::unique_ptr<SessionInfo> session_info_;
 
diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp
index 76372bd54..8546cfb8f 100644
--- a/src/models/position_inputs.cpp
+++ b/src/models/position_inputs.cpp
@@ -119,13 +119,13 @@ void DefaultPositionInputs::CreateNextPositionIDsTensor() {
       position_ids_ = std::move(position_ids_next_);
       position_ids_next_ = nullptr;
     } else {
-      position_ids_ = OrtValue::CreateTensor(*model_.allocator_device_, position_ids_shape_, type_);
+      position_ids_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), position_ids_shape_, type_);
     }
   } else {
     position_ids_ = sb_position_ids_->CreateTensorOnStaticBuffer(position_ids_shape_, type_);
     if (position_ids_shape_[1] == 1) {
-      auto position_ids_span = ByteWrapTensor(*model_.p_device_, *position_ids_);
-      auto position_ids_next_span = ByteWrapTensor(*model_.p_device_, *position_ids_next_);
+      auto position_ids_span = ByteWrapTensor(*model_.p_device_inputs_, *position_ids_);
+      auto position_ids_next_span = ByteWrapTensor(*model_.p_device_inputs_, *position_ids_next_);
       position_ids_span.CopyFrom(position_ids_next_span);
     }
   }
@@ -143,7 +143,7 @@ void DefaultPositionInputs::UpdatePositionIDs(int total_length, int new_kv_lengt
   }
 
   if (model_.device_type_ == DeviceType::CUDA)
-    model_.p_device_->UpdatePositionIds(position_ids_->GetTensorMutableRawData(), static_cast<int>(position_ids_shape_[0]), total_length, new_kv_length, type_);
+    model_.p_device_inputs_->UpdatePositionIds(position_ids_->GetTensorMutableRawData(), static_cast<int>(position_ids_shape_[0]), total_length, new_kv_length, type_);
   else {
     type_ == Ort::TypeToTensorType<int32_t> ? UpdatePositionIDsImpl<int32_t>(total_length, new_kv_length)
                                             : UpdatePositionIDsImpl<int64_t>(total_length, new_kv_length);
@@ -153,12 +153,12 @@ void DefaultPositionInputs::UpdatePositionIDs(int total_length, int new_kv_lengt
 void DefaultPositionInputs::CreateNextAttentionMaskTensor(int total_length) {
   if (!sb_attention_mask_) {
     attention_mask_shape_[1] = total_length;
-    attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_);
+    attention_mask_next_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), attention_mask_shape_, type_);
   } else {
     attention_mask_shape_[1] = state_.params_->search.max_length;
     attention_mask_next_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_);
     if (is_first_mask_update_) {
-      ByteWrapTensor(*model_.p_device_, *attention_mask_next_).Zero();
+      ByteWrapTensor(*model_.p_device_inputs_, *attention_mask_next_).Zero();
     }
   }
 }
@@ -175,7 +175,7 @@ void DefaultPositionInputs::UpdateAttentionMask(int total_length, int new_kv_len
   if (model_.device_type_ == DeviceType::CUDA) {
     int max_length = sb_attention_mask_ ? state_.params_->search.max_length : total_length;
     bool update_only = sb_attention_mask_ && !is_first_mask_update_;
-    model_.p_device_->UpdateAttentionMask(attention_mask_next_->GetTensorMutableRawData(),
+    model_.p_device_inputs_->UpdateAttentionMask(attention_mask_next_->GetTensorMutableRawData(),
                                           attention_mask_->GetTensorRawData(),
                                           static_cast<int>(attention_mask_shape_[0]),
                                           new_kv_length,
diff --git a/src/models/whisper.cpp b/src/models/whisper.cpp
index 6c46f218d..6bb64df65 100644
--- a/src/models/whisper.cpp
+++ b/src/models/whisper.cpp
@@ -63,7 +63,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, DeviceSpan<int32_t> seq
 
   auto hidden_states_type = model_.session_info_->GetOutputDataType("encoder_hidden_states");
   auto encoder_hidden_states_shape = std::array<int64_t, 3>{decoder_input_ids_.GetShape()[0], 1500, static_cast<int64_t>(model_.config_->model.decoder.num_attention_heads) * model_.config_->model.decoder.head_size};
-  encoder_hidden_states_ = OrtValue::CreateTensor(*model_.allocator_device_, encoder_hidden_states_shape, hidden_states_type);
+  encoder_hidden_states_ = OrtValue::CreateTensor(model_.p_device_->GetAllocator(), encoder_hidden_states_shape, hidden_states_type);
 
   auto sequence_lengths = sequence_lengths_unk.CpuSpan();
   for (int i = 0; i < decoder_input_ids_.GetShape()[0]; i++) {
@@ -90,7 +90,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, DeviceSpan<int32_t> seq
     auto type = model_.session_info_->GetOutputDataType(output_names_[kv_cache_indices]);
 
     for (int i = 0; i < layer_count * 2; i++) {
-      init_presents_.emplace_back(OrtValue::CreateTensor(*model_.allocator_device_, shape, type));
+      init_presents_.emplace_back(OrtValue::CreateTensor(model_.p_device_->GetAllocator(), shape, type));
       presents_.emplace_back(outputs_[kv_cache_indices + i]);
       outputs_[kv_cache_indices + i] = init_presents_.back().get();
     }
@@ -268,7 +268,7 @@ DeviceSpan<float> Whisper_State::Run(int current_length, DeviceSpan<int32_t>& ne
       }
 
       if (model_.session_info_->HasInput("cache_indirection")) {
-        cache_indirection_ = OrtValue::CreateTensor<int32_t>(*model_.allocator_device_, std::array<int64_t, 3>{params_->search.batch_size, params_->search.num_beams, params_->search.max_length});
+        cache_indirection_ = OrtValue::CreateTensor<int32_t>(model_.p_device_->GetAllocator(), std::array<int64_t, 3>{params_->search.batch_size, params_->search.num_beams, params_->search.max_length});
         cache_indirection_index_ = inputs_.size();
         input_names_.push_back("cache_indirection");
         inputs_.push_back(cache_indirection_.get());
@@ -284,7 +284,7 @@ DeviceSpan<float> Whisper_State::Run(int current_length, DeviceSpan<int32_t>& ne
           char string[64];
           snprintf(string, std::size(string), "output_cross_qk_%d", i);
           output_cross_qk_names_.emplace_back(string);
-          output_cross_qk_.emplace_back(OrtValue::CreateTensor(*model_.allocator_device_, shape, type));
+          output_cross_qk_.emplace_back(OrtValue::CreateTensor(model_.p_device_->GetAllocator(), shape, type));
 
           output_names_.emplace_back(output_cross_qk_names_.back().c_str());
           outputs_.emplace_back(output_cross_qk_.back().get());
diff --git a/src/qnn/interface.cpp b/src/qnn/interface.cpp
new file mode 100644
index 000000000..3acc3b58d
--- /dev/null
+++ b/src/qnn/interface.cpp
@@ -0,0 +1,79 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "../generators.h"
+#include "../search.h"
+#include "interface.h"
+
+namespace Generators {
+namespace QNN {
+
+static Ort::Allocator* ort_allocator_{};
+const char* device_label = "qnn";
+
+struct QnnMemory final : DeviceBuffer {
+  QnnMemory(size_t size) : owned_{true} {
+    size_in_bytes_ = size;
+    p_cpu_ = p_device_ = static_cast<uint8_t*>(ort_allocator_->Alloc(size_in_bytes_));
+  }
+
+  QnnMemory(void* p, size_t size) : owned_{false} {
+    size_in_bytes_ = size;
+    p_cpu_ = p_device_ = static_cast<uint8_t*>(p);
+  }
+
+  ~QnnMemory() override {
+    if (owned_)
+      ort_allocator_->Free(p_device_);
+  }
+
+  const char* GetType() const override { return device_label; }
+  void AllocateCpu() override {}      // Nothing to do, device memory is CPU accessible
+  void CopyDeviceToCpu() override {}  // Nothing to do, device memory is CPU accessible
+  void CopyCpuToDevice() override {}  // Nothing to do, device memory is CPU accessible
+  void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override {
+    CopyThroughCpu(*this, begin_dest, source, begin_source, size_in_bytes);
+  }
+
+  void Zero() override {
+    memset(p_device_, 0, size_in_bytes_);
+  }
+
+  bool owned_;
+};
+
+struct InterfaceImpl : DeviceInterface {
+  InterfaceImpl() {
+  }
+
+  void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override {
+    assert(!ort_allocator_);
+    ort_allocator_ = &allocator;
+  }
+
+  Ort::Allocator& GetAllocator() override {
+    return *ort_allocator_;
+  }
+
+  std::shared_ptr<DeviceBuffer> AllocateBase(size_t size) override {
+    return std::make_shared<QnnMemory>(size);
+  }
+
+  std::shared_ptr<DeviceBuffer> WrapMemoryBase(void* p, size_t size) override {
+    return std::make_shared<QnnMemory>(p, size);
+  }
+
+  std::unique_ptr<Search> CreateGreedy(const GeneratorParams& params) override { return std::make_unique<GreedySearch_Cpu>(params); }
+  std::unique_ptr<Search> CreateBeam(const GeneratorParams& params) override { return std::make_unique<BeamSearch_Cpu>(params); }
+
+  void Synchronize() override {}  // Nothing to do
+};
+
+}  // namespace QNN
+
+DeviceInterface* GetQNNInterface() {
+  static std::unique_ptr<DeviceInterface> g_device = std::make_unique<QNN::InterfaceImpl>();
+  return g_device.get();
+}
+
+}  // namespace Generators
diff --git a/src/qnn/interface.h b/src/qnn/interface.h
new file mode 100644
index 000000000..fcbfe1f64
--- /dev/null
+++ b/src/qnn/interface.h
@@ -0,0 +1,8 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+namespace Generators {
+
+DeviceInterface* GetQNNInterface();
+
+}  // namespace Generators
\ No newline at end of file
diff --git a/src/webgpu/interface.cpp b/src/webgpu/interface.cpp
new file mode 100644
index 000000000..246cfed50
--- /dev/null
+++ b/src/webgpu/interface.cpp
@@ -0,0 +1,79 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "../generators.h"
+#include "../search.h"
+#include "interface.h"
+
+namespace Generators {
+namespace WebGPU {
+
+static Ort::Allocator* ort_allocator_{};
+const char* device_label = "WebGPU";
+
+struct WebGPUMemory final : DeviceBuffer {
+  WebGPUMemory(size_t size) : owned_{true} {
+    size_in_bytes_ = size;
+    p_cpu_ = p_device_ = static_cast<uint8_t*>(ort_allocator_->Alloc(size_in_bytes_));
+  }
+
+  WebGPUMemory(void* p, size_t size) : owned_{false} {
+    size_in_bytes_ = size;
+    p_cpu_ = p_device_ = static_cast<uint8_t*>(p);
+  }
+
+  ~WebGPUMemory() override {
+    if (owned_)
+      ort_allocator_->Free(p_device_);
+  }
+
+  const char* GetType() const override { return device_label; }
+  void AllocateCpu() override { throw std::runtime_error("CPU can't access WebGPU memory"); }
+  void CopyDeviceToCpu() override { throw std::runtime_error("CPU can't access WebGPU memory"); }
+  void CopyCpuToDevice() override { throw std::runtime_error("CPU can't access WebGPU memory"); }
+  void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override {
+    throw std::runtime_error("CPU can't access WebGPU memory");
+  }
+
+  void Zero() override {
+    throw std::runtime_error("Zeroing not implemented for WebGPU memory");
+  }
+
+  bool owned_;
+};
+
+struct InterfaceImpl : DeviceInterface {
+  InterfaceImpl() {
+  }
+
+  void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override {
+    assert(!ort_allocator_);
+    ort_allocator_ = &allocator;
+  }
+
+  Ort::Allocator& GetAllocator() override {
+    return *ort_allocator_;
+  }
+
+  std::shared_ptr<DeviceBuffer> AllocateBase(size_t size) override {
+    return std::make_shared<WebGPUMemory>(size);
+  }
+
+  std::shared_ptr<DeviceBuffer> WrapMemoryBase(void* p, size_t size) override {
+    return std::make_shared<WebGPUMemory>(p, size);
+  }
+
+  std::unique_ptr<Search> CreateGreedy(const GeneratorParams& params) override { return std::make_unique<GreedySearch_Cpu>(params); }
+  std::unique_ptr<Search> CreateBeam(const GeneratorParams& params) override { return std::make_unique<BeamSearch_Cpu>(params); }
+
+  void Synchronize() override {}  // Nothing to do?
+};
+
+}  // namespace WebGPU
+
+DeviceInterface* GetWebGPUInterface() {
+  static std::unique_ptr<DeviceInterface> g_device = std::make_unique<WebGPU::InterfaceImpl>();
+  return g_device.get();
+}
+
+}  // namespace Generators
diff --git a/src/webgpu/interface.h b/src/webgpu/interface.h
new file mode 100644
index 000000000..204b4dfed
--- /dev/null
+++ b/src/webgpu/interface.h
@@ -0,0 +1,8 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+namespace Generators {
+
+DeviceInterface* GetWebGPUInterface();
+
+}  // namespace Generators
\ No newline at end of file

From f8ed9ce52a9ac409b689eee69d42f851edccca26 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Wed, 29 Jan 2025 18:16:08 -0800
Subject: [PATCH 25/31] Previous change also added device interfaces for webgpu
 & qnn Lint

---
 src/models/kv_cache.h          |  1 -
 src/models/model.h             |  2 +-
 src/models/position_inputs.cpp | 14 +++++++-------
 3 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h
index 588e52045..9b8fe8e83 100644
--- a/src/models/kv_cache.h
+++ b/src/models/kv_cache.h
@@ -122,7 +122,6 @@ struct WindowedKeyValueCache : KeyValueCache {
   }
 
  private:
-
   DeviceInterface& Device() { return *model_.p_device_kvcache_; }
   Ort::Allocator& Allocator() { return model_.p_device_kvcache_->GetAllocator(); }
 
diff --git a/src/models/model.h b/src/models/model.h
index ead0648a7..91b202c42 100644
--- a/src/models/model.h
+++ b/src/models/model.h
@@ -137,7 +137,7 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
   std::unique_ptr<OrtSessionOptions> session_options_;
 
   DeviceType device_type_{DeviceType::CPU};
-  mutable DeviceInterface* p_device_{};  // The device we're running on (matches device_type_) used for things that work the same on all devices
+  mutable DeviceInterface* p_device_{};          // The device we're running on (matches device_type_) used for things that work the same on all devices
   mutable DeviceInterface* p_device_inputs_{};   // For some model inputs, the device might be the CPU device (all but KV cache currently)
   mutable DeviceInterface* p_device_kvcache_{};  // The kvcache is always allocated in device memory  (TODO: Remove in favor of just p_device_?)
 
diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp
index 8546cfb8f..13ab0f0f0 100644
--- a/src/models/position_inputs.cpp
+++ b/src/models/position_inputs.cpp
@@ -176,13 +176,13 @@ void DefaultPositionInputs::UpdateAttentionMask(int total_length, int new_kv_len
     int max_length = sb_attention_mask_ ? state_.params_->search.max_length : total_length;
     bool update_only = sb_attention_mask_ && !is_first_mask_update_;
     model_.p_device_inputs_->UpdateAttentionMask(attention_mask_next_->GetTensorMutableRawData(),
-                                          attention_mask_->GetTensorRawData(),
-                                          static_cast<int>(attention_mask_shape_[0]),
-                                          new_kv_length,
-                                          total_length,
-                                          max_length,
-                                          update_only,
-                                          type_);
+                                                 attention_mask_->GetTensorRawData(),
+                                                 static_cast<int>(attention_mask_shape_[0]),
+                                                 new_kv_length,
+                                                 total_length,
+                                                 max_length,
+                                                 update_only,
+                                                 type_);
   } else {
     type_ == Ort::TypeToTensorType<int32_t> ? UpdateAttentionMaskImpl<int32_t>(total_length)
                                             : UpdateAttentionMaskImpl<int64_t>(total_length);

From 198e8f8258f8083b9b8e443db5de9ec57f12549a Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Wed, 29 Jan 2025 22:40:37 -0800
Subject: [PATCH 26/31] Remove accidental change

---
 src/csharp/NativeMethods.cs | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/src/csharp/NativeMethods.cs b/src/csharp/NativeMethods.cs
index e53da7421..4dd11169e 100644
--- a/src/csharp/NativeMethods.cs
+++ b/src/csharp/NativeMethods.cs
@@ -17,8 +17,7 @@ internal class NativeLib
             // define the library name required for iOS
             internal const string DllName = "__Internal";
 #else
-//          internal const string DllName = "C:\\code\\onnxruntime-genai2\\build\\Windows\\Debug\\Debug\\onnxruntime-genai";
-            internal const string DllName = "C:\\code\\onnxruntime-genai2\\build\\Windows\\RelWithDebInfo\\RelWithDebInfo\\onnxruntime-genai";
+            internal const string DllName = "onnxruntime-genai";
 #endif
         }
 

From e6b77f22967cb28848e5f9816b9d6d8e29391151 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Thu, 30 Jan 2025 01:24:37 -0800
Subject: [PATCH 27/31] Device check simplifications

---
 src/models/model.cpp           | 7 +++----
 src/models/position_inputs.cpp | 2 --
 2 files changed, 3 insertions(+), 6 deletions(-)

diff --git a/src/models/model.cpp b/src/models/model.cpp
index ad454f6c4..e245c9786 100644
--- a/src/models/model.cpp
+++ b/src/models/model.cpp
@@ -575,11 +575,10 @@ std::unique_ptr<OrtValue> Model::ExpandInputs(std::unique_ptr<OrtValue>& input,
   // Input shape (batch_size, sequence_length). The input is required with data type T.
   // Output shape (batch_size * num_beams, sequence_length)
 
-  // If we're on CUDA, we still want to do the copy to move the data over to CUDA memory where we will read from it later.
-  // DML doesn't currently support on-device scoring, so we go the same route as the CPU
-  if (num_beams == 1 && (device_type_ == DeviceType::CPU || device_type_ == DeviceType::DML || device_type_ == DeviceType::WEBGPU)) {
+  // When num_beams == 1, we don't need to expand the input, but the expand has a side effect of copying from
+  // CPU memory to device memory, so we can skip if the p_device_inputs_ is the CPU device
+  if (num_beams == 1 && p_device_inputs_ == GetDeviceInterface(DeviceType::CPU))
     return std::move(input);
-  }
 
   auto input_type_info = input->GetTensorTypeAndShapeInfo();
   auto element_type = input_type_info->GetElementType();
diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp
index 13ab0f0f0..06dca1d46 100644
--- a/src/models/position_inputs.cpp
+++ b/src/models/position_inputs.cpp
@@ -166,8 +166,6 @@ void DefaultPositionInputs::CreateNextAttentionMaskTensor(int total_length) {
 void DefaultPositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) {
   if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1))
     throw std::runtime_error("DefaultPositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding.");
-  if (DeviceType::DML == model_.device_type_ && !(total_length == 0 || new_kv_length == 1))
-    throw std::runtime_error("DefaultPositionInputs::UpdatePositionIDs - DML does not support continuous decoding.");
 
   CreateNextAttentionMaskTensor(total_length);
   state_.inputs_[mask_input_index_] = attention_mask_.get();

From 4f2f0844fce0def8ead7bd876463f39c252e03d6 Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Thu, 30 Jan 2025 16:02:48 -0800
Subject: [PATCH 28/31] Refactor device_type

---
 src/cpu/interface.cpp                |  2 ++
 src/cuda/interface.cpp               |  2 ++
 src/dml/interface.cpp                |  2 ++
 src/generators.cpp                   | 11 +++++------
 src/generators.h                     | 12 +-----------
 src/models/adapters.cpp              |  2 +-
 src/models/captured_graph_pool.cpp   |  2 +-
 src/models/decoder_only_pipeline.cpp | 10 +++++-----
 src/models/logits.cpp                | 22 ++++------------------
 src/models/logits.h                  |  4 ----
 src/models/model.cpp                 | 24 +++++++++---------------
 src/models/model.h                   |  1 -
 src/models/position_inputs.cpp       |  4 ++--
 src/models/whisper.cpp               |  2 +-
 src/python/python.cpp                |  2 +-
 src/qnn/interface.cpp                |  2 ++
 src/smartptrs.h                      | 11 +++++++++++
 src/webgpu/interface.cpp             |  2 ++
 test/sampling_benchmark.cpp          |  1 -
 test/sampling_tests.cpp              | 13 -------------
 20 files changed, 51 insertions(+), 80 deletions(-)

diff --git a/src/cpu/interface.cpp b/src/cpu/interface.cpp
index 0b34b80c4..93bbcc6f1 100644
--- a/src/cpu/interface.cpp
+++ b/src/cpu/interface.cpp
@@ -46,6 +46,8 @@ struct CpuInterface : DeviceInterface {
   CpuInterface() {
   }
 
+  DeviceType GetType() const override { return DeviceType::CPU; }
+
   void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override {
     assert(!ort_allocator_);
     ort_allocator_ = &allocator;
diff --git a/src/cuda/interface.cpp b/src/cuda/interface.cpp
index 2638d0e4b..bf21d48c3 100644
--- a/src/cuda/interface.cpp
+++ b/src/cuda/interface.cpp
@@ -75,6 +75,8 @@ struct CudaInterfaceImpl final : DeviceInterface {
   ~CudaInterfaceImpl() {
   }
 
+  DeviceType GetType() const override { return DeviceType::CUDA; }
+
   void InitOrt(const OrtApi& api, Ort::Allocator& allocator) override {
     Ort::api = &api;
     assert(!ort_allocator_);
diff --git a/src/dml/interface.cpp b/src/dml/interface.cpp
index a16d770d5..5f6c40c6a 100644
--- a/src/dml/interface.cpp
+++ b/src/dml/interface.cpp
@@ -115,6 +115,8 @@ struct InterfaceImpl : DeviceInterface {
     Ort::ThrowOnError(Ort::api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&dml_api_)));
   }
 
+  DeviceType GetType() const override { return DeviceType::DML; }
+
   void InitOrt(const OrtApi& api, Ort::Allocator& allocator) override {
     Ort::api = &api;
     assert(!ort_allocator_);
diff --git a/src/generators.cpp b/src/generators.cpp
index 66f49ee64..fd550421b 100644
--- a/src/generators.cpp
+++ b/src/generators.cpp
@@ -203,8 +203,7 @@ GeneratorParams::GeneratorParams(const Config& config)
 
 GeneratorParams::GeneratorParams(const Model& model)
     : config{*model.config_.get()},
-      p_device{model.p_device_},
-      device_type{model.device_type_},
+      p_device{model.p_device_inputs_},
       is_cuda_graph_enabled_{IsCudaGraphEnabled(model.config_->model.decoder.session_options)} {
   use_cuda_graph = is_cuda_graph_enabled_;
   if (use_cuda_graph) {
@@ -213,12 +212,12 @@ GeneratorParams::GeneratorParams(const Model& model)
 }
 
 void GeneratorParams::TryGraphCapture(int max_bs) {
-  if (!is_cuda_graph_enabled_ || device_type == DeviceType::CPU) {
+  if (!is_cuda_graph_enabled_ || p_device->GetType() == DeviceType::CPU) {
     // no-op
     return;
   }
 
-  if (DeviceType::CUDA == device_type || DeviceType::DML == device_type) {
+  if (DeviceType::CUDA == p_device->GetType() || DeviceType::DML == p_device->GetType()) {
     if (max_bs == 0) {
       throw std::runtime_error("Graph capture is enabled, but max_batch_size is not set.");
     }
@@ -323,8 +322,8 @@ void Generator::AppendTokens(cpu_span<const int32_t> input_ids) {
   constexpr std::array<DeviceType, 3> devices_supporting_continuous_decoding{DeviceType::CPU, DeviceType::CUDA, DeviceType::WEBGPU};
   if (search_->GetSequenceLength() != 0 &&
       std::none_of(devices_supporting_continuous_decoding.begin(), devices_supporting_continuous_decoding.end(),
-                   [this](DeviceType device_type) { return device_type == state_->params_->device_type; }))
-    throw std::runtime_error("Continuous decoding is not supported on the selected device type (" + to_string(state_->params_->device_type) +
+                   [this](DeviceType device_type) { return device_type == state_->params_->p_device->GetType(); }))
+    throw std::runtime_error("Continuous decoding is not supported on the selected device type (" + to_string(state_->params_->p_device->GetType()) +
                              "). Please recreate the generator instance to avoid using continuous decoding.");
 
   if (last_action_ == Action::generated) {
diff --git a/src/generators.h b/src/generators.h
index 6d2865a37..50962b744 100644
--- a/src/generators.h
+++ b/src/generators.h
@@ -64,15 +64,6 @@ struct OrtTensor {
 // OgaSequences are a vector of int32 vectors
 using TokenSequences = std::vector<std::vector<int32_t>>;
 
-enum struct DeviceType {
-  CPU,
-  CUDA,
-  DML,
-  WEBGPU,
-  QNN,
-  MAX
-};
-
 std::string to_string(DeviceType device_type);
 DeviceInterface* GetDeviceInterface(DeviceType type);
 
@@ -87,8 +78,7 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams>, LeakChec
   bool use_cuda_graph{};
   int BatchBeamSize() const { return search.num_beams * search.batch_size; }
 
-  DeviceInterface* p_device{};
-  DeviceType device_type{DeviceType::CPU};
+  DeviceInterface* p_device{};  // Scoring device (usually CPU, but can be CUDA)
 
   cpu_span<int32_t> aux_input_ids{};  // Intermediate solution to be used with SetInputs function for multimodal and whisper models
 
diff --git a/src/models/adapters.cpp b/src/models/adapters.cpp
index 5a95ebdd9..13791386c 100644
--- a/src/models/adapters.cpp
+++ b/src/models/adapters.cpp
@@ -34,7 +34,7 @@ void Adapters::LoadAdapter(const char* adapter_file_path, const std::string& ada
   }
 
   adapters_.emplace(adapter_name, std::make_unique<Adapter>(adapter_file_path,
-                                                            model_->device_type_ == DeviceType::CUDA
+                                                            model_->p_device_->GetType() == DeviceType::CUDA
                                                                 ? &model_->p_device_->GetAllocator()
                                                                 : nullptr));
 }
diff --git a/src/models/captured_graph_pool.cpp b/src/models/captured_graph_pool.cpp
index baaa9243e..a5cea3701 100644
--- a/src/models/captured_graph_pool.cpp
+++ b/src/models/captured_graph_pool.cpp
@@ -19,7 +19,7 @@ void CapturedGraphInfoRecycler::operator()(CapturedGraphInfo* captured_graph_inf
 }
 
 CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, const GeneratorParams& params) const {
-  if (!params.use_cuda_graph || (model.device_type_ != DeviceType::CUDA)) {
+  if (!params.use_cuda_graph || (model.p_device_->GetType() != DeviceType::CUDA)) {
     return nullptr;
   }
 
diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp
index c8c43c2b1..8ae9246f9 100644
--- a/src/models/decoder_only_pipeline.cpp
+++ b/src/models/decoder_only_pipeline.cpp
@@ -58,9 +58,9 @@ bool IntermediatePipelineState::HasOutput(std::string_view name) const {
 }
 
 bool IntermediatePipelineState::SupportsPrimaryDevice() const {
-  if (model_.device_type_ == DeviceType::CPU || model_.device_type_ == DeviceType::QNN) {
+  if (model_.p_device_->GetType() == DeviceType::CPU || model_.p_device_->GetType() == DeviceType::QNN) {
     return true;
-  } else if (model_.device_type_ == DeviceType::CUDA) {
+  } else if (model_.p_device_->GetType() == DeviceType::CUDA) {
     if (!model_.config_->model.decoder.pipeline[id_].session_options.has_value()) {
       // No session options, so this session uses the default session options.
       // Default session options supports the cuda device type.
@@ -134,7 +134,7 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
         if (!pipeline_state->SupportsPrimaryDevice()) {
           std::ostringstream oss;
           oss << "Managed input " << input_name << " resides on the primary device type ("
-              << to_string(model_.device_type_) << "). "
+              << to_string(model_.p_device_->GetType()) << "). "
               << "But the pipeline model "
               << model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id
               << " is expecting it to reside elsewhere.";
@@ -159,7 +159,7 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
         if (!pipeline_state->SupportsPrimaryDevice()) {
           std::ostringstream oss;
           oss << "Managed output " << output_name << " resides on the primary device type ("
-              << to_string(model_.device_type_) << "). "
+              << to_string(model_.p_device_->GetType()) << "). "
               << "But the pipeline model "
               << model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id
               << " is expecting it to reside elsewhere.";
@@ -178,7 +178,7 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
         if (!pipeline_state->SupportsPrimaryDevice()) {
           std::ostringstream oss;
           oss << "Managed input " << input_name << " resides on the primary device type ("
-              << to_string(model_.device_type_) << "). "
+              << to_string(model_.p_device_->GetType()) << "). "
               << "But the pipeline model "
               << model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id
               << " is expecting it to reside elsewhere.";
diff --git a/src/models/logits.cpp b/src/models/logits.cpp
index 48bdc4f32..369bd8e39 100644
--- a/src/models/logits.cpp
+++ b/src/models/logits.cpp
@@ -12,7 +12,7 @@ Logits::Logits(State& state)
       type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} {
   output_raw_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_);
 
-  if (model_.device_type_ == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) {
+  if (model_.p_device_inputs_->GetType() == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) {
     auto& cpu_ids = model_.config_->model.eos_token_ids;
     cuda_eos_token_ids_ = model_.p_device_->Allocate<int32_t>(cpu_ids.size());
     copy(std::span<const int32_t>{cpu_ids}, cuda_eos_token_ids_.CpuSpan());
@@ -70,9 +70,9 @@ DeviceSpan<float> Logits::Get() {
   if (logits_.empty() || logits_of_last_token->GetTensorMutableRawData() != logits_.Span().data())
     logits_ = WrapTensor<float>(*model_.p_device_inputs_, *logits_of_last_token);
 
-  if (model_.device_type_ == DeviceType::CUDA) {
+  if (model_.p_device_inputs_->GetType() == DeviceType::CUDA) {
     if (!cuda_eos_token_ids_.empty())
-      model_.p_device_->LaunchHandleEOSArray(
+      model_.p_device_inputs_->LaunchHandleEOSArray(
           logits_.Span().data(),
           static_cast<int>(shape_[0]) /* batch_beam_size*/,
           static_cast<int>(shape_[2]) /* vocab_size */,
@@ -107,21 +107,7 @@ void Logits::Update(const DeviceSpan<int32_t>& next_tokens, size_t new_kv_length
   }
 
   shape_[1] = new_kv_length;
-  StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType<Ort::Float16_t> ? sb_logits16_ : sb_logits32_;
-  output_raw_ = !sb_logits ? OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_)
-                           : sb_logits->CreateTensorOnStaticBuffer(shape_, type_);
-
-  if (state_.GetCapturedGraphInfo()) {
-    if (!sb_logits16_ && !sb_logits32_) {
-      if (type_ == Ort::TypeToTensorType<float>) {
-        sb_logits32_ = state_.GetCapturedGraphInfo()->sb_logits32_.get();
-      }
-      if (type_ == Ort::TypeToTensorType<Ort::Float16_t>) {
-        sb_logits16_ = state_.GetCapturedGraphInfo()->sb_logits16_.get();
-      }
-    }
-  }
-
+  output_raw_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_);
   state_.outputs_[output_index_] = output_raw_.get();
 }
 
diff --git a/src/models/logits.h b/src/models/logits.h
index 3eae05875..a9c658835 100644
--- a/src/models/logits.h
+++ b/src/models/logits.h
@@ -39,10 +39,6 @@ struct Logits {
   // OrtValue wrapped in a DeviceMemory object to make it universal
   DeviceSpan<float> logits_;
 
-  // Used for decoding runs with cuda graphs.
-  StaticBuffer* sb_logits32_{};
-  StaticBuffer* sb_logits16_{};
-
   DeviceSpan<int32_t> cuda_eos_token_ids_;  // eos_token_ids from params, but in cuda accessible memory
 };
 
diff --git a/src/models/model.cpp b/src/models/model.cpp
index e245c9786..266569419 100644
--- a/src/models/model.cpp
+++ b/src/models/model.cpp
@@ -301,10 +301,10 @@ Model::Model(std::unique_ptr<Config> config) : config_{std::move(config)} {
 Model::~Model() = default;
 
 void Model::InitDeviceAllocator(OrtSession& session) {
-  EnsureDeviceOrtInit(session, device_type_);
+  EnsureDeviceOrtInit(session, p_device_->GetType());
 
   // Only CUDA does every input on the device
-  if (device_type_ == DeviceType::CUDA)
+  if (p_device_->GetType() == DeviceType::CUDA)
     p_device_inputs_ = p_device_;
   else
     p_device_inputs_ = GetDeviceInterface(DeviceType::CPU);
@@ -413,8 +413,7 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
       // Device type determines the scoring device.
       // Only use the primary session options to determine the device type
       if (is_primary_session_options) {
-        device_type_ = DeviceType::CUDA;  // Scoring will use CUDA
-        p_device_ = GetDeviceInterface(device_type_);
+        p_device_ = GetDeviceInterface(DeviceType::CUDA);
 
         // Create and set our cudaStream_t
         ort_provider_options->UpdateValue("user_compute_stream", p_device_->GetCudaStream());
@@ -451,15 +450,10 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
         InitDmlInterface(p_device_luid);
       }
 
-      if (!disable_graph_capture) {
-        session_options.AddConfigEntry("ep.dml.enable_graph_capture", "1");
-        session_options.AddConfigEntry("ep.dml.disable_memory_arena", "1");
-      }
-
       SetDmlProvider(session_options);
 
       if (is_primary_session_options)
-        device_type_ = DeviceType::DML;  // We use a DML allocator for input/output caches, but other tensors will use CPU tensors
+        p_device_ = GetDeviceInterface(DeviceType::DML);  // We use a DML allocator for input/output caches, but other tensors will use CPU tensors
 #endif
     } else if (provider_options.name == "qnn") {
       session_options.AddConfigEntry("ep.share_ep_contexts", "1");
@@ -473,12 +467,12 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
       // on the other hand, not sure if is_primary_session_options is the right thing to check here.
       if (const auto opt_it = opts.find("enable_htp_shared_memory_allocator");
           opt_it != opts.end() && opt_it->second == "1") {
-        device_type_ = DeviceType::QNN;
+        p_device_ = GetDeviceInterface(DeviceType::QNN);
       }
 
       session_options.AppendExecutionProvider("QNN", opts);
     } else if (provider_options.name == "webgpu") {
-      device_type_ = DeviceType::WEBGPU;
+      p_device_ = GetDeviceInterface(DeviceType::WEBGPU);
       std::unordered_map<std::string, std::string> opts;
       for (auto& option : provider_options.options) {
         opts.emplace(option.first, option.second);
@@ -488,9 +482,9 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
       throw std::runtime_error("Unknown provider type: " + provider_options.name);
   }
 
-  if (!p_device_) {
-    p_device_ = GetDeviceInterface(device_type_);
-  }
+  // Fallback to CPU if no provider specific interface was set
+  if (!p_device_)
+    p_device_ = GetDeviceInterface(DeviceType::CPU);
 }
 
 void Model::CreateSessionOptions() {
diff --git a/src/models/model.h b/src/models/model.h
index 91b202c42..c17736b73 100644
--- a/src/models/model.h
+++ b/src/models/model.h
@@ -136,7 +136,6 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
   std::unique_ptr<Config> config_;
   std::unique_ptr<OrtSessionOptions> session_options_;
 
-  DeviceType device_type_{DeviceType::CPU};
   mutable DeviceInterface* p_device_{};          // The device we're running on (matches device_type_) used for things that work the same on all devices
   mutable DeviceInterface* p_device_inputs_{};   // For some model inputs, the device might be the CPU device (all but KV cache currently)
   mutable DeviceInterface* p_device_kvcache_{};  // The kvcache is always allocated in device memory  (TODO: Remove in favor of just p_device_?)
diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp
index 06dca1d46..ce81bae66 100644
--- a/src/models/position_inputs.cpp
+++ b/src/models/position_inputs.cpp
@@ -142,7 +142,7 @@ void DefaultPositionInputs::UpdatePositionIDs(int total_length, int new_kv_lengt
     state_.inputs_[posid_input_index_] = position_ids_.get();
   }
 
-  if (model_.device_type_ == DeviceType::CUDA)
+  if (model_.p_device_inputs_->GetType() == DeviceType::CUDA)
     model_.p_device_inputs_->UpdatePositionIds(position_ids_->GetTensorMutableRawData(), static_cast<int>(position_ids_shape_[0]), total_length, new_kv_length, type_);
   else {
     type_ == Ort::TypeToTensorType<int32_t> ? UpdatePositionIDsImpl<int32_t>(total_length, new_kv_length)
@@ -170,7 +170,7 @@ void DefaultPositionInputs::UpdateAttentionMask(int total_length, int new_kv_len
   CreateNextAttentionMaskTensor(total_length);
   state_.inputs_[mask_input_index_] = attention_mask_.get();
 
-  if (model_.device_type_ == DeviceType::CUDA) {
+  if (model_.p_device_inputs_->GetType() == DeviceType::CUDA) {
     int max_length = sb_attention_mask_ ? state_.params_->search.max_length : total_length;
     bool update_only = sb_attention_mask_ && !is_first_mask_update_;
     model_.p_device_inputs_->UpdateAttentionMask(attention_mask_next_->GetTensorMutableRawData(),
diff --git a/src/models/whisper.cpp b/src/models/whisper.cpp
index 6bb64df65..05fc20171 100644
--- a/src/models/whisper.cpp
+++ b/src/models/whisper.cpp
@@ -182,7 +182,7 @@ DeviceSpan<float> Whisper_State::Run(int current_length, DeviceSpan<int32_t>& ne
         auto src_data = init_presents_[i]->GetTensorRawData();
         auto dest_data = presents_[i]->GetTensorMutableRawData();
 
-        switch (model_.device_type_) {
+        switch (model_.p_device_inputs_->GetType()) {
 #if 0  // USE_CUDA
           case DeviceType::CUDA:
             if (self_attn_kv_cache_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
diff --git a/src/python/python.cpp b/src/python/python.cpp
index deae726da..0ae6e6236 100644
--- a/src/python/python.cpp
+++ b/src/python/python.cpp
@@ -412,7 +412,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
       }))
       .def_property_readonly("type", [](const Model& model) { return model.config_->model.type; })
       .def_property_readonly(
-          "device_type", [](const Model& model) { return to_string(model.device_type_); }, "The device type the model is running on")
+          "device_type", [](const Model& model) { return to_string(model.p_device_->GetType()); }, "The device type the model is running on")
       .def("create_multimodal_processor", [](const Model& model) { return model.CreateMultiModalProcessor(); });
 
   pybind11::class_<PyGenerator>(m, "Generator")
diff --git a/src/qnn/interface.cpp b/src/qnn/interface.cpp
index 3acc3b58d..0fc746437 100644
--- a/src/qnn/interface.cpp
+++ b/src/qnn/interface.cpp
@@ -46,6 +46,8 @@ struct InterfaceImpl : DeviceInterface {
   InterfaceImpl() {
   }
 
+  DeviceType GetType() const override { return DeviceType::QNN; }
+
   void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override {
     assert(!ort_allocator_);
     ort_allocator_ = &allocator;
diff --git a/src/smartptrs.h b/src/smartptrs.h
index 1e2a79f47..642196ca3 100644
--- a/src/smartptrs.h
+++ b/src/smartptrs.h
@@ -83,8 +83,19 @@ struct DeviceSpan {
   friend struct DeviceSpan;  // All DeviceSpans are friends
 };
 
+enum struct DeviceType {
+  CPU,
+  CUDA,
+  DML,
+  WEBGPU,
+  QNN,
+  MAX
+};
+
 struct DeviceInterface {
   virtual ~DeviceInterface() {}
+
+  virtual DeviceType GetType() const = 0;
   virtual void InitOrt(const OrtApi& api, Ort::Allocator& allocator) = 0;
   virtual Ort::Allocator& GetAllocator() = 0;
 
diff --git a/src/webgpu/interface.cpp b/src/webgpu/interface.cpp
index 246cfed50..3b1fd9c25 100644
--- a/src/webgpu/interface.cpp
+++ b/src/webgpu/interface.cpp
@@ -46,6 +46,8 @@ struct InterfaceImpl : DeviceInterface {
   InterfaceImpl() {
   }
 
+  DeviceType GetType() const override { return DeviceType::WEBGPU; }
+
   void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override {
     assert(!ort_allocator_);
     ort_allocator_ = &allocator;
diff --git a/test/sampling_benchmark.cpp b/test/sampling_benchmark.cpp
index eb7f04cd9..28137fdc9 100644
--- a/test/sampling_benchmark.cpp
+++ b/test/sampling_benchmark.cpp
@@ -32,7 +32,6 @@ struct SamplingBenchmark {
     params->search.max_length = 10;
     params->search.batch_size = batch_size_;
     params->p_device = Generators::GetDeviceInterface(device_type_);
-    params->device_type = device_type_;
 
     std::random_device rd;
     std::mt19937 engine(rd());
diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp
index a42609306..6fa206873 100644
--- a/test/sampling_tests.cpp
+++ b/test/sampling_tests.cpp
@@ -39,7 +39,6 @@ TEST(SamplingTests, BatchedSamplingTopPCpu) {
   params->search.top_p = 0.25f;
   params->search.batch_size = 4;
   params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU);
-  params->device_type = Generators::DeviceType::CPU;
   auto generator = Generators::CreateGenerator(*model, *params);
   auto logits = params->p_device->WrapMemory<float>(logits_cpu);
   generator->SetLogits(logits);
@@ -66,7 +65,6 @@ TEST(SamplingTests, BatchedSamplingTopKCpu) {
   params->search.top_k = 2;
   params->search.batch_size = batch_size;
   params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU);
-  params->device_type = Generators::DeviceType::CPU;
   auto generator = Generators::CreateGenerator(*model, *params);
   auto logits_copy = logits_cpu;
   auto logits = params->p_device->WrapMemory<float>(logits_copy);
@@ -101,7 +99,6 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCpu) {
   params->search.top_p = 0.25f;
   params->search.batch_size = batch_size;
   params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU);
-  params->device_type = Generators::DeviceType::CPU;
   auto generator = Generators::CreateGenerator(*model, *params);
   auto logits_copy = logits_cpu;
   auto logits = params->p_device->WrapMemory<float>(logits_copy);
@@ -152,7 +149,6 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) {
   params->search.top_p = 0.95f;
   params->search.batch_size = batch_size;
   params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU);
-  params->device_type = Generators::DeviceType::CPU;
   std::vector<float> logits_cpu(config.model.vocab_size * batch_size);
   std::random_device rd;
   std::mt19937 engine(rd());
@@ -205,7 +201,6 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) {
   params->search.top_k = k;
   params->search.batch_size = batch_size;
   params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU);
-  params->device_type = Generators::DeviceType::CPU;
 
   // Create data structures for testing
   std::random_device rd;
@@ -270,7 +265,6 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) {
   params->search.top_p = p;
   params->search.batch_size = batch_size;
   params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CPU);
-  params->device_type = Generators::DeviceType::CPU;
   std::vector<float> logits_cpu(config.model.vocab_size * batch_size);
   std::random_device rd;
   std::mt19937 engine(rd());
@@ -317,7 +311,6 @@ TEST(SamplingTests, BatchedSamplingTopPCuda) {
   params->search.top_p = 0.25f;
   params->search.batch_size = batch_size;
   params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA);
-  params->device_type = Generators::DeviceType::CUDA;
   auto logits = AllocateFromCpuMem<float>(*params->p_device, logits_cpu);
   auto generator = Generators::CreateGenerator(*model, *params);
   generator->SetLogits(logits);
@@ -345,7 +338,6 @@ TEST(SamplingTests, BatchedSamplingTopKCuda) {
   params->search.top_k = 2;
   params->search.batch_size = batch_size;
   params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA);
-  params->device_type = Generators::DeviceType::CUDA;
   auto logits = AllocateFromCpuMem<float>(*params->p_device, logits_cpu);
   auto generator = Generators::CreateGenerator(*model, *params);
   generator->SetLogits(logits);
@@ -378,7 +370,6 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCuda) {
   params->search.top_p = 0.25f;
   params->search.batch_size = batch_size;
   params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA);
-  params->device_type = Generators::DeviceType::CUDA;
   auto logits = AllocateFromCpuMem<float>(*params->p_device, logits_cpu);
   auto generator = Generators::CreateGenerator(*model, *params);
   generator->SetLogits(logits);
@@ -406,7 +397,6 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) {
   params->search.top_p = 0.95f;
   params->search.batch_size = batch_size;
   params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA);
-  params->device_type = Generators::DeviceType::CUDA;
   auto logits_gpu = params->p_device->Allocate<float>(config.model.vocab_size * batch_size);
   auto indices_buffer = params->p_device->Allocate<int>(config.model.vocab_size * batch_size);
 
@@ -448,7 +438,6 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) {
   params->search.top_k = k;
   params->search.batch_size = batch_size;
   params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA);
-  params->device_type = Generators::DeviceType::CUDA;
 
   // Create data structures for testing
   std::random_device rd;
@@ -512,7 +501,6 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) {
   params->search.top_p = p;
   params->search.batch_size = batch_size;
   params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA);
-  params->device_type = Generators::DeviceType::CUDA;
   auto logits_gpu = params->p_device->Allocate<float>(config.model.vocab_size * batch_size);
   auto indices_buffer = params->p_device->Allocate<int>(config.model.vocab_size * batch_size);
   std::random_device rd;
@@ -549,7 +537,6 @@ TEST(SamplingTests, RandomizedSamplingSelectTopCuda) {
   params->search.max_length = 10;
   params->search.batch_size = batch_size;
   params->p_device = Generators::GetDeviceInterface(Generators::DeviceType::CUDA);
-  params->device_type = Generators::DeviceType::CUDA;
   auto logits_gpu = params->p_device->Allocate<float>(config.model.vocab_size * batch_size);
   auto indices_buffer = params->p_device->Allocate<int>(config.model.vocab_size * batch_size);
   std::random_device rd;

From acba52cb41d50465def66dd2d42291b1d6c2d588 Mon Sep 17 00:00:00 2001
From: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com>
Date: Mon, 3 Feb 2025 15:24:33 -0800
Subject: [PATCH 29/31] Update src/models/model.h

Co-authored-by: aciddelgado <139922440+aciddelgado@users.noreply.github.com>
---
 src/models/model.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/models/model.h b/src/models/model.h
index c17736b73..9b04e589d 100644
--- a/src/models/model.h
+++ b/src/models/model.h
@@ -137,7 +137,7 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
   std::unique_ptr<OrtSessionOptions> session_options_;
 
   mutable DeviceInterface* p_device_{};          // The device we're running on (matches device_type_) used for things that work the same on all devices
-  mutable DeviceInterface* p_device_inputs_{};   // For some model inputs, the device might be the CPU device (all but KV cache currently)
+  mutable DeviceInterface* p_device_inputs_{};   // For some model inputs, the device might be the CPU device (all but KV cache currently for WebGPU and DML)
   mutable DeviceInterface* p_device_kvcache_{};  // The kvcache is always allocated in device memory  (TODO: Remove in favor of just p_device_?)
 
   Ort::Allocator& allocator_cpu_{GetDeviceInterface(DeviceType::CPU)->GetAllocator()};

From 68a6ea7d8c4162bb42d6a8507e334a298add223e Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Wed, 12 Feb 2025 18:26:47 -0800
Subject: [PATCH 30/31] Fix merge conflicts

---
 src/beam_search_scorer.cpp       |  2 +-
 src/generators.h                 |  1 -
 src/models/windowed_kv_cache.cpp | 16 ++++++++--------
 src/models/windowed_kv_cache.h   |  3 +++
 4 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/src/beam_search_scorer.cpp b/src/beam_search_scorer.cpp
index b9cbceffe..22765b6c4 100644
--- a/src/beam_search_scorer.cpp
+++ b/src/beam_search_scorer.cpp
@@ -27,7 +27,7 @@ void BeamHypotheses::Add(std::span<int32_t> hypothesis, float sum_logprobs) {
       return;
     }
   } else {
-    beams_used_++;
+    beams_used_++; 
   }
 
   // Rotate existing elements over while the new element scores higher
diff --git a/src/generators.h b/src/generators.h
index c84358b0e..185c04583 100644
--- a/src/generators.h
+++ b/src/generators.h
@@ -26,7 +26,6 @@
 
 #include "leakcheck.h"
 #include "make_string.h"
-#include "smartptrs.h"
 #include "models/onnxruntime_api.h"
 #include "smartptrs.h"
 #include "models/debugging.h"
diff --git a/src/models/windowed_kv_cache.cpp b/src/models/windowed_kv_cache.cpp
index a143d86b3..8b147f65d 100644
--- a/src/models/windowed_kv_cache.cpp
+++ b/src/models/windowed_kv_cache.cpp
@@ -47,21 +47,21 @@ WindowedKeyValueCache::WindowedKeyValueCache(State& state)
 
   for (int i = 0; i < layer_count_; ++i) {
     key_caches_in_.push_back(
-        OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_in_, type_));
+        OrtValue::CreateTensor(Allocator(), key_cache_shape_in_, type_));
     std::fill_n(key_caches_in_[i]->GetTensorMutableData<uint8_t>(),
                 ElementCountFromShape(key_cache_shape_in_),
                 static_cast<uint8_t>(model_.config_->model.decoder.sliding_window->pad_value));
 
     value_caches_in_.push_back(
-        OrtValue::CreateTensor(*model_.allocator_device_, value_cache_shape_in_, type_));
+        OrtValue::CreateTensor(Allocator(), value_cache_shape_in_, type_));
     std::fill_n(value_caches_in_[i]->GetTensorMutableData<uint8_t>(),
                 ElementCountFromShape(value_cache_shape_in_),
                 static_cast<uint8_t>(model_.config_->model.decoder.sliding_window->pad_value));
 
     key_caches_out_.push_back(
-        OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_out_, type_));
+        OrtValue::CreateTensor(Allocator(), key_cache_shape_out_, type_));
     value_caches_out_.push_back(
-        OrtValue::CreateTensor(*model_.allocator_device_, value_cache_shape_out_, type_));
+        OrtValue::CreateTensor(Allocator(), value_cache_shape_out_, type_));
   }
 }
 
@@ -187,7 +187,7 @@ void WindowedKeyValueCache::Update(DeviceSpan<int32_t> /* beam_indices */, int c
 
   ThreadPool thread_pool{static_cast<size_t>(layer_count_)};
   thread_pool.Compute([&](size_t layer_idx) {
-    std::unique_ptr<OrtValue> key_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_in, type_);
+    std::unique_ptr<OrtValue> key_cache = OrtValue::CreateTensor(Allocator(), updated_key_cache_shape_in, type_);
 
     uint8_t* key_cache_data = key_cache->GetTensorMutableData<uint8_t>();
     uint8_t* key_cache_in_data = key_caches_in_[layer_idx]->GetTensorMutableData<uint8_t>();
@@ -213,9 +213,9 @@ void WindowedKeyValueCache::Update(DeviceSpan<int32_t> /* beam_indices */, int c
     }
 
     key_caches_in_[layer_idx] = std::move(key_cache);
-    key_caches_out_[layer_idx] = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_out, type_);
+    key_caches_out_[layer_idx] = OrtValue::CreateTensor(Allocator(), updated_key_cache_shape_out, type_);
 
-    std::unique_ptr<OrtValue> value_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_in, type_);
+    std::unique_ptr<OrtValue> value_cache = OrtValue::CreateTensor(Allocator(), updated_value_cache_shape_in, type_);
 
     uint8_t* value_cache_data = value_cache->GetTensorMutableData<uint8_t>();
     uint8_t* value_cache_in_data = value_caches_in_[layer_idx]->GetTensorMutableData<uint8_t>();
@@ -241,7 +241,7 @@ void WindowedKeyValueCache::Update(DeviceSpan<int32_t> /* beam_indices */, int c
     }
 
     value_caches_in_[layer_idx] = std::move(value_cache);
-    value_caches_out_[layer_idx] = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_out, type_);
+    value_caches_out_[layer_idx] = OrtValue::CreateTensor(Allocator(), updated_value_cache_shape_out, type_);
   });
 
   window_size_ = 1;
diff --git a/src/models/windowed_kv_cache.h b/src/models/windowed_kv_cache.h
index a842eb04d..4e7de10ea 100644
--- a/src/models/windowed_kv_cache.h
+++ b/src/models/windowed_kv_cache.h
@@ -31,6 +31,9 @@ struct WindowedKeyValueCache : KeyValueCache {
   void SlideAllLayers();
   void SlideLayers(std::span<const size_t> layer_indices);
 
+  DeviceInterface& Device() { return *model_.p_device_kvcache_; }
+  Ort::Allocator& Allocator() { return model_.p_device_kvcache_->GetAllocator(); }
+
   State& state_;
   const Model& model_{state_.model_};
   int layer_count_{};

From 12e2f76c97d2585d7d6230798ac56985a1b9e1bf Mon Sep 17 00:00:00 2001
From: Ryan Hill <ryanhill@microsoft.com>
Date: Wed, 12 Feb 2025 18:27:45 -0800
Subject: [PATCH 31/31] Formatting

---
 src/beam_search_scorer.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/beam_search_scorer.cpp b/src/beam_search_scorer.cpp
index 22765b6c4..b9cbceffe 100644
--- a/src/beam_search_scorer.cpp
+++ b/src/beam_search_scorer.cpp
@@ -27,7 +27,7 @@ void BeamHypotheses::Add(std::span<int32_t> hypothesis, float sum_logprobs) {
       return;
     }
   } else {
-    beams_used_++; 
+    beams_used_++;
   }
 
   // Rotate existing elements over while the new element scores higher