You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
model = VideoChatGPTLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
# torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float,
)
Can I use bfloat16 when training? I find if I use bfloat16 I can train with 24G GPUs. But I'm not sure how much this affects model performance? Can you give me some advice?
The text was updated successfully, but these errors were encountered:
I see the code use torch.float by default.
Can I use bfloat16 when training? I find if I use bfloat16 I can train with 24G GPUs. But I'm not sure how much this affects model performance? Can you give me some advice?
The text was updated successfully, but these errors were encountered: