Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Utils] align_module_device #3204

Merged
merged 11 commits into from
Nov 1, 2024
6 changes: 5 additions & 1 deletion docs/source/package_reference/big_modeling.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,8 @@ rendered properly in your Markdown viewer.

### has_offloaded_params

[[autodoc]] utils.has_offloaded_params
[[autodoc]] utils.has_offloaded_params

### align_module_device

[[autodoc]] utils.align_module_device
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
is_xpu_available,
)
from .modeling import (
align_module_device,
calculate_maximum_sizes,
check_device_map,
check_tied_parameters_in_config,
Expand Down
81 changes: 50 additions & 31 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,22 +1531,12 @@ def get_state_dict_offloaded_model(model: nn.Module):
for name, module in model.named_modules():
if name == "":
continue
if has_offloaded_params(module):
original_device = module._hf_hook.execution_device
# assign hook execution device to cpu
module._hf_hook.execution_device = "cpu"
# onload meta tensors to execution device
try:
module._hf_hook.pre_forward(module)
except MemoryError:
raise MemoryError("Offloaded module must fit in CPU memory to call save_model!") from None
module_state_dict = module.state_dict()
# offload meta tensors from cpu
module._hf_hook.post_forward(module, torch.tensor([]))
# re-assign hook to original execution device
module._hf_hook.execution_device = original_device
else:
module_state_dict = module.state_dict()

try:
with align_module_device(module, "cpu"):
module_state_dict = module.state_dict()
except MemoryError:
raise MemoryError("Offloaded module must fit in CPU memory to call save_model!") from None

for key in module_state_dict:
# ignore placeholder parameters that are still on the meta device
Expand Down Expand Up @@ -1586,22 +1576,12 @@ def get_state_dict_from_offload(
"""

root = module_name[: module_name.rfind(".")] # module name without .weight or .bias
preforward = False
if has_offloaded_params(module):
# assign the device to which the offloaded parameters will be sent
original_device = module._hf_hook.execution_device
module._hf_hook.execution_device = device_to_put_offload
module._hf_hook.pre_forward(module)
preforward = True

for m_key in module.state_dict():
params = module.state_dict()[m_key]
if (root + f".{m_key}") in state_dict:
state_dict[root + f".{m_key}"] = params

if preforward:
module._hf_hook.post_forward(module, torch.tensor([]))
module._hf_hook.execution_device = original_device
# assign the device to which the offloaded parameters will be sent
with align_module_device(module, device_to_put_offload):
for m_key, params in module.state_dict().items():
if (root + f".{m_key}") in state_dict:
state_dict[root + f".{m_key}"] = params

return state_dict

Expand Down Expand Up @@ -1919,3 +1899,42 @@ def has_offloaded_params(module: torch.nn.Module) -> bool:
from ..hooks import AlignDevicesHook # avoid circular import

return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload


@contextlib.contextmanager
def align_module_device(module: torch.nn.Module, execution_device: Optional[torch.device] = None):
"""
Context manager that moves a module's parameters to the specified execution device.

Args:
module (`torch.nn.Module`):
Module with parameters to align.
execution_device (`torch.device`, *optional*):
If provided, overrides the module's execution device within the context. Otherwise, use hook execution
device or pass
"""
if has_offloaded_params(module):
if execution_device is not None:
original_device = module._hf_hook.execution_device
module._hf_hook.execution_device = execution_device

try:
module._hf_hook.pre_forward(module)
yield
finally:
module._hf_hook.post_forward(module, None)
if execution_device is not None:
module._hf_hook.execution_device = original_device

elif execution_device is not None:
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
devices = {name: param.device for name, param in module.named_parameters()}
try:
for name in devices:
set_module_tensor_to_device(module, name, execution_device)
yield
finally:
for name, device in devices.items():
set_module_tensor_to_device(module, name, device)

else:
yield
49 changes: 49 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from safetensors.torch import save_file

from accelerate import init_empty_weights
from accelerate.big_modeling import cpu_offload
from accelerate.test_utils import (
require_cuda,
require_huggingface_suite,
Expand All @@ -34,6 +35,7 @@
torch_device,
)
from accelerate.utils.modeling import (
align_module_device,
check_device_map,
clean_device_map,
compute_module_sizes,
Expand Down Expand Up @@ -785,3 +787,50 @@ def test_convert_file_size(self):

with self.assertRaises(ValueError):
convert_file_size_to_int("-1GB")

def test_align_module_device_simple(self):
model = ModelForTest()
execution_device = torch.device(torch_device)
model_device = torch.device("cpu")

# test default execution device
with align_module_device(model.batchnorm):
assert model.linear1.weight.device == model_device
assert model.batchnorm.weight.device == model_device
assert model.linear2.weight.device == model_device
assert model.linear1.weight.device == model_device
assert model.batchnorm.weight.device == model_device
assert model.linear2.weight.device == model_device

# test with explicit execution device
with align_module_device(model.batchnorm, execution_device=execution_device):
assert model.linear1.weight.device == model_device
assert model.batchnorm.weight.device == execution_device
assert model.linear2.weight.device == model_device
assert model.linear1.weight.device == model_device
assert model.batchnorm.weight.device == model_device
assert model.linear2.weight.device == model_device

def test_align_module_device_offloaded(self):
model = ModelForTest()
execution_device = torch.device(torch_device)
offload_device = torch.device("meta")
cpu_offload(model, execution_device=execution_device)

# test default execution device
with align_module_device(model.batchnorm):
assert model.linear1.weight.device == offload_device
assert model.batchnorm.weight.device == execution_device
assert model.linear2.weight.device == offload_device
assert model.linear1.weight.device == offload_device
assert model.batchnorm.weight.device == offload_device
assert model.linear2.weight.device == offload_device

# test with explicit execution device
with align_module_device(model.batchnorm, execution_device="cpu"):
assert model.linear1.weight.device == offload_device
assert model.batchnorm.weight.device == torch.device("cpu")
assert model.linear2.weight.device == offload_device
assert model.linear1.weight.device == offload_device
assert model.batchnorm.weight.device == offload_device
assert model.linear2.weight.device == offload_device
Loading