Skip to content

Commit

Permalink
chore: address review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Apr 17, 2024
1 parent a83ba8c commit c272b78
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 14 deletions.
5 changes: 2 additions & 3 deletions examples/dynamo/torch_compile_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
from diffusers import DiffusionPipeline

import torch_tensorrt
from diffusers import DiffusionPipeline

model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda:0"
Expand All @@ -39,7 +38,7 @@
backend=backend,
options={
"truncate_long_and_double": True,
"precision": torch.float16,
"enabled_precisions": {torch.float32, torch.float16},
},
dynamic=False,
)
Expand Down
6 changes: 4 additions & 2 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,11 @@ def compile(
Returns:
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
"""
input_list = inputs if inputs else []
input_list = inputs if inputs is not None else []
enabled_precisions_set: Set[dtype | torch.dtype] = (
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
enabled_precisions
if enabled_precisions is not None
else _defaults.ENABLED_PRECISIONS
)

module_type = _parse_module_type(module)
Expand Down
39 changes: 31 additions & 8 deletions py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ def try_from(
casted_format = dtype._from(t, use_default=use_default)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(e)
logging.debug(
f"Conversion from {t} to torch_tensorrt.dtype failed", exc_info=True
)
return None

def to(
Expand Down Expand Up @@ -301,7 +303,10 @@ def try_to(
casted_format = self.to(t, use_default)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(e)
logging.debug(
f"torch_tensorrt.dtype conversion to target type {t} failed",
exc_info=True,
)
return None

def __eq__(self, other: Union[torch.dtype, trt.DataType, np.dtype, dtype]) -> bool:
Expand Down Expand Up @@ -413,7 +418,10 @@ def try_from(
casted_format = memory_format._from(f)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(e)
logging.debug(
f"Conversion from {f} to torch_tensorrt.memory_format failed",
exc_info=True,
)
return None

def to(
Expand Down Expand Up @@ -492,7 +500,10 @@ def try_to(
casted_format = self.to(t)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(e)
logging.debug(
f"torch_tensorrt.memory_format conversion to target type {t} failed",
exc_info=True,
)
return None

def __eq__(
Expand Down Expand Up @@ -546,7 +557,10 @@ def try_from(cls, d: Union[trt.DeviceType, DeviceType]) -> Optional[DeviceType]:
casted_format = DeviceType._from(d)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(e)
logging.debug(
f"Conversion from {d} to torch_tensorrt.DeviceType failed",
exc_info=True,
)
return None

def to(
Expand Down Expand Up @@ -595,7 +609,10 @@ def try_to(
casted_format = self.to(t, use_default=use_default)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(e)
logging.debug(
f"torch_tensorrt.DeviceType conversion to target type {t} failed",
exc_info=True,
)
return None

def __eq__(self, other: Union[trt.DeviceType, DeviceType]) -> bool:
Expand Down Expand Up @@ -653,7 +670,10 @@ def try_from(
casted_format = EngineCapability._from(c)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(e)
logging.debug(
f"Conversion from {c} to torch_tensorrt.EngineCapablity failed",
exc_info=True,
)
return None

def to(
Expand Down Expand Up @@ -696,7 +716,10 @@ def try_to(
casted_format = self.to(t)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(e)
logging.debug(
f"torch_tensorrt.EngineCapablity conversion to target type {t} failed",
exc_info=True,
)
return None

def __eq__(self, other: Union[trt.EngineCapability, EngineCapability]) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def create_constant(
A TensorRT ITensor that represents the given value.
"""
numpy_value = to_numpy(
value, _enums.dtype._from(dtype).to(np.dtype) if dtype else None
value, _enums.dtype._from(dtype).to(np.dtype) if dtype is not None else None
)
constant = ctx.net.add_constant(
(1,) if isinstance(value, (int, float, bool)) else value.shape,
Expand Down

0 comments on commit c272b78

Please sign in to comment.