forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
72dcfeb
commit d4ef9c5
Showing
7 changed files
with
295 additions
and
292 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,59 +0,0 @@ | ||
# From: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py | ||
|
||
import math | ||
from typing import Tuple | ||
import torch | ||
|
||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | ||
ndim = x.ndim | ||
assert 0 <= 1 < ndim | ||
assert freqs_cis.shape == (x.shape[1], x.shape[-1]) | ||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | ||
return freqs_cis.view(*shape) | ||
|
||
def apply_scaling(freqs: torch.Tensor): | ||
# Values obtained from grid search | ||
scale_factor = 8 | ||
low_freq_factor = 1 | ||
high_freq_factor = 4 | ||
old_context_len = 8192 # original llama3 length | ||
|
||
low_freq_wavelen = old_context_len / low_freq_factor | ||
high_freq_wavelen = old_context_len / high_freq_factor | ||
new_freqs = [] | ||
for freq in freqs: | ||
wavelen = 2 * math.pi / freq | ||
if wavelen < high_freq_wavelen: | ||
new_freqs.append(freq) | ||
elif wavelen > low_freq_wavelen: | ||
new_freqs.append(freq / scale_factor) | ||
else: | ||
assert low_freq_wavelen != high_freq_wavelen | ||
smooth = (old_context_len / wavelen - low_freq_factor) / ( | ||
high_freq_factor - low_freq_factor | ||
) | ||
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) | ||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) | ||
|
||
def apply_rotary_emb( | ||
xq: torch.Tensor, | ||
xk: torch.Tensor, | ||
freqs_cis: torch.Tensor, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | ||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | ||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) | ||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | ||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | ||
return xq_out.type_as(xq), xk_out.type_as(xk) | ||
|
||
def precompute_freqs_cis( | ||
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False | ||
): | ||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | ||
t = torch.arange(end, device=freqs.device, dtype=torch.float32) | ||
if use_scaled: | ||
freqs = apply_scaling(freqs) | ||
freqs = torch.outer(t, freqs) | ||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | ||
return freqs_cis | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +0,0 @@ | ||
# Taken from: | ||
# 1) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py | ||
# 2) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py | ||
|
||
import torch | ||
from torch import nn | ||
|
||
# Special modules | ||
class RMSNorm(torch.nn.Module): | ||
def __init__(self, dim: int, eps: float = 1e-6): | ||
super().__init__() | ||
self.eps = eps | ||
self.weight = nn.Parameter(torch.ones(dim)) | ||
|
||
def _norm(self, x): | ||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | ||
|
||
def forward(self, x): | ||
output = self._norm(x.float()).type_as(x) | ||
return output * self.weight | ||
|
||
# Sampling | ||
def sample_top_p(probs, p): | ||
""" | ||
Perform top-p (nucleus) sampling on a probability distribution. | ||
Args: | ||
probs (torch.Tensor): Probability distribution tensor. | ||
p (float): Probability threshold for top-p sampling. | ||
Returns: | ||
torch.Tensor: Sampled token indices. | ||
Note: | ||
Top-p sampling selects the smallest set of tokens whose cumulative probability mass | ||
exceeds the threshold p. The distribution is renormalized based on the selected tokens. | ||
""" | ||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | ||
probs_sum = torch.cumsum(probs_sort, dim=-1) | ||
mask = probs_sum - probs_sort > p | ||
probs_sort[mask] = 0.0 | ||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | ||
next_token = torch.multinomial(probs_sort, num_samples=1) | ||
next_token = torch.gather(probs_idx, -1, next_token) | ||
return next_token | ||
|
||
# GQA | ||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: | ||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)""" | ||
bs, slen, n_kv_heads, head_dim = x.shape | ||
if n_rep == 1: | ||
return x | ||
return ( | ||
x[:, :, :, None, :] | ||
.expand(bs, slen, n_kv_heads, n_rep, head_dim) | ||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) | ||
) | ||
Oops, something went wrong.