Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nccl ops correction changes #3387

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,26 @@ def aot_torch_tensorrt_aten_backend(
_pretraced_backend, settings=settings, engine_cache=engine_cache
)
settings_aot_autograd = {}
settings_aot_autograd["decompostions"] = get_decompositions(
settings_aot_autograd["decompositions"] = get_decompositions(
settings.enable_experimental_decompositions
)
# This is added since detach lowering leads to alias nodes
# Error - View operation returned a tensor that is the same as the input base tensor
# torch nop_decompositions in torch/_decomp/decompositions.py
if aten.detach in settings_aot_autograd["decompositions"]:
del settings_aot_autograd["decompositions"][aten.detach]
# transpose key deleted since not desirable to lower it to permute
to_delete = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this apply to all cases not just NCCL?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean in the non distributed example? I am not sure about that answer, I added this for the llama3 example since I was issues in the model lowering and it was generating graph breaks at the wrong part, leading to complex input error. It can be added to all cases in case if we want to not lower transpose to permute.

key
for key in settings_aot_autograd["decompositions"]
if "transpose" in key._name
}

for key in to_delete:
del settings_aot_autograd["decompositions"][key]

remove_detach(gm, settings)
return aot_autograd(
fw_compiler=_pretraced_backend_autograd,
decompositions=get_decompositions(settings.enable_experimental_decompositions),
decompositions=settings_aot_autograd["decompositions"],
)(gm, sample_inputs)


Expand Down
8 changes: 3 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import Dict, Sequence, Tuple, Union

import tensorrt as trt
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
Expand All @@ -16,8 +17,6 @@
tensorrt_fused_nccl_reduce_scatter_op,
)

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)

if load_tensorrt_llm():
Expand All @@ -30,7 +29,7 @@ def fused_nccl_gather(
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.distributed.nccl_gather(
return impl.nccl_ops.nccl_gather(
ctx,
target,
SourceIR.ATEN,
Expand All @@ -46,15 +45,14 @@ def fused_nccl_reduce_scatter(
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.distributed.nccl_reduce_scatter(
return impl.nccl_ops.nccl_reduce_scatter(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

breakpoint()
else:
_LOGGER.debug(
"Did not load torch.distributed converters since TensorRT-LLM is not available"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ def update_node_meta(node: torch.fx.Node, fake_mode: FakeTensorMode) -> None:

if op_target in shape_inference_funcs:
new_shape = shape_inference_funcs[op_target](node)
real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype)
new_node_dtype = None
if node.meta["val"].dtype == torch.complex64:
new_node_dtype = torch.float32
else:
new_node_dtype = torch.float64
real_tensor = torch.empty(new_shape, dtype=new_node_dtype)
node.meta["val"] = fake_mode.from_tensor(real_tensor)
else:
print("No shape for the inference function", {op_name})
12 changes: 6 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def fuse_distributed_ops(
== torch.ops._c10d_functional.wait_tensor.default
):
wait_tensor_node = list(node.users)[0]
fused_op = None
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
with gm.graph.inserting_after(wait_tensor_node):
fused_node = gm.graph.create_node(
Expand All @@ -58,11 +57,12 @@ def fuse_distributed_ops(
args=(node.args[0], node.args[1], node.args[2]),
)
else:
fused_node = gm.graph.create_node(
op="call_function",
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
)
with gm.graph.inserting_after(wait_tensor_node):
fused_node = gm.graph.create_node(
op="call_function",
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
)

wait_tensor_node.replace_all_uses_with(fused_node)
fused_node.meta.update(node.meta)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,15 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]

for i, contiguous_input in enumerate(contiguous_inputs):
if contiguous_input.dtype == torch.complex64:
contiguous_input_real = contiguous_input.real
contiguous_input_imag = contiguous_input.imag
contiguous_inputs[i] = torch.stack(
(contiguous_input_real, contiguous_input_imag), dim=-1
)

with (
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
if self.profiling_enabled
Expand Down
Loading