diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index d5213fcd02..8be612b61c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2610,7 +2610,6 @@ def aten_ops_pixel_unshuffle( ) -@dynamo_tensorrt_converter(torch.ops.aten.resize.default) @dynamo_tensorrt_converter(torch.ops.aten.resize_.default) @enforce_tensor_types( { @@ -2624,7 +2623,7 @@ def aten_ops_resize( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.shuffle.resize_( + return impl.shuffle.resize( ctx, target, SourceIR.ATEN, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 726adc2573..bb0271b3fc 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -132,7 +132,7 @@ def pixel_unshuffle( ) -def resize_( +def resize( ctx: ConversionContext, target: Union[Target, str], source_ir: Optional[SourceIR], @@ -140,9 +140,7 @@ def resize_( input: TRTTensor, sizes: Sequence[int], ) -> TRTTensor: - input_np_dtype = unified_dtype_converter(input.dtype, Frameworks.NUMPY) - input_val = get_trt_tensor(ctx, input, f"{name}_input") # Calculate the total number of elements for new and current shape @@ -158,31 +156,34 @@ def resize_( # Flatten input tensor to 1D for concatenation flatten_shape = flatten_dims(input_val, 0, -1) - flattened_input = impl.shuffle.reshape( + flattened_input = reshape( ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape ) # Concatenate the flattened input tensor and padding tensor - concat_layer = ctx.net.add_concatenation([flattened_input, padding_tensor]) - concat_layer.axis = 0 - reshaped_tensor = concat_layer.get_output(0) - + reshaped_tensor = impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_cat", + [flattened_input, padding_tensor], + dim=0, + ) elif new_num_elements < current_num_elements: # Flatten input tensor to 1D for slicing flatten_shape = flatten_dims(input_val, 0, -1) - flattened_input = impl.shuffle.reshape( + flattened_input = reshape( ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape ) # Slice the flattened input tensor to the desired number of elements slice_layer = ctx.net.add_slice(flattened_input, [0], [new_num_elements], [1]) reshaped_tensor = slice_layer.get_output(0) - else: reshaped_tensor = input_val # Reshape the final output tensor to the target sizes - resized_output = impl.shuffle.reshape( + resized_output = reshape( ctx, target, source_ir, f"{name}_final_reshape", reshaped_tensor, sizes ) diff --git a/tests/py/dynamo/conversion/test_resize_aten.py b/tests/py/dynamo/conversion/test_resize_aten.py index 25f1f31fac..12e6cb66a1 100644 --- a/tests/py/dynamo/conversion/test_resize_aten.py +++ b/tests/py/dynamo/conversion/test_resize_aten.py @@ -20,9 +20,6 @@ class TestResizeConverter(DispatchTestCase): ) def test_resize_1d_input_float(self, target_shape): class Resize(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x): return torch.ops.aten.resize_.default(x, target_shape) @@ -46,9 +43,6 @@ def forward(self, x): ) def test_resize_1d_input_int(self, target_shape): class Resize(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x): return torch.ops.aten.resize_.default(x, target_shape) @@ -73,9 +67,6 @@ def forward(self, x): ) def test_resize_2d_input_float(self, target_shape): class Resize(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x): return torch.ops.aten.resize_.default(x, target_shape) @@ -100,9 +91,6 @@ def forward(self, x): ) def test_resize_2d_input_int(self, target_shape): class Resize(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x): return torch.ops.aten.resize_.default(x, target_shape)