Handle bfloat16 weights in disk offload #460
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR handles a bug reported in #454 concerning bfloat16 and disk offload.
NumPy does not support bfloat16, so while storing those tensors on disk, we need to convert them back as float32 (can't convert them in float16 without losing precision). This not super efficient but it's the best we can offer until NumPy offers support for this dtype.
A new test is added to check the offload on disk with bfloat16 works. It doesn't need any specialized hardware since we don't do anything on the bfloat16 tensor.