Skip to content

Commit

Permalink
flash attn
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 30, 2024
1 parent d45b34f commit 500bd2d
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def run_train(config: TrainConfig):
logger.warning(f"ALERT, force adjusting model config max_position_embeddings upwards from {model_config.max_position_embeddings} to {config.generate.max_length}")
model_config.max_position_embeddings = config.generate.max_length

if config.model.use_flash_attn:
model_config.attention_type = "flash"

model = AutoModelForCausalLM.from_pretrained(
config.model.name_or_path, torch_dtype=torch.bfloat16,
config=model_config,
Expand Down

0 comments on commit 500bd2d

Please sign in to comment.