Skip to content

Commit

Permalink
Support accessing the module reference for the process group (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Apr 4, 2024
1 parent 6babe4e commit c915335
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
16 changes: 15 additions & 1 deletion thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,23 @@ def skip_data_parallel_grad_sync() -> Generator[Any, Any, Any]:
def _sync_grads(module: torch.nn.Module) -> None:
import thunder

if hasattr(module, "process_group_for_ddp"):
# This branch is required when a function that takes the model as an input is jitted instead
# of the model itself. In that case, the user won't have a reference to a `ThunderModule` so this needs to use
# the reference set by ddp and fsdp on the module directly
process_group = module.process_group_for_ddp
elif (cd := thunder.compile_data(module)) is not None:
# The ordinary jitted module branch
process_group = cd.process_group_for_ddp
else:
raise RuntimeError(
f"Expected `{type(module).__name__}` to have been jitted or to contain a `process_group_for_ddp` attribute"
)

params_with_grad = [p for p in module.parameters() if p.grad is not None]
if not params_with_grad:
return
grads = [p.grad for p in params_with_grad]
process_group = thunder.compile_data(module).process_group_for_ddp
torch._foreach_div_(grads, process_group.size())
with tdist.distributed_c10d._coalescing_manager(group=process_group, async_ops=True) as cm:
for g in grads:
Expand Down
4 changes: 4 additions & 0 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,10 @@ def fwd_loss(m, x):
fwd_loss = thunder.jit(fwd_loss)
fwd_loss(model, x)

# notice how we cannot do `model.no_sync()` because it's not a ThunderModule
with thunder.ThunderModule.no_sync(model):
fwd_loss(model, x)


common_utils.instantiate_parametrized_tests(CompileDDPTest)

Expand Down

0 comments on commit c915335

Please sign in to comment.