From d4ef9c5afdd9c5b2ffcba9b03077535045ee4ad7 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Thu, 8 Aug 2024 17:09:41 +0200 Subject: [PATCH] Remove llmc_py, single file --- llmc_py/__pycache__/rope.cpython-310.pyc | Bin 0 -> 2167 bytes llmc_py/__pycache__/tokenizer.cpython-310.pyc | Bin 0 -> 5372 bytes llmc_py/__pycache__/utils.cpython-310.pyc | Bin 0 -> 2245 bytes llmc_py/rope.py | 59 ---- llmc_py/tokenizer.py | 173 ---------- llmc_py/utils.py | 57 ---- train_gpt2.py | 298 +++++++++++++++++- 7 files changed, 295 insertions(+), 292 deletions(-) create mode 100644 llmc_py/__pycache__/rope.cpython-310.pyc create mode 100644 llmc_py/__pycache__/tokenizer.cpython-310.pyc create mode 100644 llmc_py/__pycache__/utils.cpython-310.pyc diff --git a/llmc_py/__pycache__/rope.cpython-310.pyc b/llmc_py/__pycache__/rope.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8797638e54d73f9b72c780cf7c209e54840afb2c GIT binary patch literal 2167 zcmaJ?OK)5?6t;bRA2XRenx<_^t12J?8VE_%E)Y@`Dp7U8A`9qhL^Zk@doq)`Gk4;9 z(EU{vNvSZCFe8)G_2SRvsj(v}h?c?*E&y&y2)*0H* zKYic*BVp_p`gpm~>`h4V4U}SvXRJlN$oNpS1g~sCD^P(BcBGP|yHbTJdMI0=iv71x ziK^i{Qg!b&<8`&7p#^I;1=8@xi*u$9ipi;OaNBeo=Pv z;do<)LZED|vd^1F7rk_>+uhu>w@jX@PFj>}H?q8wX2k{u(^XAr8d&2(@9%1Qv^Ufy zEwz*P`uct|Fmy#jSvT{fC@e!~nOP`z?mtf=r-I5Z;+OdqzRc^~ERuWZv3(_lv{wZA z2vYn2rT7GJO@s<2fszwxr4^UhBt##ekG9z)w(J(WyD>>rSOVZBW)sZRY|TakX5O~( zp|HsjN?>;hws0WL^~z@XC3b2L;Tsem_#%fIuaoe&bK=y3o*meWTv+C&({oX&M@4ST1=?(xR?iqJG1n*nAYe^9aBR(I`TaI!bZS`W^|yPk zSmyIgwLAH!)ca*S)1%hH^ELO=dzvg~UM!lLY9n23TsTSlx8mtcR1EXH>{VspYr&Vn zEaJ}&+7w-Xv^|GVi?^Vpkeq&rNJ#TJNCUFK8=}r5E>Bj1C2m&eYGQvEGH{X6<(?UJ zsP3kX4kGb2@+}4>I|Q{4#Sw@~l*6lqy7>|$|BfsE6!r6zf5M&st4ZK}`Ey_XKfbaI zt+2rZR3BDFR1Hhm<1khMc0#SehH8At7TNGHvGB1GdVa*aL~OZfsHDw#sAIZFvt}-% zbMr*1c8b>xl+E-$2$AJ!>EcrwRUGqFX-r*>PBRyGv$RB2xR`1PJ2YeSIvw!_DLSGR z?C-Xr2JKdaQNCBAA5i~quofz?OkSqupsj~n1XV%Lz-qV(sxJATc^8IXA&fox@#bw8 z;8AixRir&#rmntMXpgv79jf(ddj+fvR>du)6QEdzL|@G z>bay8R9W}zSQ;BmfMY{XPXFh24L$`ixllMowZlp4}{a<)-5~%^`27S@9~(VQ;uu^{UMgZ=_lCYHE)jje4WlYei$tac>-RC2xY+ z-k#SrR%Xr(jXB$fx0mTFT7Bp)&eoRedTzZF$c_4NKD@A+N)a^V3YWRF7(`L98u7fc zm_!k8$S{d>>qOYV+DXYpAQO>Wr@~a`gg+~9fXy;$vR1gXSf z<53bYUxt@ta*4 zTH|GA%t2nB!z^!zm3V2*U^Xk?u)GR$*buAy)?mZD#HxIljiA+dm5=Zm8^zvHUYgd3 z%CyGDdb4BQSvJncxebgen_zox*pTL4zK3S8Qd+avJ~oMS_VRt|geh)gep2nW*?u;C z!`jxpDds@lPvk>q=*r<17h7*vT>LuYFkW|qn7I&F!Z0|>c+qhszb4(yln~?wch7XB z<|^3{p;<*Iuca;C2!qI13+Y`=*R-v4=6HDQV45A2*+IdxgICa+*}*8v4zhZ7_Q&VM zx%gb`gHvsQ<@WjGs((0BqYs{c_Jwxl9tT!BfOGgPZ90D_JC`0Ry1-8l&m8~8xs^Kg z@=7ydJn~zxsD=Thb$=yvY&5d`^j`o{yQtqXGX1umR^HSyeOv!nmnMDNhO{zcTmPxH z4o&}FZ=1|u<_%4jGVP*`HFI0vuWjom(ApSf)|A#ZKhf57X0g%@N0teLRx{0a6 z945LKtJpu18GsJ6Kee!0yQvlQrZkL3)k=Aq`eUl^sQ$R>52^ly>Q_{MkLnMr{$ACu zs{Xz{4ovJqH8ZR=M`Wu*R-aVViQ=rJwS+cgjOYA83{ps#Nw)Y(dJ=Ym*I zkM%}Gt1otU22NZ*=AMDUkQ!htyNySox@%#?3+tin_2GNtbJ;3i4`V-B^U3bA`Fo>t z8IM_iV(Y27Oz@P8E1dbeOwNBUn}5ve^B^`om#u{{mzy#7MYz5pVI<5C88!~jWvS$? zVz7q>fspg&> z`V_!AbbY?kMqFWfrcY_VK@9$#`3vK+0bQ9};Q34LdzVOU8QJ_Dt#0JT>~pyuiYG9q zz`X-UHyArAo2{c!)SPV)Z1}C~NBfT3x}I0&B1F$d(rnEy5X z$9+G9rTG5rT$D6|D4j>AyK&d`$hqixdx#n_6VmmD&KYeuRc73Tn@wrIRF~`XS~ta$ zKoum8yrdYsqkE(AX478{Qr-zE1^(9vr!vQo1QgpRF;2;?ZLOw%CDXlabt|{xpvm(x z4Ah6Y=A%t6);XmcUj*@`RE&YBpnHhD1f~c~64*~*AHZ^bDmRiex5&tIJ50%<1!auf zjCh=vKxm!kPtrvmt;94R;#j3r#?^tj)6Gksv(yR3f)p*DCUAggmz4YD<=DTP2$tsd zYLde6a+|BKhc6i-eaSezk-jl1w8H4~;Z!EAj%@3t+)8=0CMM{VX9&>sz0n@qMOZI% z;L^HMgwr*Q!@4f_Q;rF-*>EvVe8J0Wjea08)(kwwQ|*{>fJxWuH4_6 zc-YpK)nRvxd=zCsa<%X8`2-8o2j}!OHz9}Igb)NVlAAygFHz4H@OcsUvVn!kAvCgf zGT8KI0GY+~O-+=gwym`-sozBOmd2)*j=rgVtbd|?q9dnM7I37=ENQ9r(zdB`fW5A@ z%b9&bJ6~%%S^1WIQ@>cg=wwc2WhGYn)V!_h+T{b8Azs1i5VNpexs56-GsPlihB2-Z zPF9B3+reMJRxh|JJ3xRAT-Yfk;xI-)4P5d-H(4uEv3v4F`q=QC4o2lLZf(jgIS^ar zMe66QACM_1L!$1?YLY~V=I?KCQ~-j4c}wsXf}slPDq_cA>qxS8B<+-31y8$aK{1b_ z6|{P!K=P}Ibk$y%qdW67CZ5hxF3&)f^{Rl44Q9epXm2MEcH zK)NCbQ;vekY)0}~7Z;SI z`2mJIe7!nT@2GCTKvu82;d-2qvGkbPeUDjph3d7n5Yark-cRI5TvXobs0$CDb%At` z^R9+bPI_DrDOv z>D>goQm^-|Yk=vF`n#op;#k>8!L$?{>id5%z57yqUX~w9hh+qp7bwbn8{i{KhkM)l zP5qWe$uQF;65>x#{Y&Gvf%s9PH2FhR&zHw88w_PaU0<%7`7^jtgbaLj1Mi%)6*PG2 zC&G`Dcwi=sS~ymg52Z!pLks1CrH|>xj`f9M8uzO^ zTL=2VTsHRv_L&dYT|W(3uq~E z;C$W2<*G&v*~}_a7PFDZ=+FAX-m%Ix%v8~~#vZ)1coW-uIR|}>@;DVoi)bm85~$L( zTgWofloslCecQNA?_z!HZ#@q>K@Tw%{2vPU-8@iag~tYe>CT+$#+^=K{xG@1G1beD z@<5M;M}{#RS`_#l{7!i7J3_tG(N}>WMsQz288x0Vm}A?2sJj|c98qF&h$%` zzE5}c4?_rjM-A*xvy2_wm-N_ z&Wo_5d^9hIsftL-ZR_PCbi54=Z@jUItPrEX?&`;z+&vYh2yBREs74ivdUvUZfLFy@ zoLnu|AGDEILH-YDsSAK0!}kR}Kg1G&Qvms>?+0<5$bkM4m16x6Egc~6 zBCT6gP>BkSs|2VN6(#`+4$4PV3=+dMByY&8kc$i(O;lhB6J*zdBrqVAzWa<~!9`2w z04f^9Vp)cT0A{>m==NuZVcf^?bE9PZXVNzAO_Yt#Y^U2B{! z<5Lzz9QF@zvX8lOfWN>G;Mn5E0qub^5)y){wm0jIM31JY`cd6q)mK$hw%aWN?T6of z8hqa%vF~Vp@`uMdnlG8q=!sbdnajL0(sx(` z=3M4;k2O&mc&x?N&YXUO(OnX*U%;897gB8%MS}8bbC(181qhtz#N z6&XMaL1Z%JbSv^D@Fdt`JX?}{`%#|C)#Fdu5Dh) z#9p}&WxNx=@>A#(c}Zq;PAI9ZnnaeeXO41bUQOrlz)XXYUGkdtLSMUbQgNaEC`yY| zMUg*xJr-3JMdj}@0$Hl0leG*QR2CG>1q27YFxPn8EJ9gHIVOs zP?H(JMuO%58%dd3NYuziuma5s9q|Etv^&VkSm}m3sdyAi(Z)5~z!JA{gI7*eHBoN^ zYz#_q6bpuEC0^d4EqZ$666A6eNnc-2!)xovF)QaHIPOqJC!A*2pn>-nli0YzHk^Wew(o zmIpE8)i$HGH@Bx=?Lmw*<6kKRVE}E13-C}LsIvO33bu+#lJSWQ0~zO4mKFy=Sp;!V ziE=;QPqS2=1T2+Gr27!qWzlV!L%tRV@}ZINNZc%gEt5vwJzmICgb#zA|2ZO1W=Ho* zJ=ijl@Rp7m3ND9b#sa{*gnJ2c3F8@86H%-nS%vU0*hP~#1Klw%g0x_1!sSx6S1NvG z!Jcu(<>>)rK2MYk)Q|@50Xhf$zqir@@pQ;p-8+jj~^$T zv&CUoHdWnV4^Ch8ZIH@9a7?i{OHTp!ek`E}J`4_n>Czf713fP~5J!8E%UGs#i*cE7 zFta#-=*Ou8cuiYw5%0n10z!d6p;#*>d1P$nGHfi0g#+9v?SWl#l52;hhmp2XTmwT@ zT9kPjXWBc60g?zC+BMv3TJ>8-icr7bT*xUqz7)-A;Fhs&75rF5hLV+RGL>h20R|SU zsJaOtU&fWgxZ>L5$5mV~-N2ZKiO@3^iU+Vw zTl*6JMj0=_dC~YV8uFM)?c*nc)8BSua|vU7cptPSmiZ0K2hRo82Hm9AKjKT63Vp-) y;&13ZHD%EhE#qJ9E%I1=3|r0qm0L^^G>rL=7WMQAhLps=o$svEj_q&Ux&0p@@;XES literal 0 HcmV?d00001 diff --git a/llmc_py/rope.py b/llmc_py/rope.py index 3caf58073..e69de29bb 100644 --- a/llmc_py/rope.py +++ b/llmc_py/rope.py @@ -1,59 +0,0 @@ -# From: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py - -import math -from typing import Tuple -import torch - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - -def apply_scaling(freqs: torch.Tensor): - # Values obtained from grid search - scale_factor = 8 - low_freq_factor = 1 - high_freq_factor = 4 - old_context_len = 8192 # original llama3 length - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in freqs: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / scale_factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - -def precompute_freqs_cis( - dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False -): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device, dtype=torch.float32) - if use_scaled: - freqs = apply_scaling(freqs) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis \ No newline at end of file diff --git a/llmc_py/tokenizer.py b/llmc_py/tokenizer.py index 528de113c..2d2a3bd58 100644 --- a/llmc_py/tokenizer.py +++ b/llmc_py/tokenizer.py @@ -1,7 +1,5 @@ # From: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/tokenizer.py -import os -from pathlib import Path from typing import ( AbstractSet, Callable, @@ -16,174 +14,3 @@ cast, ) -import tiktoken -from tiktoken.load import load_tiktoken_bpe - -# The tiktoken tokenizer can handle <=400k chars without -# pyo3_runtime.PanicException. -TIKTOKEN_MAX_ENCODE_CHARS = 400_000 - -# https://github.com/openai/tiktoken/issues/195 -# Here we iterate over subsequences and split if we exceed the limit -# of max consecutive non-whitespace or whitespace characters. -MAX_NO_WHITESPACES_CHARS = 25_000 - - -class Tokenizer: - """ - Tokenizing and encoding/decoding text using the Tiktoken tokenizer. - """ - - special_tokens: Dict[str, int] - - num_reserved_special_tokens = 256 - - pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 - - def __init__(self, model_path: str): - """ - Initializes the Tokenizer with a Tiktoken model. - - Args: - model_path (str): The path to the Tiktoken model file. - """ - assert os.path.isfile(model_path), model_path - - mergeable_ranks = load_tiktoken_bpe(model_path) - num_base_tokens = len(mergeable_ranks) - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|finetune_right_pad_id|>", - "<|step_id|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|eom_id|>", # end of message - "<|eot_id|>", # end of turn - "<|python_tag|>", - ] - reserved_tokens = [ - f"<|reserved_special_token_{2 + i}|>" - for i in range(self.num_reserved_special_tokens - len(special_tokens)) - ] - special_tokens = special_tokens + reserved_tokens - - self.special_tokens = { - token: num_base_tokens + i for i, token in enumerate(special_tokens) - } - self.model = tiktoken.Encoding( - name=Path(model_path).name, - pat_str=self.pat_str, - mergeable_ranks=mergeable_ranks, - special_tokens=self.special_tokens, - ) - - self.n_words: int = num_base_tokens + len(special_tokens) - # BOS / EOS token IDs - self.bos_id: int = self.special_tokens["<|begin_of_text|>"] - self.eos_id: int = self.special_tokens["<|end_of_text|>"] - self.eot_id: int = self.special_tokens["<|eot_id|>"] - self.eom_id: int = self.special_tokens["<|eom_id|>"] - self.python_tag_id = self.special_tokens["<|python_tag|>"] - self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] - self.stop_tokens = [ - self.special_tokens["<|eom_id|>"], - self.special_tokens["<|eot_id|>"], - ] - - def encode( - self, - s: str, - *, - bos: bool, - eos: bool, - allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, - disallowed_special: Union[Literal["all"], Collection[str]] = (), - ) -> List[int]: - """ - Encodes a string into a list of token IDs. - - Args: - s (str): The input string to be encoded. - bos (bool): Whether to prepend the beginning-of-sequence token. - eos (bool): Whether to append the end-of-sequence token. - allowed_tokens ("all"|set[str]): allowed special tokens in string - disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string - - Returns: - list[int]: A list of token IDs. - - By default, setting disallowed_special=() encodes a string by ignoring - special tokens. Specifically: - - Setting `disallowed_special` to () will cause all text corresponding - to special tokens to be encoded as natural text (insteading of raising - an error). - - Setting `allowed_special` to "all" will treat all text corresponding - to special tokens to be encoded as special tokens. - """ - if allowed_special is None: - allowed_special = set() - assert type(s) is str - - substrs = ( - substr - for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) - for substr in self._split_whitespaces_or_nonwhitespaces( - s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS - ) - ) - t: List[int] = [] - for substr in substrs: - t.extend( - self.model.encode( - substr, - allowed_special=allowed_special, - disallowed_special=disallowed_special, - ) - ) - if bos: - t.insert(0, self.bos_id) - if eos: - t.append(self.eos_id) - return t - - def decode(self, t: Sequence[int]) -> str: - """ - Decodes a list of token IDs into a string. - - Args: - t (List[int]): The list of token IDs to be decoded. - - Returns: - str: The decoded string. - """ - # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. - return self.model.decode(cast(List[int], t)) - - @staticmethod - def _split_whitespaces_or_nonwhitespaces( - s: str, max_consecutive_slice_len: int - ) -> Iterator[str]: - """ - Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` - consecutive whitespaces or consecutive non-whitespaces. - """ - current_slice_len = 0 - current_slice_is_space = s[0].isspace() if len(s) > 0 else False - slice_start = 0 - - for i in range(len(s)): - is_now_space = s[i].isspace() - - if current_slice_is_space ^ is_now_space: - current_slice_len = 1 - current_slice_is_space = is_now_space - else: - current_slice_len += 1 - if current_slice_len > max_consecutive_slice_len: - yield s[slice_start:i] - slice_start = i - current_slice_len = 1 - yield s[slice_start:] \ No newline at end of file diff --git a/llmc_py/utils.py b/llmc_py/utils.py index ed023c78a..e69de29bb 100644 --- a/llmc_py/utils.py +++ b/llmc_py/utils.py @@ -1,57 +0,0 @@ -# Taken from: -# 1) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py -# 2) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py - -import torch -from torch import nn - -# Special modules -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight - -# Sampling -def sample_top_p(probs, p): - """ - Perform top-p (nucleus) sampling on a probability distribution. - - Args: - probs (torch.Tensor): Probability distribution tensor. - p (float): Probability threshold for top-p sampling. - - Returns: - torch.Tensor: Sampled token indices. - - Note: - Top-p sampling selects the smallest set of tokens whose cumulative probability mass - exceeds the threshold p. The distribution is renormalized based on the selected tokens. - """ - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - return next_token - -# GQA -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) \ No newline at end of file diff --git a/train_gpt2.py b/train_gpt2.py index 4eea1f46e..4a8cc6e0e 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -29,9 +29,18 @@ import json from pathlib import Path from typing import ( + AbstractSet, + Callable, + Collection, + Dict, + Iterator, List, + Literal, Optional, + Sequence, Tuple, + Union, + cast, ) import numpy as np @@ -44,9 +53,8 @@ from torch.distributed.optim import ZeroRedundancyOptimizer import torch.distributed as dist -from llmc_py.tokenizer import Tokenizer -from llmc_py.rope import precompute_freqs_cis, apply_rotary_emb -from llmc_py.utils import repeat_kv, sample_top_p, RMSNorm +import tiktoken +from tiktoken.load import load_tiktoken_bpe # ----------------------------------------------------------------------------- # PyTorch nn.Module definitions for the LLaMA 3.x model @@ -54,6 +62,91 @@ # using a global to toggle flash-attention FLASH = 0 +# Used in Grouped Query Attention (GQA), broadcasts the key and value tensors +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + +# ----------------------------------------------------------------------------- +# RoPE related + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False +): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + if use_scaled: + freqs = apply_scaling(freqs) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + +# ----------------------------------------------------------------------------- +# LLaMA building blocks + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + class CausalSelfAttention(nn.Module): def __init__(self, config): @@ -482,6 +575,205 @@ def generate( out_logprobs.append(probs) return (out_tokens, out_logprobs if logprobs else None) +# ----------------------------------------------------------------------------- +# sampling utils + +def sample_top_p(probs, p): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + +# ----------------------------------------------------------------------------- +# Llama 3.1 Tokenizer + +# The tiktoken tokenizer can handle <=400k chars without +# pyo3_runtime.PanicException. +TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + +# https://github.com/openai/tiktoken/issues/195 +# Here we iterate over subsequences and split if we exceed the limit +# of max consecutive non-whitespace or whitespace characters. +MAX_NO_WHITESPACES_CHARS = 25_000 + + +class Tokenizer: + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + def __init__(self, model_path: str): + """ + Initializes the Tokenizer with a Tiktoken model. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|step_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + ] + reserved_tokens = [ + f"<|reserved_special_token_{2 + i}|>" + for i in range(self.num_reserved_special_tokens - len(special_tokens)) + ] + special_tokens = special_tokens + reserved_tokens + + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self.n_words: int = num_base_tokens + len(special_tokens) + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.eot_id: int = self.special_tokens["<|eot_id|>"] + self.eom_id: int = self.special_tokens["<|eom_id|>"] + self.python_tag_id = self.special_tokens["<|python_tag|>"] + self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] + self.stop_tokens = [ + self.special_tokens["<|eom_id|>"], + self.special_tokens["<|eot_id|>"], + ] + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, + disallowed_special: Union[Literal["all"], Collection[str]] = (), + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + if allowed_special is None: + allowed_special = set() + assert type(s) is str + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + # ----------------------------------------------------------------------------- # Our own simple Distributed Data Loader