From 01d525d6c3b0f3b5812cc45908b974cf254b0fd9 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Tue, 31 Aug 2021 13:41:17 -0700 Subject: [PATCH] feat(//py): Allow example tensors from torch to set shape Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- py/trtorch/Input.py | 7 +++++++ py/trtorch/_compile_spec.py | 6 +++++- tests/py/test_api.py | 24 ++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/py/trtorch/Input.py b/py/trtorch/Input.py index d36d1eb3b6..d8dc08fc09 100644 --- a/py/trtorch/Input.py +++ b/py/trtorch/Input.py @@ -196,3 +196,10 @@ def _parse_format(format: Any) -> _types.TensorFormat: else: raise TypeError( "Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat") + + @classmethod + def _from_tensor(cls, t: torch.Tensor): + if not any([t.is_contiguous(memory_format=torch.contiguous_format), t.is_contiguous(memory_format=torch.channels_last)]): + raise ValueError("Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last") + frmt = torch.contiguous_format if t.is_contiguous(memory_format=torch.contiguous_format) else torch.channels_last + return cls(shape=t.shape, dtype=t.dtype, format=frmt) \ No newline at end of file diff --git a/py/trtorch/_compile_spec.py b/py/trtorch/_compile_spec.py index dc2e7095fa..d1b984df9e 100644 --- a/py/trtorch/_compile_spec.py +++ b/py/trtorch/_compile_spec.py @@ -174,7 +174,11 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: info.inputs = _parse_input_ranges(compile_spec["input_shapes"]) if "inputs" in compile_spec: - info.inputs = [i._to_internal() for i in compile_spec["inputs"]] + if not all([isinstance(i, torch.Tensor) or isinstance(i, trtorch.Input) for i in compile_spec["inputs"]]): + raise KeyError("Input specs should be either trtorch.Input or torch.Tensor, found types: {}".format([typeof(i) for i in compile_spec["inputs"]])) + + inputs = [trtorch.Input._from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"]] + info.inputs = [i._to_internal() for i in inputs] if "op_precision" in compile_spec and "enabled_precisions" in compile_spec: raise KeyError( diff --git a/tests/py/test_api.py b/tests/py/test_api.py index 94239b1475..486eae12f5 100644 --- a/tests/py/test_api.py +++ b/tests/py/test_api.py @@ -73,6 +73,30 @@ def test_compile_script(self): same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() self.assertTrue(same < 2e-2) + def test_from_torch_tensor(self): + compile_spec = { + "inputs": [self.input], + "device": { + "device_type": trtorch.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float} + } + + trt_mod = trtorch.compile(self.scripted_model, compile_spec) + same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() + self.assertTrue(same < 2e-2) + + def test_device(self): + compile_spec = { + "inputs": [self.input], + "device": trtorch.Device("gpu:0"), + "enabled_precisions": {torch.float} + } + + trt_mod = trtorch.compile(self.scripted_model, compile_spec) + same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() + self.assertTrue(same < 2e-2) class TestCompileHalf(ModelTestCase):