diff --git a/src/generators.cpp b/src/generators.cpp index cbb00b46c..5744f1653 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -28,6 +28,18 @@ OrtEnv& GetOrtEnv() { return *GetOrtGlobals()->env_; } +std::string to_string(DeviceType device_type) { + switch (device_type) { + case DeviceType::CPU: + return "CPU"; + case DeviceType::CUDA: + return "CUDA"; + case DeviceType::DML: + return "DirectML"; + } + throw std::runtime_error("Unknown device type"); +} + GeneratorParams::GeneratorParams(const Model& model) : search{model.config_->search}, pad_token_id{model.config_->model.pad_token_id}, diff --git a/src/generators.h b/src/generators.h index bbf024004..7a8f08951 100644 --- a/src/generators.h +++ b/src/generators.h @@ -53,6 +53,8 @@ enum struct DeviceType { DML, }; +std::string to_string(DeviceType device_type); + struct GeneratorParams : std::enable_shared_from_this { GeneratorParams() = default; // This constructor is only used if doing a custom model handler vs built-in GeneratorParams(const Model& model); diff --git a/src/python/python.cpp b/src/python/python.cpp index e65fb3758..02903632e 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -418,7 +418,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { return CreateModel(GetOrtEnv(), config_path.c_str()); })) .def("generate", [](Model& model, PyGeneratorParams& params) { params.Prepare(); return Generate(model, params); }) - .def_property_readonly("device_type", [](const Model& s) { return s.device_type_; }) + .def_property_readonly("device_type", [](const Model& s) { return to_string(s.device_type_); }) .def("create_multimodal_processor", [](const Model& model) { return model.CreateMultiModalProcessor(); }); pybind11::class_(m, "Generator") diff --git a/test/python/test_onnxruntime_genai_api.py b/test/python/test_onnxruntime_genai_api.py index a92055434..e3e72c2c9 100644 --- a/test/python/test_onnxruntime_genai_api.py +++ b/test/python/test_onnxruntime_genai_api.py @@ -149,3 +149,23 @@ def test_logging(): og.set_log_options(enabled=True, model_input_values=True, model_output_shapes=True) og.set_log_options(model_input_values=False, model_output_shapes=False) og.set_log_options(enabled=False) + + +@pytest.mark.parametrize( + "relative_model_path", + ( + [ + (Path("hf-internal-testing") / "tiny-random-gpt2-fp32", "CPU"), + (Path("hf-internal-testing") / "tiny-random-gpt2-fp32-cuda", "CUDA"), + (Path("hf-internal-testing") / "tiny-random-gpt2-fp16-cuda", "CUDA"), + ] + if og.is_cuda_available() + else [(Path("hf-internal-testing") / "tiny-random-gpt2-fp32", "CPU")] + ), +) +def test_model_device_type(test_data_path, relative_model_path): + model_path = os.fspath(Path(test_data_path) / relative_model_path[0]) + + model = og.Model(model_path) + + assert model.device_type == relative_model_path[1]