Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Bugfix] Fix dummy weight for fp8 (vllm-project#4916)
Browse files Browse the repository at this point in the history
Allow dummy load format for fp8,
torch.uniform_ doesn't support FP8 at the moment

Co-authored-by: Mor Zusman <[email protected]>
  • Loading branch information
2 people authored and robertgshaw2-redhat committed Jul 14, 2024
1 parent 33643a4 commit 247dd03
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,4 +388,11 @@ def initialize_dummy_weights(
"""
for param in model.state_dict().values():
if torch.is_floating_point(param):
param.data.uniform_(low, high)
if torch.finfo(param.data.dtype).bits < 16:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype = param.data.dtype
tmp_param = param.data.to(torch.float16)
tmp_param = tmp_param.uniform_(low, high).to(dtype)
param.data.copy_(tmp_param)
else:
param.uniform_(low, high)

0 comments on commit 247dd03

Please sign in to comment.