Skip to content

Commit

Permalink
update checkpoint handling and testing
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Jan 14, 2025
1 parent fa8b80c commit a5ea9b3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 18 deletions.
16 changes: 3 additions & 13 deletions thunder/distributed/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,9 @@ def load_model_state_dict(state_dict: dict[str, Any], module: Module, options: S
module.load_state_dict(state_dict, strict=options.strict)

elif options.full_state_dict:
if not hasattr(module, "process_group_for_ddp"):
raise RuntimeError(f"Expected {module} to be FSDP transformed")
process_group = module.process_group_for_ddp
device = next(module.parameters()).device
_unshard_params(module, process_group, options.cpu_offload)
if not options.rank0_only or rank == 0:
module.load_state_dict(state_dict, strict=options.strict)
# with rank0_only enabled, it's useful to broadcast so that the other shards are still loaded as expected
_shard_params(module, process_group, device, 0 if options.rank0_only else None)
# TODO: broadcast rank0 to others if options.rank0_only?
module.load_original_state_dict(state_dict)
else:
state_dict = tree_map(lambda t: DTensor.to_local(t) if isinstance(t, DTensor) else t, state_dict)
module.load_state_dict(state_dict, strict=options.strict)
Expand Down Expand Up @@ -209,10 +203,6 @@ def load(module_state: dict[str, Any], path: Path, **kwargs: Any) -> None:

def _split_state_dict(module: Module) -> tuple[dict[str, Any], dict[str, Any]]:
"""A flavor of ``module.state_dict()`` that returns parameters separated to everything else."""
params = {
param_name: param.detach()
for module_name, submodule in module.named_modules()
for param_name, param in submodule.named_parameters(recurse=False, prefix=module_name)
}
params = {param_name: param.detach() for param_name, param in module.named_parameters()}
rest = {k: v for k, v in module.state_dict().items() if k not in params}
return params, rest
14 changes: 9 additions & 5 deletions thunder/tests/distributed/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_get_model_state_dict(self):
with pytest.raises(ValueError, match="cannot be used"):
get_model_state_dict(model, options, self.rank)

sharded_model = thunder.distributed.fsdp(model)
sharded_model = thunder.distributed.fsdp(thunder.jit(model))
assert has_fsdp_modules(model)

# Sharding - full state dict
Expand Down Expand Up @@ -207,23 +207,27 @@ def test_load_model_state_dict(self):

# Sharding - full state dict
for kwargs in ({"cpu_offload": True}, {"cpu_offload": False}, {"rank0_only": True}):
unsharded_model = MyModel(4).to(device=device)
unsharded_model.load_state_dict(state_dict)
model = MyModel(4).to(device=device)
sharded_model = thunder.distributed.fsdp(model)
sharded_model = thunder.distributed.fsdp(thunder.jit(model))
options = StateDictOptions(full_state_dict=True, **kwargs)
print("before", sharded_model.get_parameter("l.weight"))
load_model_state_dict(state_dict, sharded_model, options, self.rank)
print("after", sharded_model.get_parameter("l.weight"))
_unshard_params(sharded_model, pg)
torch.testing.assert_close(model.state_dict(), state_dict)
torch.testing.assert_close(unsharded_model.state_dict(), state_dict)

# Create a sharded state_dict that can be loaded
model = MyModel(4).to(device=device)
sharded_model_expected = thunder.distributed.fsdp(model)
sharded_model_expected = thunder.distributed.fsdp(thunder.jit(model))
options = StateDictOptions(full_state_dict=False)
sharded_state_dict = get_model_state_dict(sharded_model_expected, options, self.rank)

# Sharding - sharded state dict
for kwargs in ({"cpu_offload": True}, {"cpu_offload": False}, {"rank0_only": True}):
model = MyModel(4).to(device=device)
sharded_model = thunder.distributed.fsdp(model)
sharded_model = thunder.distributed.fsdp(thunder.jit(model))
options = StateDictOptions(full_state_dict=False, **kwargs)
load_model_state_dict(sharded_state_dict, sharded_model, options, self.rank)
torch.testing.assert_close(sharded_model.state_dict(), sharded_model_expected.state_dict())
Expand Down

0 comments on commit a5ea9b3

Please sign in to comment.