Skip to content

Commit

Permalink
Query device type as a string (#593)
Browse files Browse the repository at this point in the history
Addresses #592
  • Loading branch information
baijumeswani committed Jun 12, 2024
1 parent 50aeb88 commit eb5262a
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 1 deletion.
12 changes: 12 additions & 0 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ OrtEnv& GetOrtEnv() {
return *GetOrtGlobals()->env_;
}

std::string to_string(DeviceType device_type) {
switch (device_type) {
case DeviceType::CPU:
return "CPU";
case DeviceType::CUDA:
return "CUDA";
case DeviceType::DML:
return "DirectML";
}
throw std::runtime_error("Unknown device type");
}

GeneratorParams::GeneratorParams(const Model& model)
: search{model.config_->search},
pad_token_id{model.config_->model.pad_token_id},
Expand Down
2 changes: 2 additions & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ enum struct DeviceType {
DML,
};

std::string to_string(DeviceType device_type);

struct GeneratorParams : std::enable_shared_from_this<GeneratorParams> {
GeneratorParams() = default; // This constructor is only used if doing a custom model handler vs built-in
GeneratorParams(const Model& model);
Expand Down
2 changes: 1 addition & 1 deletion src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
return CreateModel(GetOrtEnv(), config_path.c_str());
}))
.def("generate", [](Model& model, PyGeneratorParams& params) { params.Prepare(); return Generate(model, params); })
.def_property_readonly("device_type", [](const Model& s) { return s.device_type_; })
.def_property_readonly("device_type", [](const Model& s) { return to_string(s.device_type_); })
.def("create_multimodal_processor", [](const Model& model) { return model.CreateMultiModalProcessor(); });

pybind11::class_<PyGenerator>(m, "Generator")
Expand Down
20 changes: 20 additions & 0 deletions test/python/test_onnxruntime_genai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,23 @@ def test_logging():
og.set_log_options(enabled=True, model_input_values=True, model_output_shapes=True)
og.set_log_options(model_input_values=False, model_output_shapes=False)
og.set_log_options(enabled=False)


@pytest.mark.parametrize(
"relative_model_path",
(
[
(Path("hf-internal-testing") / "tiny-random-gpt2-fp32", "CPU"),
(Path("hf-internal-testing") / "tiny-random-gpt2-fp32-cuda", "CUDA"),
(Path("hf-internal-testing") / "tiny-random-gpt2-fp16-cuda", "CUDA"),
]
if og.is_cuda_available()
else [(Path("hf-internal-testing") / "tiny-random-gpt2-fp32", "CPU")]
),
)
def test_model_device_type(test_data_path, relative_model_path):
model_path = os.fspath(Path(test_data_path) / relative_model_path[0])

model = og.Model(model_path)

assert model.device_type == relative_model_path[1]

0 comments on commit eb5262a

Please sign in to comment.