diff --git a/.github/workflows/build-wheels-cuda.yaml b/.github/workflows/build-wheels-cuda.yaml index 3d410148f..745b2e602 100644 --- a/.github/workflows/build-wheels-cuda.yaml +++ b/.github/workflows/build-wheels-cuda.yaml @@ -61,11 +61,9 @@ jobs: - name: Setup Mamba uses: conda-incubator/setup-miniconda@v3.1.0 with: - activate-environment: "build" + activate-environment: "llamacpp" python-version: ${{ matrix.pyver }} - miniforge-variant: Mambaforge miniforge-version: latest - use-mamba: true add-pip-as-python-dependency: true auto-activate-base: false diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 8fa2b447f..343581dce 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -55,7 +55,13 @@ def __init__( if model is None: raise ValueError(f"Failed to load model from file: {path_model}") + vocab = llama_cpp.llama_model_get_vocab(model) + + if vocab is None: + raise ValueError(f"Failed to get vocab from model: {path_model}") + self.model = model + self.vocab = vocab def free_model(): if self.model is None: @@ -75,7 +81,7 @@ def vocab_type(self) -> int: return llama_cpp.llama_vocab_type(self.model) def n_vocab(self) -> int: - return llama_cpp.llama_n_vocab(self.model) + return llama_cpp.llama_n_vocab(self.vocab) def n_ctx_train(self) -> int: return llama_cpp.llama_n_ctx_train(self.model) @@ -84,7 +90,7 @@ def n_embd(self) -> int: return llama_cpp.llama_n_embd(self.model) def rope_freq_scale_train(self) -> float: - return llama_cpp.llama_rope_freq_scale_train(self.model) + return llama_cpp.llama_model_rope_freq_scale_train(self.model) def desc(self) -> str: buf = ctypes.create_string_buffer(1024) @@ -98,53 +104,53 @@ def n_params(self) -> int: return llama_cpp.llama_model_n_params(self.model) def get_tensor(self, name: str) -> ctypes.c_void_p: - return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8")) + raise NotImplementedError("get_tensor is not implemented in llama.cpp") # Vocab def token_get_text(self, token: int) -> str: - return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8") + return llama_cpp.llama_token_get_text(self.vocab, token).decode("utf-8") def token_get_score(self, token: int) -> float: - return llama_cpp.llama_token_get_score(self.model, token) + return llama_cpp.llama_token_get_score(self.vocab, token) def token_get_attr(self, token: int) -> int: - return llama_cpp.llama_token_get_attr(self.model, token) + return llama_cpp.llama_token_get_attr(self.vocab, token) # Special tokens def token_bos(self) -> int: - return llama_cpp.llama_token_bos(self.model) + return llama_cpp.llama_token_bos(self.vocab) def token_eos(self) -> int: - return llama_cpp.llama_token_eos(self.model) + return llama_cpp.llama_token_eos(self.vocab) def token_cls(self) -> int: - return llama_cpp.llama_token_cls(self.model) + return llama_cpp.llama_token_cls(self.vocab) def token_sep(self) -> int: - return llama_cpp.llama_token_sep(self.model) + return llama_cpp.llama_token_sep(self.vocab) def token_nl(self) -> int: - return llama_cpp.llama_token_nl(self.model) + return llama_cpp.llama_token_nl(self.vocab) def token_prefix(self) -> int: - return llama_cpp.llama_token_prefix(self.model) + raise NotImplementedError("token_prefix is not implemented in llama.cpp") def token_middle(self) -> int: - return llama_cpp.llama_token_middle(self.model) + raise NotImplementedError("token_middle is not implemented in llama.cpp") def token_suffix(self) -> int: - return llama_cpp.llama_token_suffix(self.model) + raise NotImplementedError("token_suffix is not implemented in llama.cpp") def token_eot(self) -> int: - return llama_cpp.llama_token_eot(self.model) + return llama_cpp.llama_token_eot(self.vocab) def add_bos_token(self) -> bool: - return llama_cpp.llama_add_bos_token(self.model) + return llama_cpp.llama_add_bos_token(self.vocab) def add_eos_token(self) -> bool: - return llama_cpp.llama_add_eos_token(self.model) + return llama_cpp.llama_add_eos_token(self.vocab) # Tokenization @@ -152,13 +158,13 @@ def tokenize(self, text: bytes, add_bos: bool, special: bool): n_ctx = self.n_ctx_train() tokens = (llama_cpp.llama_token * n_ctx)() n_tokens = llama_cpp.llama_tokenize( - self.model, text, len(text), tokens, n_ctx, add_bos, special + self.vocab, text, len(text), tokens, n_ctx, add_bos, special ) if n_tokens < 0: n_tokens = abs(n_tokens) tokens = (llama_cpp.llama_token * n_tokens)() n_tokens = llama_cpp.llama_tokenize( - self.model, text, len(text), tokens, n_tokens, add_bos, special + self.vocab, text, len(text), tokens, n_tokens, add_bos, special ) if n_tokens < 0: raise RuntimeError( @@ -168,7 +174,7 @@ def tokenize(self, text: bytes, add_bos: bool, special: bool): def token_to_piece(self, token: int, special: bool = False) -> bytes: buf = ctypes.create_string_buffer(32) - llama_cpp.llama_token_to_piece(self.model, token, buf, 32, 0, special) + llama_cpp.llama_token_to_piece(self.vocab, token, buf, 32, 0, special) return bytes(buf) def detokenize(self, tokens: List[int], special: bool = False) -> bytes: @@ -177,7 +183,7 @@ def detokenize(self, tokens: List[int], special: bool = False) -> bytes: buffer = (ctypes.c_char * size)() for token in tokens: n = llama_cpp.llama_token_to_piece( - self.model, llama_cpp.llama_token(token), buffer, size, 0, special + self.vocab, llama_cpp.llama_token(token), buffer, size, 0, special ) assert n <= size output += bytes(buffer[:n]) @@ -320,7 +326,8 @@ def get_embeddings(self): def set_rng_seed(self, seed: int): # TODO: Fix - llama_cpp.llama_set_rng_seed(self.ctx, seed) + # llama_cpp.llama_set_rng_seed(self.ctx, seed) + raise NotImplementedError("set_rng_seed is not implemented in llama.cpp") def sample_repetition_penalties( self, @@ -331,55 +338,63 @@ def sample_repetition_penalties( penalty_freq: float, penalty_present: float, ): - llama_cpp.llama_sample_repetition_penalties( - self.ctx, - llama_cpp.byref(candidates.candidates), - last_tokens_data, - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - ) + # llama_cpp.llama_sample_repetition_penalties( + # self.ctx, + # llama_cpp.byref(candidates.candidates), + # last_tokens_data, + # penalty_last_n, + # penalty_repeat, + # penalty_freq, + # penalty_present, + # ) + raise NotImplementedError("sample_repetition_penalties is not implemented in llama.cpp") def sample_softmax(self, candidates: "_LlamaTokenDataArray"): - llama_cpp.llama_sample_softmax( - self.ctx, - llama_cpp.byref(candidates.candidates), - ) + # llama_cpp.llama_sample_softmax( + # self.ctx, + # llama_cpp.byref(candidates.candidates), + # ) + raise NotImplementedError("sample_softmax is not implemented in llama.cpp") def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int): - llama_cpp.llama_sample_top_k( - self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep - ) + # llama_cpp.llama_sample_top_k( + # self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep + # ) + raise NotImplementedError("sample_top_k is not implemented in llama.cpp") def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): - llama_cpp.llama_sample_top_p( - self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep - ) + # llama_cpp.llama_sample_top_p( + # self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep + # ) + raise NotImplementedError("sample_top_p is not implemented in llama.cpp") def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): - llama_cpp.llama_sample_min_p( - self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep - ) + # llama_cpp.llama_sample_min_p( + # self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep + # ) + raise NotImplementedError("sample_min_p is not implemented in llama.cpp") def sample_typical( self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int ): - llama_cpp.llama_sample_typical( - self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep - ) + # llama_cpp.llama_sample_typical( + # self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep + # ) + raise NotImplementedError("sample_typical is not implemented in llama.cpp") def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float): - llama_cpp.llama_sample_temp( - self.ctx, llama_cpp.byref(candidates.candidates), temp - ) + # llama_cpp.llama_sample_temp( + # self.ctx, llama_cpp.byref(candidates.candidates), temp + # ) + raise NotImplementedError("sample_temp is not implemented in llama.cpp") def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar): - llama_cpp.llama_sample_grammar( - self.ctx, - llama_cpp.byref(candidates.candidates), - grammar.grammar, - ) + # llama_cpp.llama_sample_grammar( + # self.ctx, + # llama_cpp.byref(candidates.candidates), + # grammar.grammar, + # ) + raise NotImplementedError("sample_grammar is not implemented in llama.cpp") def sample_token_mirostat( self, @@ -389,14 +404,15 @@ def sample_token_mirostat( m: int, mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float], ) -> int: - return llama_cpp.llama_sample_token_mirostat( - self.ctx, - llama_cpp.byref(candidates.candidates), - tau, - eta, - m, - mu, - ) + raise NotImplementedError("sample_token_mirostat is not implemented in llama.cpp") + # return llama_cpp.llama_sample_token_mirostat( + # self.ctx, + # llama_cpp.byref(candidates.candidates), + # tau, + # eta, + # m, + # mu, + # ) def sample_token_mirostat_v2( self, @@ -405,29 +421,33 @@ def sample_token_mirostat_v2( eta: float, mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float], ) -> int: - return llama_cpp.llama_sample_token_mirostat_v2( - self.ctx, - llama_cpp.byref(candidates.candidates), - tau, - eta, - mu, - ) + raise NotImplementedError("sample_token_mirostat_v2 is not implemented in llama.cpp") + # return llama_cpp.llama_sample_token_mirostat_v2( + # self.ctx, + # llama_cpp.byref(candidates.candidates), + # tau, + # eta, + # mu, + # ) def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int: - return llama_cpp.llama_sample_token_greedy( - self.ctx, - llama_cpp.byref(candidates.candidates), - ) + raise NotImplementedError("sample_token_greedy is not implemented in llama.cpp") + # return llama_cpp.llama_sample_token_greedy( + # self.ctx, + # llama_cpp.byref(candidates.candidates), + # ) def sample_token(self, candidates: "_LlamaTokenDataArray") -> int: - return llama_cpp.llama_sample_token( - self.ctx, - llama_cpp.byref(candidates.candidates), - ) + raise NotImplementedError("sample_token is not implemented in llama.cpp") + # return llama_cpp.llama_sample_token( + # self.ctx, + # llama_cpp.byref(candidates.candidates), + # ) # Grammar def grammar_accept_token(self, grammar: LlamaGrammar, token: int): - llama_cpp.llama_grammar_accept_token(grammar.grammar, self.ctx, token) + raise NotImplementedError("grammar_accept_token is not implemented in llama.cpp") + # llama_cpp.llama_grammar_accept_token(grammar.grammar, self.ctx, token) def reset_timings(self): llama_cpp.llama_perf_context_reset(self.ctx) @@ -788,7 +808,7 @@ def add_mirostat_v2(self, seed: int, tau: float, eta: float): def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar): sampler = llama_cpp.llama_sampler_init_grammar( - model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8") + model.vocab, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8") ) self._add_sampler(sampler) @@ -842,6 +862,7 @@ def get_seed(self) -> int: def sample(self, ctx: LlamaContext, idx: int) -> int: assert self.sampler is not None + assert ctx.ctx is not None return llama_cpp.llama_sampler_sample(self.sampler, ctx.ctx, idx) def close(self): diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index a71750671..7e9a6af23 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -409,10 +409,10 @@ def __init__( ) ) - self._lora_adapter: Optional[llama_cpp.llama_lora_adapter_p] = None + self._lora_adapter: Optional[llama_cpp.llama_adapter_lora_p] = None if self.lora_path: - self._lora_adapter = llama_cpp.llama_lora_adapter_init( + self._lora_adapter = llama_cpp.llama_adapter_lora_init( self._model.model, self.lora_path.encode("utf-8"), ) @@ -424,12 +424,12 @@ def __init__( def free_lora_adapter(): if self._lora_adapter is None: return - llama_cpp.llama_lora_adapter_free(self._lora_adapter) + llama_cpp.llama_adapter_lora_free(self._lora_adapter) self._lora_adapter = None self._stack.callback(free_lora_adapter) - if llama_cpp.llama_lora_adapter_set( + if llama_cpp.llama_set_adapter_lora( self._ctx.ctx, self._lora_adapter, self.lora_scale ): raise RuntimeError( @@ -1155,9 +1155,9 @@ def _create_completion( bos_token_id: int = self.token_bos() cls_token_id: int = self._model.token_cls() sep_token_id: int = self._model.token_sep() - prefix_token_id: int = self._model.token_prefix() - middle_token_id: int = self._model.token_middle() - suffix_token_id: int = self._model.token_suffix() + prefix_token_id: int = 0 # self._model.token_prefix() # TODO: Fix + middle_token_id: int = 0 # self._model.token_middle() # TODO: Fix + suffix_token_id: int = 0 # self._model.token_suffix() # TODO: Fix add_space_prefix: bool = ( self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true" ) @@ -1335,7 +1335,7 @@ def logit_bias_processor( logits_processor=logits_processor, grammar=grammar, ): - if llama_cpp.llama_token_is_eog(self._model.model, token): + if llama_cpp.llama_token_is_eog(self._model.vocab, token): text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) finish_reason = "stop" break diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 14f8b0b2c..205d89a0b 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -149,6 +149,10 @@ # define LLAMA_STATE_SEQ_VERSION 2 LLAMA_STATE_SEQ_VERSION = 2 +# struct llama_vocab; +llama_vocab_p = NewType("llama_vocab_p", int) +llama_vocab_p_ctypes = ctypes.c_void_p + # struct llama_model; llama_model_p = NewType("llama_model_p", int) llama_model_p_ctypes = ctypes.c_void_p @@ -640,9 +644,6 @@ class llama_model_kv_override(ctypes.Structure): # // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() # const float * tensor_split; -# // comma separated list of RPC servers to use for offloading -# const char * rpc_servers; - # // Called with a progress value between 0.0 and 1.0. Pass NULL to disable. # // If the provided progress_callback returns true, model loading continues. # // If it returns false, model loading is immediately aborted. @@ -669,7 +670,6 @@ class llama_model_params(ctypes.Structure): split_mode (int): how to split the model across multiple GPUs main_gpu (int): the GPU that is used for the entire model. main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results LLAMA_SPLIT_LAYER: ignored tensor_split (ctypes.Array[ctypes.ctypes.c_float]): proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() - rpc_servers (ctypes.c_char_p): comma separated list of RPC servers to use for offloading progress_callback (llama_progress_callback): called with a progress value between 0.0 and 1.0. Pass NULL to disable. If the provided progress_callback returns true, model loading continues. If it returns false, model loading is immediately aborted. progress_callback_user_data (ctypes.ctypes.c_void_p): context pointer passed to the progress callback kv_overrides (ctypes.Array[llama_model_kv_override]): override key-value pairs of the model meta data @@ -683,7 +683,6 @@ class llama_model_params(ctypes.Structure): split_mode: int main_gpu: int tensor_split: CtypesArray[ctypes.c_float] - rpc_servers: ctypes.c_char_p progress_callback: Callable[[float, ctypes.c_void_p], bool] progress_callback_user_data: ctypes.c_void_p kv_overrides: CtypesArray[llama_model_kv_override] @@ -698,7 +697,6 @@ class llama_model_params(ctypes.Structure): ("split_mode", ctypes.c_int), ("main_gpu", ctypes.c_int32), ("tensor_split", ctypes.POINTER(ctypes.c_float)), - ("rpc_servers", ctypes.c_char_p), ("progress_callback", llama_progress_callback), ("progress_callback_user_data", ctypes.c_void_p), ("kv_overrides", ctypes.POINTER(llama_model_kv_override)), @@ -978,9 +976,9 @@ class llama_chat_message(ctypes.Structure): # // lora adapter -# struct llama_lora_adapter; -llama_lora_adapter_p = ctypes.c_void_p -llama_lora_adapter_p_ctypes = ctypes.POINTER(ctypes.c_void_p) +# struct llama_adapter_lora; +llama_adapter_lora_p = ctypes.c_void_p +llama_adapter_lora_p_ctypes = ctypes.POINTER(ctypes.c_void_p) # // Helpers for getting default parameters @@ -1062,6 +1060,18 @@ def llama_backend_init(): GGML_NUMA_STRATEGY_COUNT = 5 +# // Call once at the end of the program - currently only used for MPI +# LLAMA_API void llama_backend_free(void); +@ctypes_function( + "llama_backend_free", + [], + None, +) +def llama_backend_free(): + """Call once at the end of the program - currently only used for MPI""" + ... + + # //optional: # LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa); @ctypes_function( @@ -1075,24 +1085,14 @@ def llama_numa_init(numa: int, /): # // Optional: an auto threadpool gets created in ggml if not passed explicitly # LLAMA_API void llama_attach_threadpool( -# struct llama_context * ctx, -# ggml_threadpool_t threadpool, -# ggml_threadpool_t threadpool_batch); +# struct llama_context * ctx, +# ggml_threadpool_t threadpool, +# ggml_threadpool_t threadpool_batch); +# TODO: Add llama_attach_threadpool # LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); - - -# // Call once at the end of the program - currently only used for MPI -# LLAMA_API void llama_backend_free(void); -@ctypes_function( - "llama_backend_free", - [], - None, -) -def llama_backend_free(): - """Call once at the end of the program - currently only used for MPI""" - ... +# TODO: Add llama_detach_threadpool # DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( @@ -1110,6 +1110,9 @@ def llama_load_model_from_file( ... +# // Load the model from a file +# // If the file is split into multiple parts, the file name must follow this pattern: -%05d-of-%05d.gguf +# // If the split file name does not follow this pattern, use llama_model_load_from_splits # LLAMA_API struct llama_model * llama_model_load_from_file( # const char * path_model, # struct llama_model_params params); @@ -1121,6 +1124,31 @@ def llama_load_model_from_file( def llama_model_load_from_file( path_model: bytes, params: llama_model_params, / ) -> Optional[llama_model_p]: + """Load the model from a file + + If the file is split into multiple parts, the file name must follow this pattern: -%05d-of-%05d.gguf + + If the split file name does not follow this pattern, use llama_model_load_from_splits""" + ... + + +# // Load the model from multiple splits (support custom naming scheme) +# // The paths must be in the correct order +# LLAMA_API struct llama_model * llama_model_load_from_splits( +# const char ** paths, +# size_t n_paths, +# struct llama_model_params params); +@ctypes_function( + "llama_model_load_from_splits", + [ctypes.POINTER(ctypes.c_char_p), ctypes.c_size_t, llama_model_params], + llama_model_p_ctypes, +) +def llama_model_load_from_splits( + paths: List[bytes], n_paths: int, params: llama_model_params, / +) -> Optional[llama_model_p]: + """Load the model from multiple splits (support custom naming scheme) + + The paths must be in the correct order""" ... @@ -1144,9 +1172,24 @@ def llama_model_free(model: llama_model_p, /): ... -# LLAMA_API struct llama_context * llama_new_context_with_model( +# LLAMA_API struct llama_context * llama_init_from_model( # struct llama_model * model, # struct llama_context_params params); +@ctypes_function( + "llama_init_from_model", + [llama_model_p_ctypes, llama_context_params], + llama_context_p_ctypes, +) +def llama_init_from_model( + model: llama_model_p, params: llama_context_params, / +) -> Optional[llama_context_p]: + ... + + +# DEPRECATED(LLAMA_API struct llama_context * llama_new_context_with_model( +# struct llama_model * model, +# struct llama_context_params params), +# "use llama_init_from_model instead"); @ctypes_function( "llama_new_context_with_model", [llama_model_p_ctypes, llama_context_params], @@ -1234,65 +1277,102 @@ def llama_n_seq_max(ctx: llama_context_p, /) -> int: ... -# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); -@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32) -def llama_n_vocab(model: llama_model_p, /) -> int: - ... -# LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); +# DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); @ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32) def llama_n_ctx_train(model: llama_model_p, /) -> int: ... -# LLAMA_API int32_t llama_n_embd (const struct llama_model * model); +# DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); @ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32) def llama_n_embd(model: llama_model_p, /) -> int: ... -# LLAMA_API int32_t llama_n_layer (const struct llama_model * model); +# DEPRECATED(LLAMA_API int32_t llama_n_layer (const struct llama_model * model), "use llama_model_n_layer instead"); @ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32) def llama_n_layer(model: llama_model_p, /) -> int: ... -# LLAMA_API int32_t llama_n_head (const struct llama_model * model); +# DEPRECATED(LLAMA_API int32_t llama_n_head (const struct llama_model * model), "use llama_model_n_head instead"); @ctypes_function("llama_n_head", [llama_model_p_ctypes], ctypes.c_int32) def llama_n_head(model: llama_model_p, /) -> int: ... -# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); +# DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); +@ctypes_function("llama_n_vocab", [llama_vocab_p_ctypes], ctypes.c_int32) +def llama_n_vocab(model: llama_vocab_p, /) -> int: + ... + + +# LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); @ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes) def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ... -# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); +# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); @ctypes_function("llama_pooling_type", [llama_context_p_ctypes], ctypes.c_int) def llama_pooling_type(ctx: llama_context_p, /) -> int: ... -# LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); -@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int) -def llama_vocab_type(model: llama_model_p, /) -> int: +# LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); +@ctypes_function("llama_model_get_vocab", [llama_model_p_ctypes], llama_vocab_p_ctypes) +def llama_model_get_vocab(model: llama_model_p, /) -> Optional[llama_vocab_p]: + ... + + +# LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); +@ctypes_function("llama_model_rope_type", [llama_model_p_ctypes], ctypes.c_int) +def llama_model_rope_type(model: llama_model_p, /) -> int: + ... + + +# LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); +@ctypes_function("llama_model_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32) +def llama_model_n_ctx_train(model: llama_model_p, /) -> int: ... -# LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); -@ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int) -def llama_rope_type(model: llama_model_p, /) -> int: +# LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); +@ctypes_function("llama_model_n_embd", [llama_model_p_ctypes], ctypes.c_int32) +def llama_model_n_embd(model: llama_model_p, /) -> int: + ... + + +# LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); +@ctypes_function("llama_model_n_layer", [llama_model_p_ctypes], ctypes.c_int32) +def llama_model_n_layer(model: llama_model_p, /) -> int: + ... + + +# LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); +@ctypes_function("llama_model_n_head", [llama_model_p_ctypes], ctypes.c_int32) +def llama_model_n_head(model: llama_model_p, /) -> int: ... # // Get the model's RoPE frequency scaling factor -# LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); -@ctypes_function("llama_rope_freq_scale_train", [llama_model_p_ctypes], ctypes.c_float) -def llama_rope_freq_scale_train(model: llama_model_p, /) -> float: - """Get the model's RoPE frequency scaling factor""" +# LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); +@ctypes_function("llama_model_rope_freq_scale_train", [llama_model_p_ctypes], ctypes.c_float) +def llama_model_rope_freq_scale_train(model: llama_model_p, /) -> float: + ... + + +# LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); +@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int) +def llama_vocab_type(model: llama_model_p, /) -> int: + ... + + +# LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab); +@ctypes_function("llama_vocab_n_tokens", [llama_vocab_p_ctypes], ctypes.c_int32) +def llama_vocab_n_tokens(vocab: llama_vocab_p, /) -> int: ... @@ -1405,6 +1485,16 @@ def llama_model_size(model: llama_model_p, /) -> int: ... +# // Get the default chat template. Returns nullptr if not available +# // If name is NULL, returns the default chat template +# LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name); +@ctypes_function("llama_model_chat_template", [llama_model_p_ctypes, ctypes.c_char_p], ctypes.c_char_p) +def llama_model_chat_template(model: llama_model_p, name: Optional[bytes], /) -> Optional[bytes]: + """Get the default chat template. Returns None if not available + If name is None, returns the default chat template""" + ... + + # // Returns the total number of parameters in the model # LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); @ctypes_function("llama_model_n_params", [llama_model_p_ctypes], ctypes.c_uint64) @@ -1475,37 +1565,48 @@ def llama_model_quantize( # // Load a LoRA adapter from file -# // The loaded adapter will be associated to the given model, and will be free when the model is deleted -# LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init( +# LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( # struct llama_model * model, # const char * path_lora); @ctypes_function( - "llama_lora_adapter_init", + "llama_adapter_lora_init", [llama_model_p_ctypes, ctypes.c_char_p], - llama_lora_adapter_p_ctypes, + llama_adapter_lora_p_ctypes, ) -def llama_lora_adapter_init( +def llama_adapter_lora_init( model: llama_model_p, path_lora: bytes, / -) -> Optional[llama_lora_adapter_p]: - """Load a LoRA adapter from file - The loaded adapter will be associated to the given model, and will be free when the model is deleted - """ +) -> Optional[llama_adapter_lora_p]: + ... + + +# // Manually free a LoRA adapter +# // Note: loaded adapters will be free when the associated model is deleted +# LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); +@ctypes_function( + "llama_adapter_lora_free", + [llama_adapter_lora_p_ctypes], + None, +) +def llama_adapter_lora_free(adapter: llama_adapter_lora_p, /): ... +# // The following functions operate on a llama_context, hence the naming: llama_verb_... + + # // Add a loaded LoRA adapter to given context # // This will not modify model's weight -# LLAMA_API int32_t llama_lora_adapter_set( +# LLAMA_API int32_t llama_set_adapter_lora( # struct llama_context * ctx, -# struct llama_lora_adapter * adapter, +# struct llama_adapter_lora * adapter, # float scale); @ctypes_function( - "llama_lora_adapter_set", - [llama_context_p_ctypes, llama_lora_adapter_p_ctypes, ctypes.c_float], + "llama_set_adapter_lora", + [llama_context_p_ctypes, llama_adapter_lora_p_ctypes, ctypes.c_float], ctypes.c_int32, ) -def llama_lora_adapter_set( - ctx: llama_context_p, adapter: llama_lora_adapter_p, scale: float, / +def llama_set_adapter_lora( + ctx: llama_context_p, adapter: llama_adapter_lora_p, scale: float, / ) -> int: """Add a loaded LoRA adapter to given context This will not modify model's weight""" @@ -1514,64 +1615,49 @@ def llama_lora_adapter_set( # // Remove a specific LoRA adapter from given context # // Return -1 if the adapter is not present in the context -# LLAMA_API int32_t llama_lora_adapter_remove( +# LLAMA_API int32_t llama_rm_adapter_lora( # struct llama_context * ctx, -# struct llama_lora_adapter * adapter); +# struct llama_adapter_lora * adapter); @ctypes_function( - "llama_lora_adapter_remove", - [llama_context_p_ctypes, llama_lora_adapter_p_ctypes], + "llama_rm_adapter_lora", + [llama_context_p_ctypes, llama_adapter_lora_p_ctypes], ctypes.c_int32, ) -def llama_lora_adapter_remove( - ctx: llama_context_p, adapter: llama_lora_adapter_p, / +def llama_rm_adapter_lora( + ctx: llama_context_p, adapter: llama_adapter_lora_p, / ) -> int: - """Remove a LoRA adapter from given context + """Remove a specific LoRA adapter from given context Return -1 if the adapter is not present in the context""" ... # // Remove all LoRA adapters from given context -# LLAMA_API void llama_lora_adapter_clear( -# struct llama_context * ctx); +# LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx); @ctypes_function( - "llama_lora_adapter_clear", + "llama_clear_adapter_lora", [llama_context_p_ctypes], None, ) -def llama_lora_adapter_clear(ctx: llama_context_p, /): +def llama_clear_adapter_lora(ctx: llama_context_p, /): """Remove all LoRA adapters from given context""" ... -# // Manually free a LoRA adapter -# // Note: loaded adapters will be free when the associated model is deleted -# LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter); -@ctypes_function( - "llama_lora_adapter_free", - [llama_lora_adapter_p_ctypes], - None, -) -def llama_lora_adapter_free(adapter: llama_lora_adapter_p, /): - """Manually free a LoRA adapter - Note: loaded adapters will be free when the associated model is deleted""" - ... - - # // Apply a loaded control vector to a llama_context, or if data is NULL, clear # // the currently loaded vector. # // n_embd should be the size of a single layer's control, and data should point # // to an n_embd x n_layers buffer starting from layer 1. # // il_start and il_end are the layer range the vector should apply to (both inclusive) # // See llama_control_vector_load in common to load a control vector. -# LLAMA_API int32_t llama_control_vector_apply( -# struct llama_context * lctx, +# LLAMA_API int32_t llama_apply_adapter_cvec( +# struct llama_context * ctx, # const float * data, # size_t len, # int32_t n_embd, # int32_t il_start, # int32_t il_end); @ctypes_function( - "llama_control_vector_apply", + "llama_apply_adapter_cvec", [ llama_context_p_ctypes, ctypes.POINTER(ctypes.c_float), @@ -1582,8 +1668,8 @@ def llama_lora_adapter_free(adapter: llama_lora_adapter_p, /): ], ctypes.c_int32, ) -def llama_control_vector_apply( - lctx: llama_context_p, +def llama_apply_adapter_cvec( + ctx: llama_context_p, data: CtypesPointerOrRef[ctypes.c_float], len: int, n_embd: int, @@ -2580,53 +2666,53 @@ def llama_get_embeddings_seq( # // -# LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token); +# LLAMA_API const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token); @ctypes_function( - "llama_token_get_text", [llama_model_p_ctypes, llama_token], ctypes.c_char_p + "llama_vocab_get_text", [llama_vocab_p_ctypes, llama_token], ctypes.c_char_p ) -def llama_token_get_text( - model: llama_model_p, token: Union[llama_token, int], / +def llama_vocab_get_text( + vocab: llama_vocab_p, token: Union[llama_token, int], / ) -> bytes: ... -# LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); +# LLAMA_API float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token); @ctypes_function( - "llama_token_get_score", [llama_model_p_ctypes, llama_token], ctypes.c_float + "llama_vocab_get_score", [llama_vocab_p_ctypes, llama_token], ctypes.c_float ) -def llama_token_get_score( - model: llama_model_p, token: Union[llama_token, int], / +def llama_vocab_get_score( + vocab: llama_vocab_p, token: Union[llama_token, int], / ) -> float: ... -# LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token); +# LLAMA_API enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token); @ctypes_function( - "llama_token_get_attr", [llama_model_p_ctypes, llama_token], ctypes.c_int + "llama_vocab_get_attr", [llama_vocab_p_ctypes, llama_token], ctypes.c_int ) -def llama_token_get_attr( - model: llama_model_p, token: Union[llama_token, int], / +def llama_vocab_get_attr( + vocab: llama_vocab_p, token: Union[llama_token, int], / ) -> int: ... # // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) -# LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); +# LLAMA_API bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token); @ctypes_function( - "llama_token_is_eog", [llama_model_p_ctypes, llama_token], ctypes.c_bool + "llama_vocab_is_eog", [llama_vocab_p_ctypes, llama_token], ctypes.c_bool ) -def llama_token_is_eog(model: llama_model_p, token: Union[llama_token, int], /) -> bool: +def llama_vocab_is_eog(vocab: llama_vocab_p, token: Union[llama_token, int], /) -> bool: """Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)""" ... # // Identify if Token Id is a control token or a render-able token -# LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token); +# LLAMA_API bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token); @ctypes_function( - "llama_token_is_control", [llama_model_p_ctypes, llama_token], ctypes.c_bool + "llama_vocab_is_control", [llama_vocab_p_ctypes, llama_token], ctypes.c_bool ) -def llama_token_is_control( - model: llama_model_p, token: Union[llama_token, int], / +def llama_vocab_is_control( + vocab: llama_vocab_p, token: Union[llama_token, int], / ) -> bool: """Identify if Token Id is a control token or a render-able token""" ... @@ -2635,110 +2721,335 @@ def llama_token_is_control( # // Special tokens -# LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence -@ctypes_function("llama_token_bos", [llama_model_p_ctypes], llama_token) -def llama_token_bos(model: llama_model_p, /) -> int: +# LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence +@ctypes_function("llama_vocab_bos", [llama_vocab_p_ctypes], llama_token) +def llama_vocab_bos(vocab: llama_vocab_p, /) -> llama_token: """beginning-of-sentence""" ... -# LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence -@ctypes_function("llama_token_eos", [llama_model_p_ctypes], llama_token) -def llama_token_eos(model: llama_model_p, /) -> int: +# LLAMA_API llama_token llama_vocab_eos(const struct llama_vocab * vocab); // end-of-sentence +@ctypes_function("llama_vocab_eos", [llama_vocab_p_ctypes], llama_token) +def llama_vocab_eos(vocab: llama_vocab_p, /) -> llama_token: """end-of-sentence""" ... -# LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn -@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token) -def llama_token_eot(model: llama_model_p, /) -> int: +# LLAMA_API llama_token llama_vocab_eot(const struct llama_vocab * vocab); // end-of-turn +@ctypes_function("llama_vocab_eot", [llama_vocab_p_ctypes], llama_token) +def llama_vocab_eot(vocab: llama_vocab_p, /) -> llama_token: """end-of-turn""" ... -# LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification -@ctypes_function("llama_token_cls", [llama_model_p_ctypes], llama_token) -def llama_token_cls(model: llama_model_p, /) -> int: - """classification""" +# LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator +@ctypes_function("llama_vocab_sep", [llama_vocab_p_ctypes], llama_token) +def llama_vocab_sep(vocab: llama_vocab_p, /) -> llama_token: + """sentence separator""" ... -# LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator -@ctypes_function("llama_token_sep", [llama_model_p_ctypes], llama_token) -def llama_token_sep(model: llama_model_p, /) -> int: - """sentence separator""" +# LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line +@ctypes_function("llama_vocab_nl", [llama_vocab_p_ctypes], llama_token) +def llama_vocab_nl(vocab: llama_vocab_p, /) -> llama_token: + """next-line""" ... -# LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line -@ctypes_function("llama_token_nl", [llama_model_p_ctypes], llama_token) -def llama_token_nl(model: llama_model_p, /) -> int: - """next-line""" +# LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding +@ctypes_function("llama_vocab_pad", [llama_vocab_p_ctypes], llama_token) +def llama_vocab_pad(vocab: llama_vocab_p, /) -> llama_token: + """padding""" + ... + +# LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab); +@ctypes_function( + "llama_vocab_get_add_bos", + [llama_vocab_p_ctypes], + ctypes.c_bool, +) +def llama_vocab_get_add_bos(vocab: llama_vocab_p, /) -> bool: ... -# LLAMA_API bool llama_add_bos_token(const struct llama_model * model); -@ctypes_function("llama_add_bos_token", [llama_model_p_ctypes], ctypes.c_bool) -def llama_add_bos_token(model: llama_model_p, /) -> bool: +# LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab); +@ctypes_function( + "llama_vocab_get_add_eos", + [llama_vocab_p_ctypes], + ctypes.c_bool, +) +def llama_vocab_get_add_eos(vocab: llama_vocab_p, /) -> bool: ... -# LLAMA_API bool llama_add_eos_token(const struct llama_model * model); -@ctypes_function("llama_add_eos_token", [llama_model_p_ctypes], ctypes.c_bool) -def llama_add_eos_token(model: llama_model_p, /) -> bool: +# LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab); +@ctypes_function( + "llama_vocab_fim_pre", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_vocab_fim_pre(vocab: llama_vocab_p, /) -> llama_token: ... -# // Codellama infill tokens -# DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead"); -@ctypes_function("llama_token_prefix", [llama_model_p_ctypes], llama_token) -def llama_token_prefix(model: llama_model_p) -> int: - """codellama infill tokens""" +# LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab); +@ctypes_function( + "llama_vocab_fim_suf", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_vocab_fim_suf(vocab: llama_vocab_p, /) -> llama_token: ... -# DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead"); -@ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token) -def llama_token_middle(model: llama_model_p, /) -> int: +# LLAMA_API llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab); +@ctypes_function( + "llama_vocab_fim_mid", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_vocab_fim_mid(vocab: llama_vocab_p, /) -> llama_token: ... -# DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead"); -@ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token) -def llama_token_suffix(model: llama_model_p, /) -> int: +# LLAMA_API llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab); +@ctypes_function( + "llama_vocab_fim_pad", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_vocab_fim_pad(vocab: llama_vocab_p, /) -> llama_token: ... -# LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model); -@ctypes_function("llama_token_fim_pre", [llama_model_p_ctypes], llama_token) -def llama_token_fim_pre(model: llama_model_p, /) -> int: +# LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab); +@ctypes_function( + "llama_vocab_fim_rep", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_vocab_fim_rep(vocab: llama_vocab_p, /) -> llama_token: ... -# LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model); -@ctypes_function("llama_token_fim_suf", [llama_model_p_ctypes], llama_token) -def llama_token_fim_suf(model: llama_model_p, /) -> int: + +# LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab); +@ctypes_function( + "llama_vocab_fim_sep", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_vocab_fim_sep(vocab: llama_vocab_p, /) -> llama_token: ... -# LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model); -@ctypes_function("llama_token_fim_mid", [llama_model_p_ctypes], llama_token) -def llama_token_fim_mid(model: llama_model_p, /) -> int: + + +# DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_text instead"); +@ctypes_function( + "llama_token_get_text", + [llama_vocab_p_ctypes, llama_token], + ctypes.c_char_p, +) +def llama_token_get_text( + vocab: llama_vocab_p, token: Union[llama_token, int], / +) -> bytes: + ... + + +# DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead"); +@ctypes_function( + "llama_token_get_score", + [llama_vocab_p_ctypes, llama_token], + ctypes.c_float, +) +def llama_token_get_score( + vocab: llama_vocab_p, token: Union[llama_token, int], / +) -> float: + ... + +# DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead"); +@ctypes_function( + "llama_token_get_attr", + [llama_vocab_p_ctypes, llama_token], + ctypes.c_int, +) +def llama_token_get_attr( + vocab: llama_vocab_p, token: Union[llama_token, int], / +) -> int: + ... + +# DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead"); +@ctypes_function( + "llama_token_is_eog", + [llama_vocab_p_ctypes, llama_token], + ctypes.c_bool, +) +def llama_token_is_eog( + vocab: llama_vocab_p, token: Union[llama_token, int], / +) -> bool: + ... + +# DEPRECATED(LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_control instead"); +@ctypes_function( + "llama_token_is_control", + [llama_vocab_p_ctypes, llama_token], + ctypes.c_bool, +) +def llama_token_is_control( + vocab: llama_vocab_p, token: Union[llama_token, int], / +) -> bool: + ... + +# DEPRECATED(LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab), "use llama_vocab_bos instead"); +@ctypes_function( + "llama_token_bos", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_token_bos(vocab: llama_vocab_p, /) -> int: + ... + +# DEPRECATED(LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab), "use llama_vocab_eos instead"); +@ctypes_function( + "llama_token_eos", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_token_eos(vocab: llama_vocab_p, /) -> int: + ... + +# DEPRECATED(LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab), "use llama_vocab_eot instead"); +@ctypes_function( + "llama_token_eot", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_token_eot(vocab: llama_vocab_p, /) -> int: + ... + +# DEPRECATED(LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab), "use llama_vocab_cls instead"); +@ctypes_function( + "llama_token_cls", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_token_cls(vocab: llama_vocab_p, /) -> int: + ... + +# DEPRECATED(LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab), "use llama_vocab_sep instead"); +@ctypes_function( + "llama_token_sep", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_token_sep(vocab: llama_vocab_p, /) -> int: + ... + + +# DEPRECATED(LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab), "use llama_vocab_nl instead"); +@ctypes_function( + "llama_token_nl", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_token_nl(vocab: llama_vocab_p, /) -> int: + ... + + +# DEPRECATED(LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab), "use llama_vocab_pad instead"); +@ctypes_function( + "llama_token_pad", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_token_pad(vocab: llama_vocab_p, /) -> int: + ... + + +# DEPRECATED(LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_bos instead"); +@ctypes_function( + "llama_add_bos_token", + [llama_vocab_p_ctypes], + ctypes.c_bool, +) +def llama_add_bos_token(vocab: llama_vocab_p, /) -> bool: + ... + +# DEPRECATED(LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_eos instead"); +@ctypes_function( + "llama_add_eos_token", + [llama_vocab_p_ctypes], + ctypes.c_bool, +) +def llama_add_eos_token(vocab: llama_vocab_p, /) -> bool: + ... + + +# DEPRECATED(LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab), "use llama_vocab_fim_pre instead"); +@ctypes_function( + "llama_token_fim_pre", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_token_fim_pre(vocab: llama_vocab_p, /) -> llama_token: ... -# LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model); -@ctypes_function("llama_token_fim_pad", [llama_model_p_ctypes], llama_token) -def llama_token_fim_pad(model: llama_model_p, /) -> int: +# DEPRECATED(LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab), "use llama_vocab_fim_suf instead"); +@ctypes_function( + "llama_token_fim_suf", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_token_fim_suf(vocab: llama_vocab_p, /) -> llama_token: ... -# LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model); -@ctypes_function("llama_token_fim_rep", [llama_model_p_ctypes], llama_token) -def llama_token_fim_rep(model: llama_model_p, /) -> int: +# DEPRECATED(LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab), "use llama_vocab_fim_mid instead"); +@ctypes_function( + "llama_token_fim_mid", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_token_fim_mid(vocab: llama_vocab_p, /) -> llama_token: ... -# LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model); -@ctypes_function("llama_token_fim_sep", [llama_model_p_ctypes], llama_token) -def llama_token_fim_sep(model: llama_model_p, /) -> int: +# DEPRECATED(LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab), "use llama_vocab_fim_pad instead"); +@ctypes_function( + "llama_token_fim_pad", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_token_fim_pad(vocab: llama_vocab_p, /) -> llama_token: ... +# DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead"); +@ctypes_function( + "llama_token_fim_rep", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_token_fim_rep(vocab: llama_vocab_p, /) -> llama_token: + ... + +# DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead"); +@ctypes_function( + "llama_token_fim_sep", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_token_fim_sep(vocab: llama_vocab_p, /) -> llama_token: + ... + +# // CLS is equivalent to BOS +# DEPRECATED(LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab), // classification +# "use llama_vocab_bos instead"); +@ctypes_function( + "llama_vocab_cls", + [llama_vocab_p_ctypes], + llama_token, +) +def llama_vocab_cls(vocab: llama_vocab_p, /) -> llama_token: + ... + + # // # // Tokenization # // @@ -2754,7 +3065,7 @@ def llama_token_fim_sep(model: llama_model_p, /) -> int: # /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated # /// as plaintext. Does not insert a leading space. # LLAMA_API int32_t llama_tokenize( -# const struct llama_model * model, +# const struct llama_vocab * vocab, # const char * text, # int32_t text_len, # llama_token * tokens, @@ -2764,7 +3075,7 @@ def llama_token_fim_sep(model: llama_model_p, /) -> int: @ctypes_function( "llama_tokenize", [ - llama_model_p_ctypes, + llama_vocab_p_ctypes, ctypes.c_char_p, ctypes.c_int32, llama_token_p, @@ -2775,7 +3086,7 @@ def llama_token_fim_sep(model: llama_model_p, /) -> int: ctypes.c_int32, ) def llama_tokenize( - model: llama_model_p, + vocab: llama_vocab_p, text: bytes, text_len: Union[ctypes.c_int, int], tokens: CtypesArray[llama_token], @@ -2787,7 +3098,7 @@ def llama_tokenize( """Convert the provided text into tokens. Args: - model: The model to use for tokenization. + vocab: The vocabulary to use for tokenization. text: The text to tokenize. text_len: The length of the text. tokens: The tokens pointer must be large enough to hold the resulting tokens. @@ -2808,7 +3119,7 @@ def llama_tokenize( # // User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix') # // @param special If true, special tokens are rendered in the output. # LLAMA_API int32_t llama_token_to_piece( -# const struct llama_model * model, +# const struct llama_vocab * vocab, # llama_token token, # char * buf, # int32_t length, @@ -2817,7 +3128,7 @@ def llama_tokenize( @ctypes_function( "llama_token_to_piece", [ - llama_model_p_ctypes, + llama_vocab_p_ctypes, llama_token, ctypes.c_char_p, ctypes.c_int32, @@ -2827,7 +3138,7 @@ def llama_tokenize( ctypes.c_int32, ) def llama_token_to_piece( - model: llama_model_p, + vocab: llama_vocab_p, token: Union[llama_token, int], buf: Union[ctypes.c_char_p, bytes, CtypesArray[ctypes.c_char]], length: Union[ctypes.c_int, int], @@ -2841,7 +3152,7 @@ def llama_token_to_piece( User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. Args: - model: The model to use for tokenization. + vocab: The vocabulary to use for tokenization. token: The token to convert. buf: The buffer to write the token to. length: The length of the buffer. @@ -2933,7 +3244,6 @@ def llama_detokenize( # /// @param length The size of the allocated buffer # /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. # LLAMA_API int32_t llama_chat_apply_template( -# const struct llama_model * model, # const char * tmpl, # const struct llama_chat_message * chat, # size_t n_msg, @@ -2943,20 +3253,37 @@ def llama_detokenize( @ctypes_function( "llama_chat_apply_template", [ - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.POINTER(llama_chat_message), - ctypes.c_size_t, + ctypes.c_char_p, # tmpl + ctypes.POINTER(llama_chat_message), # chat + ctypes.c_size_t, # n_msg + ctypes.c_bool, # add_ass (added) + ctypes.c_char_p, # buf + ctypes.c_int32, # length ], ctypes.c_int32, ) def llama_chat_apply_template( - model: llama_model_p, tmpl: bytes, chat: CtypesArray[llama_chat_message], n_msg: int, + add_ass: bool, # Added parameter + buf: bytes, + length: int, /, ) -> int: + """Apply chat template. + + Args: + tmpl: Template to use. If None, uses model's default + chat: Array of chat messages + n_msg: Number of messages + add_ass: Whether to end prompt with assistant token + buf: Output buffer + length: Buffer length + + Returns: + Number of bytes written, or needed if buffer too small + """ ... @@ -3025,7 +3352,6 @@ def llama_chat_builtin_templates( # // llama_sampler_free(smpl); # // # // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). -# // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab # // # typedef void * llama_sampler_context_t; @@ -3345,16 +3671,16 @@ def llama_sampler_init_mirostat_v2( # LLAMA_API struct llama_sampler * llama_sampler_init_grammar( -# const struct llama_model * model, +# const struct llama_vocab * vocab, # const char * grammar_str, # const char * grammar_root); @ctypes_function( "llama_sampler_init_grammar", - [llama_model_p_ctypes, ctypes.c_char_p, ctypes.c_char_p], + [llama_vocab_p_ctypes, ctypes.c_char_p, ctypes.c_char_p], llama_sampler_p_ctypes, ) def llama_sampler_init_grammar( - model: llama_model_p, grammar_str: bytes, grammar_root: bytes, / + vocab: llama_vocab_p, grammar_str: bytes, grammar_root: bytes, / ) -> llama_sampler_p: ... @@ -3382,7 +3708,8 @@ def llama_sampler_init_penalties( # /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 # LLAMA_API struct llama_sampler * llama_sampler_init_dry( -# const struct llama_model * model, +# const struct llama_vocab * vocab, +# int32_t n_ctx_train, # float dry_multiplier, # float dry_base, # int32_t dry_allowed_length, @@ -3392,7 +3719,8 @@ def llama_sampler_init_penalties( @ctypes_function( "llama_sampler_init_dry", [ - llama_model_p_ctypes, + llama_vocab_p_ctypes, + ctypes.c_int32, ctypes.c_float, ctypes.c_float, ctypes.c_int32, @@ -3403,7 +3731,8 @@ def llama_sampler_init_penalties( llama_sampler_p_ctypes, ) def llama_sampler_init_dry( - model: llama_model_p, + vocab: llama_vocab_p, + n_ctx_train: int, dry_multiplier: float, dry_base: float, dry_allowed_length: int, @@ -3451,15 +3780,13 @@ def llama_sampler_init_logit_bias( # // 3. discard non-EOG tokens with low prob # // 4. if no tokens are left -> pick EOT # // -# LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model); +# LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab); @ctypes_function( "llama_sampler_init_infill", - [llama_model_p_ctypes], + [llama_vocab_p_ctypes], llama_sampler_p_ctypes, ) -def llama_sampler_init_infill(model: llama_model_p, /) -> llama_sampler_p: - """This sampler is meant to be used for fill-in-the-middle infilling. - """ +def llama_sampler_init_infill(vocab: llama_vocab_p, /) -> llama_sampler_p: ... diff --git a/vendor/llama.cpp b/vendor/llama.cpp index f7cd13301..794fe23f2 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit f7cd13301c2a88f97073fd119072b4cc92c08df1 +Subproject commit 794fe23f29fb40104975c91fe19f23798f7c726e