From f1c91f8ae36e5a6b7ac2f799310a037d19d6ed78 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Thu, 8 Aug 2024 18:45:04 +0200 Subject: [PATCH] Address Andrej's PR comments --- train_llama3.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/train_llama3.py b/train_llama3.py index f20b2c343..054c3ae0b 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -13,11 +13,7 @@ # 3) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py Example launches to only benchmark the speed of bfloat16 compiled GPU training: -1 GPU: -python train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=8192 --compile=1 --tensorcores=1 --dtype=bfloat16 -you can also turn on flash-attention by appending --flash=1 -4 GPU: -torchrun --standalone --nproc_per_node=4 train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=8192 --compile=1 --tensorcores=1 --dtype=bfloat16 +TODO: add the actual commands """ import os @@ -134,6 +130,9 @@ def precompute_freqs_cis( # ----------------------------------------------------------------------------- # LLaMA building blocks +# LLaMA reference code explicitly implemented RMSNorm so we copy pasted it +# (https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py) +# we could also use nn.RMSNorm, it has slightly different numeric properties, but equivalent class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -161,14 +160,13 @@ def __init__(self, config): self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=False) # key, query, value projections self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # output projection - self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 # static KV cache - we could alternatively allocate it outside of the model and just pass it in when needed if self.use_kv: self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) - def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None): + def forward(self, x, freqs_cis=None, start_pos=None, mask=None): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim qkv = self.c_attn(x) @@ -216,7 +214,6 @@ def __init__(self, config): self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False) self.c_fc2 = nn.Linear(config.n_embd, hidden_dim, bias=False) self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) - self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 def forward(self, x): # SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x)) <-- 3. difference compared to GPT-2 @@ -236,7 +233,7 @@ def __init__(self, config): self.ln_2 = RMSNorm(config.n_embd, config.norm_eps) self.mlp = MLP(config) - def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None): + def forward(self, x, freqs_cis=None, start_pos=None, mask=None): x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask) x = x + self.mlp(self.ln_2(x)) return x @@ -542,9 +539,7 @@ def generate( next_token = next_token.reshape(-1) # only replace token if prompt has already been generated - next_token = torch.where( - input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token - ) + next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) tokens[:, cur_pos] = next_token if logprobs: token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(