Skip to content

Commit

Permalink
Address Andrej's PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Aug 8, 2024
1 parent 89addd3 commit f1c91f8
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions train_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
# 3) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py
Example launches to only benchmark the speed of bfloat16 compiled GPU training:
1 GPU:
python train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=8192 --compile=1 --tensorcores=1 --dtype=bfloat16
you can also turn on flash-attention by appending --flash=1
4 GPU:
torchrun --standalone --nproc_per_node=4 train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=8192 --compile=1 --tensorcores=1 --dtype=bfloat16
TODO: add the actual commands
"""

import os
Expand Down Expand Up @@ -134,6 +130,9 @@ def precompute_freqs_cis(
# -----------------------------------------------------------------------------
# LLaMA building blocks

# LLaMA reference code explicitly implemented RMSNorm so we copy pasted it
# (https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py)
# we could also use nn.RMSNorm, it has slightly different numeric properties, but equivalent
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
Expand Down Expand Up @@ -161,14 +160,13 @@ def __init__(self, config):

self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=False) # key, query, value projections
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # output projection
self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1

# static KV cache - we could alternatively allocate it outside of the model and just pass it in when needed
if self.use_kv:
self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd))
self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd))

def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None):
def forward(self, x, freqs_cis=None, start_pos=None, mask=None):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
qkv = self.c_attn(x)
Expand Down Expand Up @@ -216,7 +214,6 @@ def __init__(self, config):
self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.c_fc2 = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1

def forward(self, x):
# SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x)) <-- 3. difference compared to GPT-2
Expand All @@ -236,7 +233,7 @@ def __init__(self, config):
self.ln_2 = RMSNorm(config.n_embd, config.norm_eps)
self.mlp = MLP(config)

def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None):
def forward(self, x, freqs_cis=None, start_pos=None, mask=None):
x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask)
x = x + self.mlp(self.ln_2(x))
return x
Expand Down Expand Up @@ -542,9 +539,7 @@ def generate(

next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
Expand Down

0 comments on commit f1c91f8

Please sign in to comment.