Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 quantization with FSDP2 error #1929

Open
happynear opened this issue Mar 20, 2025 · 2 comments
Open

fp8 quantization with FSDP2 error #1929

happynear opened this issue Mar 20, 2025 · 2 comments

Comments

@happynear
Copy link

happynear commented Mar 20, 2025

When I fp8 quantize a model and then shard it using FSDP2, it reports an error:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/mnt/teams/algo-teams/shared/code/wanx-inference/generate.py", line 461, in <module>
[rank1]:     generate(args)
[rank1]:   File "/mnt/teams/algo-teams/shared/code/wanx-inference/generate.py", line 375, in generate
[rank1]:     wan_i2v = wan.WanI2V(
[rank1]:               ^^^^^^^^^^^
[rank1]:   File "/mnt/teams/algo-teams/shared/code/wanx-inference/wan/image2video.py", line 218, in __init__
[rank1]:     self.model = shard_dit_fn(self.model, param_dtype=torch.float8_e4m3fn)
[rank1]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/teams/algo-teams/shared/code/wanx-inference/wan/distributed/fsdp.py", line 112, in shard_dit_model
[rank1]:     fully_shard_with_ignore_param(block, mesh=pm.get_dp_with_cp_mesh(), reshard_after_forward=True, mp_policy=mixed_fsdp2, ignored_params=ignored_states_set)
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/_composable/contract.py", line 125, in wrapper
[rank1]:     updated = func(inp_module, *args, **kwargs)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/teams/algo-teams/shared/pytorch_distributed_examples/src/tu_pth_dist/fsdp_compat.py", line 200, in fully_shard_with_ignore_param
[rank1]:     state._fsdp_param_group = FSDPParamGroup(
[rank1]:                               ^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 132, in __init__
[rank1]:     FSDPParam(
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 239, in __init__
[rank1]:     self._init_sharded_param(param, device, shard_placement_fn)
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 368, in _init_sharded_param
[rank1]:     chunks = _chunk_with_empty(param_data, shard_world_size, dim=shard_dim)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py", line 124, in _chunk_with_empty
[rank1]:     chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
[rank1]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 436, in _dispatch__torch_function__
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 455, in _dispatch__torch_dispatch__
[rank1]:     raise NotImplementedError(
[rank1]: NotImplementedError: LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.split', overload='Tensor')>, types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>,), arg_types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>, <class 'int'>), kwarg_types={}

I can see that there is no aten.split in https://github.com/pytorch/ao/blob/ab3792e3d91e04f85992a659c1664a6a1a6d733c/torchao/quantization/linear_activation_quantized_tensor.py . Could anyone provide an implementation for it?

@happynear
Copy link
Author

I tried to implement the split function by myself,

@implements(aten.split.Tensor)
def _(func, types, args, kwargs):
    new_values = func(args[0].original_weight_tensor, *args[1:], **kwargs)

    def make_new_tensor(value):
        out = LinearActivationQuantizedTensor(
                    value,
                    args[0].input_quant_func,
                    args[0].quant_kwargs,
                )
        return return_and_correct_aliasing(func, args, kwargs, out)

    return list(map(make_new_tensor, new_values))

Another error is reported:

[rank0]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py", line 124, in _chunk_with_empty
[rank0]:     chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 436, in _dispatch__torch_function__
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 451, in _dispatch__torch_dispatch__
[rank0]:     return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 412, in wrapper
[rank0]:     return func(f, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 220, in _
[rank0]:     out_tensor = func(tensor.original_weight_tensor, *args[1:], **kwargs)
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 455, in _dispatch__torch_dispatch__
[rank0]:     raise NotImplementedError(
[rank0]: NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.split', overload='Tensor')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>, <class 'int'>), kwarg_types={}

I'm not sure why LinearActivationQuantizedTensor becomes AffineQuantizedTensor. When I look into dtypes/affine_quantized_tensor.py, I can find no where to write a split function.
Any suggestions?

@vkuzo
Copy link
Contributor

vkuzo commented Mar 20, 2025

cc @jerryzh168

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants