Skip to content

Commit

Permalink
Add explicit external mask
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Aug 8, 2024
1 parent d4ef9c5 commit b25e325
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,7 @@ def __init__(self, config):
self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd))
self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd))

# not really a 'bias', more of a mask, but following the OpenAI/HF naming though
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))

def forward(self, x, freqs_cis=None, start_pos=None):
def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
qkv = self.c_attn(x)
Expand All @@ -194,12 +190,13 @@ def forward(self, x, freqs_cis=None, start_pos=None):

if FLASH:
# flashattention
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
y = F.scaled_dot_product_attention(q, k, v, mask)
else:
# manual implementation of attention
# this materializes the large (T,T) matrix for all the queries and keys
scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.hd))
scores = scores.masked_fill(self.bias[:,:,:T,:T] == 0, torch.finfo(scores.dtype).min)
if mask is not None:
scores.masked_fill_(mask, torch.finfo(scores.dtype).min)
att = F.softmax(scores.float(), dim=-1).type_as(q)
y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD)
y = y.transpose(1, 2).contiguous().view(B, T, C)
Expand Down Expand Up @@ -239,8 +236,8 @@ def __init__(self, config):
self.ln_2 = RMSNorm(config.n_embd, config.norm_eps)
self.mlp = MLP(config)

def forward(self, x, freqs_cis=None, start_pos=None):
x = x + self.attn(self.ln_1(x), freqs_cis, start_pos)
def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None):
x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask)
x = x + self.mlp(self.ln_2(x))
return x

Expand Down Expand Up @@ -296,8 +293,10 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0):
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
freqs_cis = self.freqs_cis[start_pos:start_pos+t]

mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1)

for i, block in enumerate(self.transformer.h):
x = block(x, freqs_cis, start_pos)
x = block(x, freqs_cis, start_pos, mask)
x = self.transformer.ln_f(x)

if targets is not None:
Expand Down

0 comments on commit b25e325

Please sign in to comment.