Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

--fp8_base breaks --save_state. Requires updated safetensors package. #1078

Closed
feffy380 opened this issue Jan 25, 2024 · 0 comments · Fixed by #1079
Closed

--fp8_base breaks --save_state. Requires updated safetensors package. #1078

feffy380 opened this issue Jan 25, 2024 · 0 comments · Fixed by #1079

Comments

@feffy380
Copy link
Contributor

feffy380 commented Jan 25, 2024

When using --fp8_base --save_state with both train_network scripts (SD and SDXL), saving state crashes. The log shows safetensors is crashing due to a failed dict lookup for the dtype size.
Updating safetensors to the latest version fixed the issue. 0.4.2 as of writing

saving last state.
Traceback (most recent call last):
  File "/home/hope/src/sd/sd-scripts/train_network.py", line 1062, in <module>
    trainer.train(args)
  File "/home/hope/src/sd/sd-scripts/train_network.py", line 964, in train
    train_util.save_state_on_train_end(args, accelerator)
  File "/home/hope/src/sd/sd-scripts/library/train_util.py", line 4573, in save_state_on_train_end
    accelerator.save_state(state_dir)
  File "/home/hope/src/sd/sd-scripts/venv/lib/python3.11/site-packages/accelerate/accelerator.py", line 2708, in save_state
    save_location = save_accelerator_state(
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hope/src/sd/sd-scripts/venv/lib/python3.11/site-packages/accelerate/checkpointing.py", line 99, in save_accelerator_state
    save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)
  File "/home/hope/src/sd/sd-scripts/venv/lib/python3.11/site-packages/accelerate/utils/other.py", line 181, in save
    save_func(obj, f)
  File "/home/hope/src/sd/sd-scripts/venv/lib/python3.11/site-packages/safetensors/torch.py", line 232, in save_file
    serialize_file(_flatten(tensors), filename, metadata=metadata)
                   ^^^^^^^^^^^^^^^^^
  File "/home/hope/src/sd/sd-scripts/venv/lib/python3.11/site-packages/safetensors/torch.py", line 402, in _flatten
    return {
           ^
  File "/home/hope/src/sd/sd-scripts/venv/lib/python3.11/site-packages/safetensors/torch.py", line 406, in <dictcomp>
    "data": _tobytes(v, k),
            ^^^^^^^^^^^^^^
  File "/home/hope/src/sd/sd-scripts/venv/lib/python3.11/site-packages/safetensors/torch.py", line 362, in _tobytes
    bytes_per_item = _SIZE[tensor.dtype]
                     ~~~~~^^^^^^^^^^^^^^
KeyError: torch.float8_e4m3fn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant