diff --git a/docker/Dockerfile b/docker/Dockerfile index 60b213b110..16b92bbd17 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -47,7 +47,7 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/ RUN add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /" RUN apt-get update -RUN apt-get install -y libnvinfer8=${TENSORRT_VERSION}.* libnvinfer-plugin8=${TENSORRT_VERSION}.* libnvinfer-dev=${TENSORRT_VERSION}.* libnvinfer-plugin-dev=${TENSORRT_VERSION}.* libnvonnxparsers8=${TENSORRT_VERSION}.* libnvonnxparsers-dev=${TENSORRT_VERSION}.* libnvparsers8=${TENSORRT_VERSION}.* libnvparsers-dev=${TENSORRT_VERSION}.* +RUN apt-get install -y libnvinfer8=${TENSORRT_VERSION}.* libnvinfer-plugin8=${TENSORRT_VERSION}.* libnvinfer-dev=${TENSORRT_VERSION}.* libnvinfer-plugin-dev=${TENSORRT_VERSION}.* libnvonnxparsers8=${TENSORRT_VERSION}.* libnvonnxparsers-dev=${TENSORRT_VERSION}.* libnvparsers8=${TENSORRT_VERSION}.* libnvparsers-dev=${TENSORRT_VERSION}.* libnvinfer-headers-dev=${TENSORRT_VERSION}.* libnvinfer-headers-plugin-dev=${TENSORRT_VERSION}.* # Setup Bazel via Bazelisk RUN wget -q https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-amd64 -O /usr/bin/bazel &&\ diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 0dd153d0aa..9acb750aed 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1136,6 +1136,23 @@ def aten_ops_exp( ) +@dynamo_tensorrt_converter(torch.ops.aten.expm1.default) +def aten_ops_expm1( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.expm1( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.log.default) def aten_ops_log( ctx: ConversionContext, @@ -1391,6 +1408,30 @@ def aten_ops_atanh( ) +@dynamo_tensorrt_converter(torch.ops.aten.atan2.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + 1: (TRTTensor,), + } +) +def aten_ops_atan2( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.atan2( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.ceil.default) def aten_ops_ceil( ctx: ConversionContext, @@ -1493,6 +1534,23 @@ def aten_ops_isinf( ) +@dynamo_tensorrt_converter(torch.ops.aten.isnan.default) +def aten_ops_isnan( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.isnan( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) @dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) def aten_ops_add( @@ -2185,7 +2243,12 @@ def aten_ops_avg_pool( @dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default) -def aten_ops_adaptive_avg_pool( +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_adaptive_avg_pool1d( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -2202,6 +2265,32 @@ def aten_ops_adaptive_avg_pool( ) +@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default) +@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default) +@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default) +@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_adaptive_avg_poolNd( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pool.adaptive_avg_poolNd( + ctx, + target, + source_ir=SourceIR.ATEN, + name=name, + input=args[0], + output_size=args[1], + ) + + def max_pool_param_validator(pool_node: Node) -> bool: dilation = args_bounds_check(pool_node.args, 4, 1) ceil_mode = args_bounds_check(pool_node.args, 5, False) @@ -2319,6 +2408,29 @@ def aten_ops_pixel_shuffle( ) +@dynamo_tensorrt_converter(torch.ops.aten.pixel_unshuffle.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_pixel_unshuffle( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.shuffle.pixel_unshuffle( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @enforce_tensor_types({0: (TRTTensor,)}) @dynamo_tensorrt_converter(torch.ops.aten.argmax.default) def aten_ops_argmax( @@ -2782,3 +2894,28 @@ def aten_ops_roll( args[1], args_bounds_check(args, 2, []), ) + + +@dynamo_tensorrt_converter(torch.ops.aten.index_select.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + 2: (TRTTensor,), + } +) +def aten_ops_index_select( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.select.index_select( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index a66a082d30..fabb037ed0 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -9,13 +9,15 @@ from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_int_int_div_trt_tensor, cast_int_or_float_to_bool, + cast_trt_tensor, get_trt_tensor, ) from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) -from torch_tensorrt.dynamo.conversion.impl.unary import sign +from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary +from torch_tensorrt.fx.converters.converter_utils import broadcast from torch_tensorrt.fx.types import TRTTensor import tensorrt as trt @@ -214,6 +216,180 @@ def remainder( return fmod2_value +def atan2( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + other: TRTTensor, +) -> TRTTensor: + """ + Perform atan2 operation on Tensor, calculating the arctangent of the quotient of input tensors. + atan2(x,y) = atan(x/y) if y > 0, + = atan(x/y) + π if x ≥ 0 and y < 0, + = atan(x/y) - π if x < 0 and y < 0, + = π/2 if x > 0 and y = 0, + = -π/2 if x < 0 and y = 0, + = 0 if x = 0 and y = 0 + + Args: + ctx: ConversionContext. + target: node target + source_ir (SourceIR): Source IR calling the function. + name: namespace for the op + input: Tensor or constant representing the dividend. + other: Tensor or constant representing the divisor. + + Returns: + A TensorRT tensor representing the result of the atan2 operation. + """ + pi_value = 3.141592653589793 + pi_tensor = get_trt_tensor(ctx, pi_value, f"{name}_pi") + + if isinstance(input, TRTTensor): + input = cast_trt_tensor(ctx, input, trt.float32, f"{name}_input") + if isinstance(other, TRTTensor): + other = cast_trt_tensor(ctx, other, trt.float32, f"{name}_other") + + input, other = broadcast(ctx.net, input, other, f"{name}_input", f"{name}_other") + + # Calculate x_zero, y_zero (whether inputs are zero) + x_zero = eq(ctx, target, source_ir, f"{name}_x_zero", input, 0) + y_zero = eq(ctx, target, source_ir, f"{name}_y_zero", other, 0) + + # Get sign of inputs + x_positive = gt(ctx, target, source_ir, f"{name}_x_positive", input, 0) + x_zero_positive = ge(ctx, target, source_ir, f"{name}_x_zero_positive", input, 0) + x_negative = lt(ctx, target, source_ir, f"{name}_x_negative", input, 0) + y_positive = gt(ctx, target, source_ir, f"{name}_y_positive", other, 0) + y_negative = lt(ctx, target, source_ir, f"{name}_y_negative", other, 0) + + # Calculate atan(x/y) + input_div_other = div( + ctx, target, source_ir, f"{name}_input_div_other", input, other + ) + atan_val = atan(ctx, target, source_ir, f"{name}_atan", input_div_other) + + # atan(x/y)+π if x≥0 and y<0, + atan_add_pi = add( + ctx, target, source_ir, f"{name}_atan_add_pi", atan_val, pi_tensor + ) + + # atan(x/y)-π if x<0 and y<0, + atan_sub_pi = sub( + ctx, target, source_ir, f"{name}_atan_sub_pi", atan_val, pi_tensor + ) + + # atan(x/y)+π if x≥0 and y<0, + atan_corrected = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_atan_corrected", + atan_add_pi, + atan_val, + logical_and( + ctx, + target, + source_ir, + f"{name}_x_zero_positive_and_y_negative", + x_zero_positive, + y_negative, + ), + ) + + # atan(x/y)-π if x<0 and y<0, + atan_corrected_2 = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_atan_corrected_2", + atan_sub_pi, + atan_corrected, + logical_and( + ctx, + target, + source_ir, + f"{name}_x_negative_and_y_negative", + x_negative, + y_negative, + ), + ) + + # atan(x/y) if y>0 + atan_output = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_atan_output", + atan_val, + atan_corrected_2, + y_positive, + ) + + # on x or y-axis + pi_over_2_tensor = get_trt_tensor( + ctx, + (pi_value / 2) * np.ones(input.shape, dtype=np.float32), + f"{name}_pi_over_2_tensor", + dtype=trt.float32, + ) + minus_pi_over_2_tensor = get_trt_tensor( + ctx, + (-pi_value / 2) * np.ones(input.shape, dtype=np.float32), + f"{name}_minus_pi_over_2_tensor", + dtype=trt.float32, + ) + zero_tensor = get_trt_tensor( + ctx, + np.zeros(input.shape, dtype=np.float32), + f"{name}_zero_tensor", + dtype=trt.float32, + ) + + # π/2 if x>0 and y=0, + pi_over_2_output = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_pi_over_2_output", + pi_over_2_tensor, + atan_output, + logical_and( + ctx, target, source_ir, f"{name}_x_zero_and_y_positive", x_positive, y_zero + ), + ) + + # -π/2 if x<0 and y=0, + minus_pi_over_2_output = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_minus_pi_over_2_output", + minus_pi_over_2_tensor, + pi_over_2_output, + logical_and( + ctx, target, source_ir, f"{name}_x_zero_and_y_negative", x_negative, y_zero + ), + ) + + # 0 if x=0 and y=0, + zero_output = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_zero_output", + zero_tensor, + minus_pi_over_2_output, + logical_and( + ctx, target, source_ir, f"{name}_x_zero_and_y_zero", y_zero, x_zero + ), + ) + + return zero_output + + def clamp( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pool.py b/py/torch_tensorrt/dynamo/conversion/impl/pool.py index 8c16f59030..c21ccc1c59 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pool.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pool.py @@ -6,7 +6,10 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple +from torch_tensorrt.dynamo.conversion.converter_utils import ( + extend_attr_to_tuple, + get_positive_dim, +) from torch_tensorrt.fx.converters.converter_utils import ( has_dynamic_shape, set_layer_name, @@ -169,3 +172,228 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int: output = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", output_list, dim=-1) return output + + +def adaptive_avg_poolNd( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + output_size: Sequence[int], +) -> TRTTensor: + input_shape = input.shape + input_rank = len(input_shape) + output_rank = len(output_size) + need_reshape_back = False + + if input_rank == output_rank + 1: # reshape to 4D/5D for TRT pooling + input = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape", input, (1, *input.shape) + ) + need_reshape_back = True + input_shape = input.shape + input_rank = len(input_shape) + + extend_len = len(output_size) + output_size = list(output_size) + original_input = input + + # repeat_interleave the input if the dim of output is larger than input + insert_axises = [] + for axis in range(1, extend_len + 1): + axis = -axis + positive_axis = get_positive_dim( + axis, input_rank + ) # convert to positive axis, which is for calculating new shapes below + input_dim = input_shape[axis] + output_dim = output_size[axis] + diff = output_dim - input_dim + if diff > 0: # the dim of output is larger than input + times = output_dim // input_dim + remainder = output_dim % input_dim + if ( + diff == 2 and remainder == 2 + ): # case 1: output_dim - input_dim == 2 and is not an integral multiple + insert_axises.append(axis) + remainder -= 1 + output_size[axis] -= 1 + + if ( + remainder + 1 == input_dim + ): # case 2: remainder + 1 == input_dim, we will repeat_interleave the whole input + remainder = 0 + times += 1 + + flags = [] # record the axis that needs to be repeated + concat_list = [] + for j in range( + input_dim + ): # iterate the input dim to see which dim needs to be repeated or not + single_elem = impl.select.select( + ctx, target, source_ir, f"{name}_select_{axis}_{j}", input, axis, j + ) + new_shape = list(single_elem.shape) + new_shape.insert(positive_axis, 1) + single_elem = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_{axis}_{j}", + single_elem, + new_shape, + ) + if remainder > 0 or j in flags: + concat_list.extend([single_elem] * (times + 1)) + remainder -= 2 + flags.append(input_dim - j - 1) + else: + concat_list.extend([single_elem] * times) + out = impl.cat.cat( + ctx, target, source_ir, f"{name}_cat_{axis}_{j}", concat_list, axis + ) + input = out + + stride = tuple( + input.shape[-extend_len + i] // output_size[i] for i in range(extend_len) + ) + kernel_size = tuple( + input.shape[-extend_len + i] - (output_size[i] - 1) * stride[i] + for i in range(extend_len) + ) + + # Don't have to pool, directly return + if all(s == 1 for s in stride) and all(k == 1 for k in kernel_size): + if need_reshape_back: # reshape back + input = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_back", + input, + (*input.shape[1:],), + ) + return input + + layer = ctx.net.add_pooling_nd( + input=input, type=trt.PoolingType.AVERAGE, window_size=kernel_size + ) + layer.stride_nd = stride + set_layer_name(layer, target, f"{name}_pooling_{extend_len}d", source_ir) + + output = layer.get_output(0) + + # For case 1, we need to split the output and insert the mid of input + for axis in insert_axises: + positive_axis = get_positive_dim(axis, input_rank) + input_dim = input_shape[axis] + output_dim = output_size[axis] + if input_dim % 2 == 1: + prev_one = impl.select.select( + ctx, + target, + source_ir, + f"{name}_select_prev_one_{axis}", + output, + axis, + output_dim // 2 - 1, + ) + extend_shape = list(prev_one.shape) + extend_shape.insert(positive_axis, 1) + prev_one = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_extend_shape_{axis}", + prev_one, + extend_shape, + ) + prev_two = impl.select.select( + ctx, + target, + source_ir, + f"{name}_select_prev_two_{axis}", + output, + axis, + output_dim // 2 - 2, + ) + prev_two = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_two_shape_reshape_{axis}", + prev_two, + extend_shape, + ) + prev_one_two_diff = impl.elementwise.sub( + ctx, + target, + source_ir, + f"{name}_prev_one_two_diff_{axis}", + prev_one, + prev_two, + ) + + mid = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_mid_{axis}", + prev_one, + prev_one_two_diff, + ) + split_output = impl.split.split( + ctx, target, source_ir, f"{name}_split_{axis}", output, 2, axis + ) + split_output.insert(1, mid) + output = impl.cat.cat( + ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis + ) + else: + mid1 = impl.select.select( + ctx, + target, + source_ir, + f"{name}_select_{axis}", + original_input, + axis, + input_dim // 2 - 1, + ) + new_shape = list(mid1.shape) + new_shape.insert(positive_axis, 1) + mid1 = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_{axis}", mid1, new_shape + ) + mid2 = impl.select.select( + ctx, + target, + source_ir, + f"{name}_select_{axis}", + original_input, + axis, + input_dim // 2, + ) + mid2 = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_{axis}", mid2, new_shape + ) + split_output = impl.split.split( + ctx, + target, + source_ir, + f"{name}_split_{axis}", + output, + [output_dim // 2, 1, output_dim // 2], + axis, + ) + split_output[1] = mid1 + split_output.insert(2, mid2) + output = impl.cat.cat( + ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis + ) + + if need_reshape_back: # reshape back + output = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_back", output, (*output.shape[1:],) + ) + + return output diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 0a9efd8485..a4507ece3e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -373,3 +373,21 @@ def index( reshape_output = reshape_layer.get_output(0) return reshape_output + + +def index_select( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + index: TRTTensor, +) -> TRTTensor: + # The axis parameter specifies the dimension along which to index. + dim = get_positive_dim(dim, len(input.shape)) + gather_layer = ctx.net.add_gather(input, index, axis=dim) + + set_layer_name(gather_layer, target, f"{name}_gather", source_ir) + + return gather_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 49ddb76e2c..1d6dd7396f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -60,3 +60,47 @@ def pixel_shuffle( permuted_tensor, shape[:-3] + (out_channels, out_height, out_width), ) + + +def pixel_unshuffle( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + downscale_factor: int, +) -> TRTTensor: + shape = input.shape + in_channels, in_height, in_width = shape[-3:] + out_channels = in_channels * (downscale_factor**2) + out_height = in_height // downscale_factor + out_width = in_width // downscale_factor + new_shape = shape[:-3] + ( + in_channels, + out_height, + downscale_factor, + out_width, + downscale_factor, + ) + reshaped_tensor = reshape( + ctx, target, source_ir, f"{name}_reshape1", input, new_shape + ) + rank = len(new_shape) + permute_shape = tuple(range(rank - 5)) + ( + rank - 5, # in_channels + rank - 3, # downscale_factor + rank - 1, # downscale_factor + rank - 4, # out_height + rank - 2, # out_width + ) + permuted_tensor = impl.permutation.permute( + ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape + ) + return reshape( + ctx, + target, + source_ir, + f"{name}_reshape2", + permuted_tensor, + shape[:-3] + (out_channels, out_height, out_width), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 554640ea5a..9f2ad07612 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -44,6 +44,32 @@ def exp( ) +def expm1( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + """ + Computes e^x - 1 for each element of the input tensor. + + Args: + ctx (ConversionContext): TensorRT ConversionContext object. + target (Target): fx node target. + source_ir (SourceIR): Source IR calling the function + name (str): Name of the fx node with optional suffix. + input_val (TRTTensor): The input tensor. + + Returns: + TRTTensor: A TensorRT tensor represent the result of expm1 operator. + """ + # Compute e^x for each element of the input tensor + exp_result = exp(ctx, target, source_ir, f"{name}_exp", input_val) + + return impl.elementwise.sub(ctx, target, source_ir, f"{name}_sub", exp_result, 1) + + def log( ctx: ConversionContext, target: Target, @@ -508,3 +534,23 @@ def scalar_tensor( identity_layer = ctx.net.add_identity(tensor) set_layer_name(identity_layer, target, name, source_ir) return identity_layer.get_output(0) + + +def isnan( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + # False for NaN elements since NaN is not equal to anything, including itself. + equality_result = impl.elementwise.eq( + ctx, target, source_ir, f"{name}_eq_nan", input, input + ) + + # Invert equality_result to get a mask where NaN values are marked as True. + nan_values_mask = logical_not( + ctx, target, source_ir, f"{name}_logical_not", equality_result + ) + + return nan_values_mask diff --git a/py/torch_tensorrt/ts/_compiler.py b/py/torch_tensorrt/ts/_compiler.py index 802976cb27..3be9b7a4c2 100644 --- a/py/torch_tensorrt/ts/_compiler.py +++ b/py/torch_tensorrt/ts/_compiler.py @@ -9,6 +9,8 @@ from torch_tensorrt._Input import Input from torch_tensorrt.ts._compile_spec import _parse_compile_spec, _parse_device +from torch_tensorrt import _enums + def compile( module: torch.jit.ScriptModule, @@ -137,6 +139,9 @@ def compile( "capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels "num_avg_timing_iters": num_avg_timing_iters, # Number of averaging timing iterations used to select kernels "workspace_size": workspace_size, # Maximum size of workspace given to TensorRT + "dla_sram_size": dla_sram_size, + "dla_local_dram_size": dla_local_dram_size, + "dla_global_dram_size": dla_global_dram_size, "calibrator": calibrator, "truncate_long_and_double": truncate_long_and_double, "torch_fallback": { diff --git a/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py b/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py index 3d48409631..b8dc1e1968 100644 --- a/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py +++ b/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py @@ -76,10 +76,225 @@ def forward(self, x): self.run_test( TestModule(), inputs, - # use_dynamo_tracer=True, enable_passes=True, ) + @parameterized.expand( + [ + # 3d input + ( + (1, 2, 3), + (1, 2), + ), + ( + (1, 2, 3), + (2, 3), + ), + ( + (1, 2, 8), + (4, 4), + ), + ( + (2, 3, 2), + (5, 3), + ), + ( + (2, 8, 16), + (4, 8), + ), + ( + (2, 8, 16), + (8, 8), + ), + # 4d input + ( + (1, 1, 4, 3), + (4, 8), + ), + ( + (3, 2, 3, 2), + (1, 5), + ), + ( + (4, 2, 2, 8), + (5, 2), + ), + ( + (3, 2, 3, 3), + (6, 4), + ), + ( + (1, 2, 3, 2), + (2, 2), + ), + ( + (2, 2, 32, 16), + (8, 8), + ), + ( + (2, 2, 32, 32), + (31, 16), + ), + ( + (1, 1, 64, 64), + (64, 16), + ), + ] + ) + def test_adaptive_avg_pool2d( + self, + input_shape, + output_size, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.adaptive_avg_pool2d.default(x, output_size) + + inputs = [torch.randn(input_shape)] + self.run_test( + TestModule(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ((1, 2),), + ] + ) + def test_adaptive_avg_pool2d_dynamic(self, output_size): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + out = torch.ops.aten.adaptive_avg_pool2d.default(x, output_size) + return out + + input_specs = [ + Input( + shape=(-1, 2, 3, 2), + dtype=torch.float32, + shape_ranges=[((1, 2, 3, 2), (3, 2, 3, 2), (10, 2, 3, 2))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + ) + + @parameterized.expand( + [ + # 4d input + ( + (1, 1, 4, 3), + (4, 8, 2), + ), + ( + (1, 2, 3, 1), + (1, 5, 2), + ), + ( + (1, 2, 3, 2), + (1, 5, 3), + ), + ( + (4, 2, 2, 8), + (8, 5, 2), + ), + ( + (3, 2, 3, 3), + (6, 4, 1), + ), + ( + (1, 2, 3, 2), + (2, 2, 2), + ), + ( + (2, 2, 32, 16), + (8, 8, 8), + ), + ( + (2, 2, 32, 32), + (31, 16, 64), + ), + ( + (1, 1, 64, 64), + (64, 16, 1), + ), + # 5d input + ( + (1, 1, 1, 4, 3), + (4, 8, 2), + ), + ( + (4, 3, 1, 2, 3), + (2, 4, 6), + ), + ( + (1, 4, 2, 2, 2), + (5, 2, 4), + ), + ( + (3, 2, 3, 3, 2), + (6, 4, 1), + ), + ( + (2, 2, 32, 16, 8), + (8, 8, 8), + ), + ( + (2, 2, 32, 32, 32), + (31, 16, 64), + ), + ( + (1, 1, 64, 64, 64), + (64, 16, 1), + ), + ] + ) + def test_adaptive_avgpool3d( + self, + input_shape, + output_size, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.adaptive_avg_pool3d.default(x, output_size) + + inputs = [torch.randn(input_shape)] + self.run_test( + TestModule(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ((1, 2, 3),), + ] + ) + def test_adaptive_avg_pool3d_dynamic(self, output_size): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + out = torch.ops.aten.adaptive_avg_pool3d.default(x, output_size) + return out + + input_specs = [ + Input( + shape=(-1, 2, 3, 1, 4), + dtype=torch.float32, + shape_ranges=[((1, 2, 3, 1, 4), (3, 2, 3, 1, 4), (10, 2, 3, 1, 4))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_atan2_aten.py b/tests/py/dynamo/conversion/test_atan2_aten.py new file mode 100644 index 0000000000..550ade2970 --- /dev/null +++ b/tests/py/dynamo/conversion/test_atan2_aten.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestAtan2Converter(DispatchTestCase): + @parameterized.expand( + [ + ((10,), torch.float), + ((1, 20), torch.float), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), + ] + ) + def test_atan2_lhs_const(self, input_shape, dtype): + class atan2(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.atan2.default(lhs_val, rhs_val) + + inputs = [ + torch.randn(input_shape, dtype=dtype), + torch.rand(1), + ] + + self.run_test( + atan2(), + inputs, + ) + + @parameterized.expand( + [ + ((10,), torch.float), + ((1, 20), torch.float), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), + ] + ) + def test_atan2_rhs_const(self, input_shape, dtype): + class atan2(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.atan2.default(lhs_val, rhs_val) + + inputs = [ + torch.rand(1), + torch.randn(input_shape, dtype=dtype), + ] + + self.run_test( + atan2(), + inputs, + ) + + @parameterized.expand( + [ + ((10,), torch.float), + ((1, 20), torch.float), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), + ] + ) + def test_atan2_float(self, input_shape, dtype): + class atan2(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.atan2.default(lhs_val, rhs_val) + + inputs = [ + torch.randn(input_shape, dtype=dtype), + torch.randn(input_shape, dtype=dtype), + ] + + self.run_test( + atan2(), + inputs, + ) + + @parameterized.expand( + [ + ((50,), torch.int, -5, 5), + ((1, 20), torch.int32, -5, 5), + ((2, 3, 4), torch.int, -5, 5), + ] + ) + def test_atan2_int(self, input_shape, dtype, low, high): + class atan2(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.atan2.default(lhs_val, rhs_val) + + inputs = [ + torch.randint(low, high, input_shape, dtype=dtype), + torch.randint(low, high, input_shape, dtype=dtype), + ] + self.run_test( + atan2(), + inputs, + ) + + @parameterized.expand( + [ + (torch.float, 0.0, 0.0), + (torch.float, 0.0, torch.rand(1)), + (torch.float, torch.rand(1), 0.0), + (torch.int, 0, 0), + (torch.int, 0, torch.randint(-5, 5, (1,))), + (torch.int, torch.randint(1, 10, (1,)), 0), + ] + ) + def test_atan2_zero(self, dtype, x_val, y_val): + class Atan2(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.atan2.default(lhs_val, rhs_val) + + if isinstance(x_val, torch.Tensor): + x_val = x_val.item() + if isinstance(y_val, torch.Tensor): + y_val = y_val.item() + + inputs = [ + torch.tensor([x_val], dtype=dtype), + torch.tensor([y_val], dtype=dtype), + ] + + self.run_test( + Atan2(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_expm1_aten.py b/tests/py/dynamo/conversion/test_expm1_aten.py new file mode 100644 index 0000000000..e695a27475 --- /dev/null +++ b/tests/py/dynamo/conversion/test_expm1_aten.py @@ -0,0 +1,69 @@ +from math import exp + +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestExpConverter(DispatchTestCase): + @parameterized.expand( + [ + ((10,), torch.float), + ((1, 20), torch.float), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), + ] + ) + def test_expm1_float(self, input_shape, dtype): + class expm1(nn.Module): + def forward(self, input): + return torch.ops.aten.expm1.default(input) + + inputs = [torch.randn(input_shape, dtype=dtype)] + self.run_test( + expm1(), + inputs, + ) + + @parameterized.expand( + [ + (torch.full((1, 20), exp(1), dtype=torch.float),), + (torch.full((2, 3, 4), exp(2), dtype=torch.float),), + (torch.full((2, 3, 4, 5), exp(3), dtype=torch.float),), + ] + ) + def test_expm1_exp_const_float(self, data): + class expm1(nn.Module): + def forward(self, input): + return torch.ops.aten.expm1.default(input) + + inputs = [data] + self.run_test( + expm1(), + inputs, + ) + + @parameterized.expand( + [ + ((10,), torch.int, 0, 5), + ((1, 20), torch.int32, -10, 10), + ((2, 3, 4), torch.int, -5, 5), + ] + ) + def test_exp_int(self, input_shape, dtype, low, high): + class expm1(nn.Module): + def forward(self, input): + return torch.ops.aten.expm1.default(input) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + expm1(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_index_select_aten.py b/tests/py/dynamo/conversion/test_index_select_aten.py new file mode 100644 index 0000000000..83eaedb944 --- /dev/null +++ b/tests/py/dynamo/conversion/test_index_select_aten.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestIndexSelectConverter(DispatchTestCase): + @parameterized.expand( + [ + ("1d_input", (10,), 0, (1,)), + ("2d_input_dim_0", (10, 3), 0, (0, 2)), + ("2d_input_dim_1", (5, 10), 1, (1, 2, 3)), + ("2d_input_dim_-2", (5, 10), -2, (1, 2, 3)), + ("3d_input_dim_0", (10, 5, 10), 0, (0, 5)), + ("3d_input_dim_2", (10, 5, 10), 2, (3, 3, 4)), + ("3d_input_dim_-1", (10, 5, 10), -1, (3, 3, 4)), + ("3d_input_dim_-3", (10, 5, 10), -3, (5, 3, 4)), + ] + ) + def test_index_select(self, _, source_shape, dim, indices_val): + class TestIndexSelect(torch.nn.Module): + def forward(self, source_tensor, indices_tensor): + return torch.ops.aten.index_select.default( + source_tensor, dim, indices_tensor + ) + + input = [ + torch.randn(*source_shape, dtype=torch.float32), + torch.tensor([*indices_val], dtype=torch.int32), + ] + + self.run_test( + TestIndexSelect(), + input, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_isnan_aten.py b/tests/py/dynamo/conversion/test_isnan_aten.py new file mode 100644 index 0000000000..5651b0ca25 --- /dev/null +++ b/tests/py/dynamo/conversion/test_isnan_aten.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestIsNanConverter(DispatchTestCase): + @parameterized.expand( + [ + ( + torch.tensor( + [ + 1.23, + float("nan"), + -4.56, + float("inf"), + float("-inf"), + -100.0, + float("nan"), + 0.13, + -0.13, + 3.14159265, + ] + ), + ), + ] + ) + def test_isnan_float(self, data): + class isnan(nn.Module): + def forward(self, input): + return torch.ops.aten.isnan.default(input) + + inputs = [data] + self.run_test( + isnan(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + (torch.full((2, 2), float("nan"), dtype=torch.float32),), + (torch.full((3, 10, 5), float("nan"), dtype=torch.float32),), + (torch.randn((5, 10, 5), dtype=torch.float32),), + ] + ) + def test_isnan_dim(self, data): + class isnan(nn.Module): + def forward(self, input): + return torch.ops.aten.isnan.default(input) + + inputs = [data] + self.run_test( + isnan(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ((10,), torch.int, 0, 5), + ((1, 20), torch.int32, -10, 10), + ((2, 3, 4), torch.int, -5, 5), + ] + ) + def test_isnan_int(self, input_shape, dtype, low, high): + class isnan(nn.Module): + def forward(self, input): + return torch.ops.aten.isnan.default(input) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + isnan(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py b/tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py new file mode 100644 index 0000000000..fb93e68499 --- /dev/null +++ b/tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py @@ -0,0 +1,29 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestPixelUnshuffleConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1, 1, 1), 1), + ((1, 1, 12, 12), 3), + ((2, 3, 4, 25, 30), 5), + ] + ) + def test_pixel_unshuffle(self, shape, downscale_factor): + class PixelUnshuffle(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.pixel_unshuffle.default(x, downscale_factor) + + inputs = [torch.randn(shape)] + self.run_test( + PixelUnshuffle(), + inputs, + ) + + +if __name__ == "__main__": + run_tests()