-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
slice_scatter decomposition #2519
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py 2023-12-06 09:08:13.895012+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py 2023-12-06 09:11:58.776404+00:00
@@ -186,21 +186,22 @@
src_dim = list(src_tensor.shape())
src_dim[dim] = torch.floor_divide(end - start, step)
src = torch.expand(src, src_dim)
- if (start == 0 and end == dim_size and step == 0):
+ if start == 0 and end == dim_size and step == 0:
return input_tensor
mask = []
if start != 0:
mask.append(torch.ge(input_tensor_shape, start))
if end != dim_size:
mask.append(torch.ge(input_tensor_shape, end))
if step != 1:
mask.append(torch.eq(src_dim, 0))
src_val = torch.masked(mask, src_dim, 0)
- return torch.where(mask, src_val,input_tensor)
+ return torch.where(mask, src_val, input_tensor)
+
def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
if enable_experimental_decompositions:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2023-12-06 09:08:13.915012+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2023-12-06 09:12:02.062349+00:00
@@ -418,11 +418,10 @@
0,
DECIMALS_OF_AGREEMENT,
f"MaxPool3d TRT outputs don't match with the original model.",
)
-
def test_lowering_select_scatter_module(self):
class selectScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@@ -435,11 +434,10 @@
torch.ops.aten.lt.default,
torch.ops.aten.lt.default,
torch.ops.aten.expand.default,
torch.ops.aten.eq.default,
torch.ops.aten.where.default,
-
}
unexpected_ops = {torch.ops.aten.select_scatter}
inputs = [torch.randn(2, 2), torch.ones(2)]
@@ -485,7 +483,8 @@
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)
+
if __name__ == "__main__":
run_tests()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2023-12-19 18:39:51.699972+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2023-12-19 18:41:49.917712+00:00
@@ -418,11 +418,10 @@
0,
DECIMALS_OF_AGREEMENT,
f"MaxPool3d TRT outputs don't match with the original model.",
)
-
def test_lowering_select_scatter_module(self):
class selectScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@@ -435,11 +434,10 @@
torch.ops.aten.lt.default,
torch.ops.aten.lt.default,
torch.ops.aten.expand.default,
torch.ops.aten.eq.default,
torch.ops.aten.where.default,
-
}
unexpected_ops = {torch.ops.aten.select_scatter}
inputs = [torch.randn(2, 2), torch.ones(2)]
@@ -485,7 +483,8 @@
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)
+
if __name__ == "__main__":
run_tests()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
if start is not None and start < 0: | ||
start = start + dim_size | ||
if end is not None and end < 0: | ||
end = end + dim_size | ||
if start is None: | ||
start = 0 | ||
if end is None: | ||
end = dim_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider switching to use get_positive_dim
utility.
|
||
if start == 0 and end == dim_size and step == 0: | ||
return input_tensor | ||
index_tensor = np.arange(start, end_dim, step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work with torch.arange
?
end = dim_size | ||
|
||
src_dim = src_tensor.shape | ||
step_dim = torch.floor_divide(end - start, step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(end - start) // step
if step_dim > src_dim[dim]: | ||
end_dim = src_dim[dim] | ||
else: | ||
indices = torch.Tensor(np.arange(0, step_dim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.arange
unbind_source_tensors = torch.unbind(src, dim) | ||
unbind_source_tensors_list = list(unbind_source_tensors) | ||
|
||
for i, index in enumerate(index_tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
range(start, end_dim, step)
instead of index tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See other comment
afeba1e
to
a0b031f
Compare
|
||
if start == 0 and end == dim_size and step == 0: | ||
return input_tensor | ||
index_tensor = torch.arange(start, end_dim, step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this tensor needed; could it be replaced with range
, as below?
unbind_source_tensors = torch.unbind(src, dim) | ||
unbind_source_tensors_list = list(unbind_source_tensors) | ||
|
||
for i, index in enumerate(index_tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See other comment
if step_dim > src_dim[dim]: | ||
end_dim = src_dim[dim] | ||
else: | ||
indices = torch.Tensor(torch.arange(0, step_dim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.arange
should already return a Tensor, so the cast should not be needed
indices = indices.to(torch.int32) | ||
src = torch.index_select(src, dim, indices) | ||
|
||
if start == 0 and end == dim_size and step == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be step == 1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be step == 0
since step == 1
would result in tensors being inserted in source tensor at step 1 interval.
end_dim = src_dim[dim] | ||
else: | ||
indices = torch.arange(0, step_dim) | ||
indices = indices.to(torch.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the indices are int64
, it is fine to leave them as-is and not change the data type, since later operators may expect or require int64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't this be required for the subsequent torch.index_select
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does torch
expect int64
for the indices in index_select
, or TensorRT? If it is TensorRT, then there is no need to perform the cast, because the outputs of the above operation will already have been handled in the TRTInterpreter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes torch would expect int64 input for the indices. Since this is a constant, I think TRTInterpretor should be able to handle it. Yes I will remove this.
cec6a4e
to
8fb696e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py 2024-02-20 19:59:59.374321+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py 2024-02-20 20:01:49.660284+00:00
@@ -1,10 +1,11 @@
"""
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
+
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py 2024-02-20 19:59:59.382321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py 2024-02-20 20:01:49.759276+00:00
@@ -30,16 +30,18 @@
gpu_id (int): Device ID for target GPU
dla_core (int): Core ID for target DLA core
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
"""
- device_type: Optional[
- trt.DeviceType
- ] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+ device_type: Optional[trt.DeviceType] = (
+ None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+ )
gpu_id: int = -1 #: Device ID for target GPU
dla_core: int = -1 #: Core ID for target DLA core
- allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
+ allow_gpu_fallback: bool = (
+ False #: Whether falling back to GPU if DLA cannot support an op should be allowed
+ )
def __init__(self, *args: Any, **kwargs: Any):
"""__init__ Method for torch_tensorrt.Device
Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py 2024-02-20 19:59:59.382321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py 2024-02-20 20:01:49.959821+00:00
@@ -26,16 +26,16 @@
class _ShapeMode(Enum):
STATIC = 0
DYNAMIC = 1
- shape_mode: Optional[
- _ShapeMode
- ] = None #: Is input statically or dynamically shaped
- shape: Optional[
- Tuple[int, ...] | Dict[str, Tuple[int, ...]]
- ] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+ shape_mode: Optional[_ShapeMode] = (
+ None #: Is input statically or dynamically shaped
+ )
+ shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+ None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+ )
dtype: _enums.dtype = (
_enums.dtype.unknown
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
_explicit_set_dtype: bool = False
format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py 2024-02-20 19:59:59.382321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py 2024-02-20 20:01:50.013227+00:00
@@ -212,13 +212,13 @@
"precision": precision,
"debug": debug,
"device": device,
"workspace_size": workspace_size,
"min_block_size": min_block_size,
- "torch_executed_ops": torch_executed_ops
- if torch_executed_ops is not None
- else set(),
+ "torch_executed_ops": (
+ torch_executed_ops if torch_executed_ops is not None else set()
+ ),
"pass_through_build_failures": pass_through_build_failures,
"max_aux_streams": max_aux_streams,
"version_compatible": version_compatible,
"optimization_level": optimization_level,
"use_python_runtime": use_python_runtime,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-02-20 19:59:59.382321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-02-20 20:01:50.235895+00:00
@@ -26,13 +26,13 @@
from packaging import version
_LOGGER: logging.Logger = logging.getLogger(__name__)
-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
- Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+ Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)
class UnsupportedOperatorException(RuntimeError):
pass
@@ -90,13 +90,13 @@
self.input_specs_iter = 0
self._cur_node_name: Optional[str] = None
self._cur_node: Optional[torch.fx.Node] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
- self._itensor_to_tensor_meta: Dict[
- trt.tensorrt.ITensor, TensorMetadata
- ] = dict()
+ self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+ dict()
+ )
self.compilation_settings = compilation_settings
# Data types for TRT Module output Tensors
self.output_dtypes = output_dtypes
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2024-02-20 19:59:59.382321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2024-02-20 20:01:50.278485+00:00
@@ -322,17 +322,15 @@
else:
raise AssertionError(f"Cannot convert {input_val} to TRT constant")
@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
- ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...
@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
- ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...
def get_positive_dim(
dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py 2024-02-20 19:59:59.386321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py 2024-02-20 20:01:50.623768+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket
aten = torch.ops.aten
-_core_aten_decompositions: Dict[
- OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+ core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._adaptive_avg_pool2d_backward,
aten.addcdiv,
aten.addcdiv_,
aten.addcmul,
@@ -179,13 +179,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._softmax.default,
}
-ENABLED_TORCH_DECOMPOSITIONS: Dict[
- OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+ get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}
def check_decomp_set_invariants() -> None:
"""Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py 2024-02-20 19:59:59.386321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py 2024-02-20 20:01:50.628829+00:00
@@ -20,16 +20,14 @@
logger.debug(f"Graph after lowering linear:\n{gm.graph}")
return gm
-def linear_replacement() -> (
- Tuple[
- torch.fx.GraphModule,
- Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
- ]
-):
+def linear_replacement() -> Tuple[
+ torch.fx.GraphModule,
+ Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
"""Constructs the original and replacement functions for linear"""
# Original graph
def orig(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py 2024-02-20 19:59:59.386321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py 2024-02-20 20:01:50.665412+00:00
@@ -20,16 +20,14 @@
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
return gm
-def view_replacement() -> (
- Tuple[
- torch.fx.GraphModule,
- Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
- ]
-):
+def view_replacement() -> Tuple[
+ torch.fx.GraphModule,
+ Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
"""Constructs the original and replacement functions for view"""
# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2024-02-20 19:59:59.386321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2024-02-20 20:01:50.681914+00:00
@@ -58,16 +58,14 @@
logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")
return gm
-def scaled_dot_product_attention_replacement() -> (
- Tuple[
- Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
- Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
- ]
-):
+def scaled_dot_product_attention_replacement() -> Tuple[
+ Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
+ Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
"""Constructs the original and replacement functions for efficient attention"""
# Efficient Attention original graph
def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py 2024-02-20 19:59:59.386321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py 2024-02-20 20:01:50.959439+00:00
@@ -99,25 +99,29 @@
self.engine.get_binding_dtype(idx), Frameworks.TORCH
)
for idx in self.output_binding_indices_in_order
]
self.output_shapes = [
- tuple(self.engine.get_binding_shape(idx))
- if self.engine.has_implicit_batch_dimension
- else tuple()
+ (
+ tuple(self.engine.get_binding_shape(idx))
+ if self.engine.has_implicit_batch_dimension
+ else tuple()
+ )
for idx in self.output_binding_indices_in_order
]
self.hidden_output_dtypes = [
unified_dtype_converter(
self.engine.get_binding_dtype(idx), Frameworks.TORCH
)
for idx in self.hidden_output_binding_indices_in_order
]
self.hidden_output_shapes = [
- tuple(self.engine.get_binding_shape(idx))
- if self.engine.has_implicit_batch_dimension
- else tuple()
+ (
+ tuple(self.engine.get_binding_shape(idx))
+ if self.engine.has_implicit_batch_dimension
+ else tuple()
+ )
for idx in self.hidden_output_binding_indices_in_order
]
def _check_initialized(self) -> None:
if not self.initialized:
@@ -165,13 +169,15 @@
self.__dict__.update(state)
if self.engine:
self.context = self.engine.create_execution_context()
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
- with torch.autograd.profiler.record_function(
- "PythonTorchTensorRTModule:Forward"
- ) if self.profiling_enabled else nullcontext():
+ with (
+ torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
+ if self.profiling_enabled
+ else nullcontext()
+ ):
self._check_initialized()
# If in safe mode, check at each iteration for for whether a switch is required
if (
torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
@@ -198,13 +204,17 @@
torch.cuda.set_device(device_id)
inputs = tuple([tensor.to(device) for tensor in inputs])
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
- with torch.autograd.profiler.record_function(
- "PythonTorchTensorRTModule:ProcessInputs"
- ) if self.profiling_enabled else nullcontext():
+ with (
+ torch.autograd.profiler.record_function(
+ "PythonTorchTensorRTModule:ProcessInputs"
+ )
+ if self.profiling_enabled
+ else nullcontext()
+ ):
assert len(inputs) == len(
self.input_names
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
@@ -237,13 +247,17 @@
self.context.set_binding_shape(
idx, tuple(contiguous_inputs[i].shape)
)
- with torch.autograd.profiler.record_function(
- "PythonTorchTensorRTModule:ProcessOutputs"
- ) if self.profiling_enabled else nullcontext():
+ with (
+ torch.autograd.profiler.record_function(
+ "PythonTorchTensorRTModule:ProcessOutputs"
+ )
+ if self.profiling_enabled
+ else nullcontext()
+ ):
# create output tensors
outputs: List[torch.Tensor] = []
for i, idx in enumerate(self.output_binding_indices_in_order):
shape = tuple(self.context.get_binding_shape(idx))
@@ -264,13 +278,17 @@
dtype=self.hidden_output_dtypes[i],
device=torch.cuda.current_device(),
)
bindings[idx] = output.data_ptr()
- with torch.autograd.profiler.record_function(
- "PythonTorchTensorRTModule:TensorRTRuntime"
- ) if self.profiling_enabled else nullcontext():
+ with (
+ torch.autograd.profiler.record_function(
+ "PythonTorchTensorRTModule:TensorRTRuntime"
+ )
+ if self.profiling_enabled
+ else nullcontext()
+ ):
self.context.execute_async_v2(
bindings, torch.cuda.current_stream().cuda_stream
)
if len(outputs) == 1:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py 2024-02-20 19:59:59.390321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py 2024-02-20 20:01:51.233651+00:00
@@ -315,25 +315,21 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
kwargs_new = {
"input": args[0],
"kernel_size": args[1],
- "stride": args[2]
- if len(args) > 2
- else (None, None)
- if len(args[1]) == 2
- else (None, None, None),
- "padding": args[3]
- if len(args) > 3
- else (0, 0)
- if len(args[1]) == 2
- else (0, 0, 0),
- "dilation": args[4]
- if len(args) > 4
- else (1, 1)
- if len(args[1]) == 2
- else (1, 1, 1),
+ "stride": (
+ args[2]
+ if len(args) > 2
+ else (None, None) if len(args[1]) == 2 else (None, None, None)
+ ),
+ "padding": (
+ args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
+ ),
+ "dilation": (
+ args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
+ ),
"ceil_mode": args[5] if len(args) > 5 else False,
}
return acc_ops_converters.acc_ops_max_poolnd(
network, target, None, kwargs_new, name
)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py 2024-02-20 19:59:59.390321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py 2024-02-20 20:01:51.283354+00:00
@@ -124,25 +124,29 @@
interpreter = TRTInterpreter(
mod,
input_specs=self.lower_setting.input_specs,
explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
explicit_precision=self.lower_setting.explicit_precision,
- logger_level=trt.Logger.VERBOSE
- if self.lower_setting.verbose_log
- else trt.Logger.WARNING,
+ logger_level=(
+ trt.Logger.VERBOSE
+ if self.lower_setting.verbose_log
+ else trt.Logger.WARNING
+ ),
)
interp_result: TRTInterpreterResult = interpreter.run(
max_batch_size=self.lower_setting.max_batch_size,
max_workspace_size=self.lower_setting.max_workspace_size,
lower_precision=self.lower_setting.lower_precision,
strict_type_constraints=self.lower_setting.strict_type_constraints,
algorithm_selector=algo_selector,
timing_cache=cache_data,
- profiling_verbosity=trt.ProfilingVerbosity.DETAILED
- if self.lower_setting.verbose_profile
- else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
+ profiling_verbosity=(
+ trt.ProfilingVerbosity.DETAILED
+ if self.lower_setting.verbose_profile
+ else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
+ ),
tactic_sources=self.lower_setting.tactic_sources,
)
# Update timing cache file if needed
timing_cache = interp_result.serialized_cache
@@ -295,14 +299,12 @@
module.half()
# A custom conversion function can be passed to the lowerer to
# handle inputs with custom types. By default, just handle
# tensors and NoneType.
if fp16_conversion_fn is None:
- conversion_fn = (
- lambda x: x.half()
- if x is not None and x.dtype == torch.float32
- else x
+ conversion_fn = lambda x: (
+ x.half() if x is not None and x.dtype == torch.float32 else x
)
else:
conversion_fn = fp16_conversion_fn
inputs = tuple(conversion_fn(x) for x in inputs)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py 2024-02-20 19:59:59.390321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py 2024-02-20 20:01:51.328023+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks
_LOGGER: logging.Logger = logging.getLogger(__name__)
-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
- Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+ Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)
class TRTInterpreterResult(NamedTuple):
engine: Any
input_names: Sequence[str]
@@ -73,13 +73,13 @@
self.input_specs_iter = 0
self.validate_input_specs()
self._cur_node_name: Optional[str] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
- self._itensor_to_tensor_meta: Dict[
- trt.tensorrt.ITensor, TensorMetadata
- ] = dict()
+ self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+ dict()
+ )
def validate_input_specs(self):
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
if not self.network.has_implicit_batch_dimension:
assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py 2024-02-20 19:59:59.390321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py 2024-02-20 20:01:51.545029+00:00
@@ -194,13 +194,15 @@
lowering_start_time = datetime.datetime.now()
self.lower_setting.input_specs = generate_input_specs(
submod_inputs,
self.lower_setting,
- additional_submodule_inputs[submod_name]
- if additional_submodule_inputs
- else None,
+ (
+ additional_submodule_inputs[submod_name]
+ if additional_submodule_inputs
+ else None
+ ),
)
lowered_module = self._lower_func(
submod, submod_inputs, self.lower_setting, submod_name
)
setattr(split_result.split_module, submod_name, lowered_module)
@@ -234,13 +236,15 @@
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
_LOGGER.info(f"ACC submodule graph: {submod.graph}")
lowering_start_time = datetime.datetime.now()
self.lower_setting.additional_inputs = (
- additional_submodule_inputs[submod_name]
- if additional_submodule_inputs
- else None,
+ (
+ additional_submodule_inputs[submod_name]
+ if additional_submodule_inputs
+ else None
+ ),
)
lowered_module = self._lower_func(
submod, submod_inputs, self.lower_setting, submod_name
)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py 2024-02-20 19:59:59.390321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py 2024-02-20 20:01:51.722875+00:00
@@ -193,13 +193,11 @@
kwargs2 = {"equal_nan": True}
if rtol:
kwargs2["rtol"] = rtol
if atol:
kwargs2["atol"] = atol
- kwargs2[
- "msg"
- ] = (
+ kwargs2["msg"] = (
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
)
# If tensors are on different devices, make sure to compare
# their copies that are on the same device.
if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py 2024-02-20 19:59:59.390321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py 2024-02-20 20:01:51.782883+00:00
@@ -536,13 +536,13 @@
reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
maybe_reshape
)
if not reshape_batch_size:
continue
- reshape_batch_size_inferred_source: Optional[
- fx.Node
- ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+ reshape_batch_size_inferred_source: Optional[fx.Node] = (
+ get_reshape_batch_size_inferred_source(reshape_batch_size)
+ )
if not reshape_batch_size_inferred_source:
continue
reshape_input: fx.Node = maybe_reshape.kwargs["input"]
if reshape_input == reshape_batch_size_inferred_source:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py 2024-02-20 19:59:59.394321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py 2024-02-20 20:01:52.206806+00:00
@@ -21,13 +21,15 @@
inputs = [torch.randn(1, 10)]
self.run_test(
Split(),
inputs,
expected_ops={
- acc_ops.split
- if isinstance(split_size_or_sections, int)
- else acc_ops.slice_tensor
+ (
+ acc_ops.split
+ if isinstance(split_size_or_sections, int)
+ else acc_ops.slice_tensor
+ )
},
test_explicit_batch_dim=False,
)
@parameterized.expand(
@@ -68,13 +70,15 @@
]
self.run_test_with_dynamic_shape(
Split(),
input_specs,
expected_ops={
- acc_ops.split
- if isinstance(split_size_or_sections, int)
- else acc_ops.slice_tensor
+ (
+ acc_ops.split
+ if isinstance(split_size_or_sections, int)
+ else acc_ops.slice_tensor
+ )
},
)
# Testing with (-1, -1, -1) results into following error:
# AssertionError: Can't chunk on dynamic shape dimension!
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py 2024-02-20 19:59:59.394321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py 2024-02-20 20:01:52.903172+00:00
@@ -152,13 +152,13 @@
mod.eval()
if len(expected_ops):
self.assert_has_op(mod, expected_ops)
interpreter_result = interpreter.run(
- lower_precision=LowerPrecision.FP16
- if fp16_mode
- else LowerPrecision.FP32
+ lower_precision=(
+ LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
+ )
)
trt_mod = TRTModule(
interpreter_result.engine,
interpreter_result.input_names,
interpreter_result.output_names,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py 2024-02-20 19:59:59.398321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py 2024-02-20 20:01:53.269384+00:00
@@ -67,25 +67,29 @@
self.engine.get_binding_dtype(idx), Frameworks.TORCH
)
for idx in self.output_binding_indices_in_order
]
self.output_shapes = [
- tuple(self.engine.get_binding_shape(idx))
- if self.engine.has_implicit_batch_dimension
- else tuple()
+ (
+ tuple(self.engine.get_binding_shape(idx))
+ if self.engine.has_implicit_batch_dimension
+ else tuple()
+ )
for idx in self.output_binding_indices_in_order
]
self.hidden_output_dtypes: Sequence[torch.dtype] = [
unified_dtype_converter(
self.engine.get_binding_dtype(idx), Frameworks.TORCH
)
for idx in self.hidden_output_binding_indices_in_order
]
self.hidden_output_shapes = [
- tuple(self.engine.get_binding_shape(idx))
- if self.engine.has_implicit_batch_dimension
- else tuple()
+ (
+ tuple(self.engine.get_binding_shape(idx))
+ if self.engine.has_implicit_batch_dimension
+ else tuple()
+ )
for idx in self.hidden_output_binding_indices_in_order
]
def _check_initialized(self):
if not self.initialized:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py 2024-02-20 19:59:59.398321+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py 2024-02-20 20:01:53.546949+00:00
@@ -404,13 +404,13 @@
"inputs": inputs if inputs is not None else [],
# "input_signature": input_signature,
"device": device,
"disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
"sparse_weights": sparse_weights, # Enable sparsity for convolution and fully connected layers.
- "enabled_precisions": enabled_precisions
- if enabled_precisions is not None
- else set(), # Enabling FP16 kernels
+ "enabled_precisions": (
+ enabled_precisions if enabled_precisions is not None else set()
+ ), # Enabling FP16 kernels
"refit": refit, # enable refit
"debug": debug, # enable debuggable engine
"capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels
"num_avg_timing_iters": num_avg_timing_iters, # Number of averaging timing iterations used to select kernels
"workspace_size": workspace_size, # Maximum size of workspace given to TensorRT
Monitoring the CI to see if this error comes in the test-
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py 2024-02-27 08:54:58.869787+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py 2024-02-27 08:56:47.352375+00:00
@@ -187,11 +187,11 @@
step_dim = (end - start) // step
end_dim = end
if step_dim > src_dim[dim]:
end_dim = src_dim[dim]
else:
- #In this case src first step_dim need to be selected
+ # In this case src first step_dim need to be selected
indices = torch.Tensor(torch.arange(0, step_dim))
indices = indices.to(torch.int32)
src = torch.index_select(src_tensor, dim, indices)
if start == 0 and end == dim_size and step == 0:
13bbdab
to
f7e0642
Compare
df7d401
to
1bd061b
Compare
dim_size = input_tensor.shape[dim] | ||
start = get_positive_dim(start, input_tensor.shape[dim]) | ||
if end is None: | ||
end = dim_size | ||
end = get_positive_dim(end, input_tensor.shape[dim]) | ||
if step is None: | ||
step = 1 | ||
|
||
src_dim = src_tensor.shape | ||
# step == 0 is not a valid torch case | ||
# also src_dim should be equal to slice dimension | ||
|
||
if start == 0 and end == dim_size and step == 1: | ||
return src_tensor | ||
|
||
cat_tensors = [] | ||
index_tensor_shape = [] | ||
for i, src_each_dim in enumerate(list(src_dim)): | ||
if i != dim: | ||
index_tensor_shape.append(src_each_dim) | ||
for index in range(start, end, step): | ||
cat_tensors.append(index * torch.ones(index_tensor_shape)) | ||
index_tensor = torch.stack(cat_tensors, dim) | ||
index_tensor = index_tensor.to(torch.int64).cuda() | ||
output_tensor = torch.scatter(input_tensor, dim, index_tensor, src) | ||
return output_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this potentially be simplified to avoid for
-loops using torch.arange
? For instance, see this implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @gs-olive I tried the above implementation.
I am not sure how get_expanded_index
works, but I think it will be difficult to achieve the above behavior without for loops.
I tried two alternate
indices = torch.arange(start,stop, step)
cat_tensors = torch.unsqueeze(indices,1) * torch.ones(index_tensor_shape)).split(1, dim = 0)
#or
cat_tensors = indices(:, None) * torch.ones(index_tensor_shape)).split(1, dim = 0)
The thing is we need to unsqueeze indices n no of times, where n is the dimension of index_tensor_shape. While the above would work for cases
input = torch.ones(8,8)
src = torch.ones(8,2)
out = torch.slice_scatter(input, src, 1, 6, 8, 1)
or
input = torch.ones(8,8)
src = torch.ones(8,1)
out = torch.slice_scatter(input, src, 1, 6, 7, 1)
it would start failing for input and src with sizes torch.ones(8,8,8)
and torch.zeros(8,2,8)
or torch.zeros(8,1,8)
respectively. We would have to unsqueeze n no of times, eg: torch.unsqueeze(indices,1,1) or indices[:,None,None] would work, but then that would again be a for loop.
I cannot think of another way on top of my mind, if you have any suggestion you could let me know.
For now the test cases pass with for loop so I have reverted back to that,
8c37797
to
498ff5e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good - added a few comments/questions
for index in range(start, end, step): | ||
cat_tensors.append(index * torch.ones(index_tensor_shape)) | ||
index_tensor = torch.stack(cat_tensors, dim) | ||
index_tensor = index_tensor.to(torch.int64).cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will cause a graph break if it inserts a cast in the graph representation, since TRT cannot support Int64 casts. What is the resultant output graph in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This operation might be avoidable by specifying dtype=torch.long
in the torch.ones(...)
call, though if the index tensor is a constant and not an ITensor
, it may not be necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gs-olive the torch.long
was present since otherwise torch would have complained that torch requires int64
input for torch.scatter in this line.
The case torch.slice_scatter(torch.zeros(8,8), torch.ones(8,2), 1, 6, None, 1)
leads to this with the cast to index_tensor = index_tensor.to(torch.int64).cuda()
-
Pre-AOT Autograd graph:=============
graph():
%l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
%l_src_ : torch.Tensor [num_users=1] = placeholder[target=L_src_]
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_src_,), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
%slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter](args = (%clone_default, %clone_default_1, 1, 6,
None, 1), kwargs = {})
return (slice_scatter,)
Post AOT Autograd graph:=============
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg1_1,), kwargs = {})
%clone_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
%empty_strided : [num_users=1] = call_function[target=torch.ops.aten.empty_strided.default](args = ([8], [1]), kwargs = {dtype: to$
ch.int64, layout: torch.strided, device: cpu, pin_memory: False})
%full_like : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%empty_strided, 1), kwargs = {pin_memo$
y: False})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%full_like, 6), kwargs = {})
%empty_strided_1 : [num_users=1] = call_function[target=torch.ops.aten.empty_strided.default](args = ([8], [1]), kwargs = {dtype: t
orch.int64, layout: torch.strided, device: cpu, pin_memory: False})
%full_like_1 : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%empty_strided_1, 1), kwargs = {pin_m
emory: False})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%full_like_1, 7), kwargs = {})
%unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%mul, 1), kwargs = {})
%unsqueeze_1 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%mul_1, 1), kwargs = {})
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%unsqueeze, %unsqueeze_1], 1), kwargs = {})
%_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%cat,), kwargs = {dtype: torch.int64, lay
out: torch.strided, device: cuda:0})
%scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.src](args = (%clone_1, 1, %_to_copy, %clone), kwargs = {})
return (scatter,)
Post lowering Autograd graph:=============
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
%scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.src](args = (%arg0_1, 1, %_frozen_param0, %arg1_1), kwargs =
{})
return (scatter,)
As mentioned by you since it is a frozen param and a constant, there are no graph breaks and not necessary. Not sure if this would be the case always though.
Hence I changed it to torch.ones()
with dtype torch.long
as suggested.
A side general question- Would the graph break lead to significant performance impact? That is the reason we should avoid them as far as possible?
} | ||
unexpected_ops = {torch.ops.aten.select_scatter} | ||
|
||
inputs = [torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda(), 1, 6, None, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this case be modified to be 3D, as in your comment above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kept the old test case and added another with the 3D.
9bc7d6c
to
2b101dd
Compare
changing decomposition pattern slice scatter changes Review comments address Removing arange and replacing with range slice_scatter adding to decomposition group using aten::scatter in aten.slice_scatter Correcting the slice_scatter case with aten::scatter use removing unnecessary cases from slice_scatter impl and adding test case changing for loop to torch.arange Reverting back the torch.arange to for loop Adding test case for 3d cases and removing the casting to torch.int64 and including it torch.ones Removing aten.index in the decomposition ops
Fixes #2434
This PR would be dependant on #2664 and #2669. Major changes
aten::scatter.src
get_attr
call due to which different device locationmeta
andcpu
in torch