From a8efee72a015adade811dcd3f4e4f67bde800157 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 12 Feb 2025 14:49:49 -0800 Subject: [PATCH 01/12] init --- examples/apple/coreml/llama/export.py | 210 +++++++++ .../apple/coreml/llama/extract_and_combine.py | 69 +++ .../apple/coreml/llama/llama_transformer.py | 407 ++++++++++++++++++ 3 files changed, 686 insertions(+) create mode 100644 examples/apple/coreml/llama/export.py create mode 100644 examples/apple/coreml/llama/extract_and_combine.py create mode 100644 examples/apple/coreml/llama/llama_transformer.py diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py new file mode 100644 index 0000000000..243055f9d0 --- /dev/null +++ b/examples/apple/coreml/llama/export.py @@ -0,0 +1,210 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +import argparse +import json + +import coremltools as ct +import torch +from executorch.backends.apple.coreml.compiler import CoreMLBackend # pyre-ignore +from executorch.backends.apple.coreml.partition import CoreMLPartitioner # pyre-ignore +from executorch.examples.models.llama.source_transformation.quantize import ( + EmbeddingQuantHandler, +) + +from executorch.exir.backend.utils import format_delegated_graph +from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig +from executorch.exir.passes import MemoryPlanningPass +from executorch.exir.passes.quant_fusion_pass import QuantFusionPass +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass +from executorch.extension.export_util.utils import export_to_edge, save_pte_program + +import sys +sys.path.insert(0, "..") +from llama.llama_transformer import ( + ModelArgs, + Transformer, +) + + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "-n", + "--output_name", + default="model.pte", + help="Override the output filename of the saved pte model file.", + ) + parser.add_argument( + "-p", + "--params", + help="config.json", + ) + parser.add_argument( + "-c", + "--checkpoint", + help="checkpoint path", + ) + parser.add_argument( + "--static_seq_length", + type=int, + default=1, # set to 1 for decode + help="length sequence to evaluate", + ) + parser.add_argument( + "--max_seq_length", + type=int, + default=128, + help="maximum length sequence to evaluate", + ) + parser.add_argument( + "-E", + "--embedding-quantize", + default=None, + type=str, + help="type of embedding quantization, ',', e.g., '8,1024'.", + ) + parser.add_argument( + "--coreml-quantize", + default="c4w", + choices=["b4w", "c4w"], + help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight), c4w (for channelwise 4 bit weight)", + ) + + export_args = parser.parse_args() + params_path = export_args.params + checkpoint_path = export_args.checkpoint + + # Load model args + with open(params_path, "r") as f: + params = json.loads(f.read()) + + args = ModelArgs( + max_seq_len=export_args.max_seq_length, + generate_full_logits=False, + **params, + ) + + with torch.device("meta"): + model = Transformer(args) + + checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True) + if "model" in checkpoint: + checkpoint = checkpoint["model"] + + missing, unexpected = model.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) + print("Missing keys: ", missing) + print("Unexpected keys: ", unexpected) + + float_dtype = torch.float16 # dtype for model/inputs + + assert export_args.static_seq_length < args.max_seq_len + + cache_shape = ( + args.n_layers, + args.max_batch_size, + args.n_kv_heads, + args.max_seq_len - export_args.static_seq_length, + args.head_dim, + ) + attn_mask_shape = (export_args.static_seq_length, args.max_seq_len) + + example_inputs = ( + torch.tensor( + [0 for _ in range(export_args.static_seq_length)], dtype=torch.long + ).reshape(1, -1), # tokens + torch.tensor([0], dtype=torch.long), # input_pos + torch.zeros(cache_shape, dtype=float_dtype), # k_cache + torch.zeros(cache_shape, dtype=float_dtype), # v_cache + torch.zeros(attn_mask_shape, dtype=float_dtype), # attn_mask + ) + model.eval() + model.to(float_dtype) + + if export_args.embedding_quantize: + bitwidth, group_size = export_args.embedding_quantize.split(",") + if group_size == "none" or group_size == "None" or group_size == "0": + group_size = None + else: + group_size = int(group_size) + bitwidth = int(bitwidth) + model = EmbeddingQuantHandler( + model, + bitwidth=bitwidth, + group_size=group_size, + packed=(bitwidth in [2, 4]), + ).quantized_model() + + if export_args.coreml_quantize == "b4w": + op_linear_quantizer_config = { + "mode": "linear_symmetric", + "dtype": "int4", + "granularity": "per_block", + "block_size": 32, + "weight_threshold": 512, + } + elif export_args.coreml_quantize == "c4w": + op_linear_quantizer_config = { + "mode": "linear_symmetric", + "dtype": "int4", + "granularity": "per_channel", + } + else: + raise ValueError("Invalid coreml_quantize arg") + + compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] + minimum_deployment_target=ct.target.iOS18, + compute_precision=ct.precision(ct.precision.FLOAT16.value), + compute_unit=ct.ComputeUnit.CPU_AND_NE, + model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16] + op_linear_quantizer_config=op_linear_quantizer_config, + ) + partitioner = CoreMLPartitioner( # pyre-fixme[16] + compile_specs=compile_specs, + take_over_mutable_buffer=False, + skip_ops_for_coreml_delegation=[ + "quantized_decomposed.embedding_4bit.dtype", + "aten.embedding.default", + ], + ) + + edge_manager = export_to_edge( + model, + example_inputs, + edge_compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_type_promotion=(float_dtype == torch.float16), + _skip_dim_order=True, + ), + ) + print("Edge program") + print(edge_manager.exported_program()) + + edge_manager = edge_manager.to_backend(partitioner) + + print("Delegated program") + + print(format_delegated_graph(edge_manager.exported_program().graph_module)) + + executorch_program = edge_manager.to_executorch( + ExecutorchBackendConfig( + extract_delegate_segments=True, + passes=[ + QuantFusionPass(), + ], + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), + ) + ) + + filename = save_pte_program(executorch_program, export_args.output_name) + print(f"Saved Executorch program to local {filename}") + +if __name__ == "__main__": + main() # pragma: no cover diff --git a/examples/apple/coreml/llama/extract_and_combine.py b/examples/apple/coreml/llama/extract_and_combine.py new file mode 100644 index 0000000000..404553094e --- /dev/null +++ b/examples/apple/coreml/llama/extract_and_combine.py @@ -0,0 +1,69 @@ +import coremltools as ct +import argparse +import os +import subprocess +import shutil + +if __name__ == "__main__": + """ + Extract mlpackage from two CoreML pte files, and combine them into one mlpackage using multifunction + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "-m1", + "--model1_path", + type=str, + help="Model1 path.", + ) + parser.add_argument( + "-m2", + "--model2_path", + type=str, + help="Model2 path.", + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + help="Output path to save combined model", + ) + + args = parser.parse_args() + model1_path = str(args.model1_path) + model2_path = str(args.model2_path) + output_dir = str(args.output_dir) + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + os.makedirs(output_dir) + + extract_script_path = os.path.join(os.path.dirname(__file__), "../scripts/extract_coreml_models.py") + extracted_path = "extracted_coreml_models/model_1/lowered_module/model.mlpackage" + + subprocess.run(["python", extract_script_path, "--model", model1_path]) + items = os.listdir("extracted_coreml_models") + assert len(items) == 1, "Expected one CoreML partition" + shutil.copytree(extracted_path, f"{output_dir}/model1.mlpackage") + shutil.rmtree("extracted_coreml_models") + + subprocess.run(["python", extract_script_path, "--model", model2_path]) + items = os.listdir("extracted_coreml_models") + assert len(items) == 1, "Expected one CoreML partition" + shutil.copytree(extracted_path, f"{output_dir}/model2.mlpackage") + shutil.rmtree("extracted_coreml_models") + + + desc = ct.utils.MultiFunctionDescriptor() + + desc.add_function( + f"{output_dir}/model1.mlpackage", + src_function_name="main", + target_function_name="model1" + ) + desc.add_function( + f"{output_dir}/model2.mlpackage", + src_function_name="main", + target_function_name="model2" + ) + desc.default_function_name = "model1" + ct.utils.save_multifunction(desc, f"{output_dir}/combined.mlpackage") diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py new file mode 100644 index 0000000000..d4711af21b --- /dev/null +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -0,0 +1,407 @@ +# @lint-ignore-every LICENSELINT +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +# Please refer to README.md in the same folder for more information. + +from dataclasses import dataclass +from functools import partial +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F + +from executorch.examples.models.llama.llama_transformer import RMSNorm + +from executorch.examples.models.llama.rope import ( + hf_apply_rotary_emb, + hf_precompute_freqs_cis, + precompute_freqs_cis, + RotaryEmbedding, +) + +from torch import nn + + +# These are just to prevent to_edge from decomposing SDPA +# A better method is to use the to_edge_transform_and_lower API for CoreML +# and not decompose SDPA +@torch.library.custom_op("coreml::sdpa", mutates_args=()) +def sdpa( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor +) -> torch.Tensor: + """Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion.""" + return torch.ops.aten.scaled_dot_product_attention.default( + q, k, v, attn_mask=attn_mask + ) + + +@torch.library.register_fake("coreml::sdpa") +def _( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor +) -> torch.Tensor: + """Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing.""" + expected_shape = list(q.shape) + expected_shape[-1] = v.shape[-1] + return q.new_empty(expected_shape) + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class ModelArgs: + dim: int = 2048 + n_layers: int = 16 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = 128256 + hidden_dim: Optional[int] = None + head_dim: Optional[int] = None # Optional customized head_dim + multiple_of: int = 256 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + max_batch_size: int = 1 + max_seq_len: int = 128 + max_context_len: int = 2048 + moe: bool = False # True to enable the MoE (Mixture of Experts) + num_experts: int = 8 # Number of experts + num_activated_experts: int = 2 # Number of experts to activate + + # Generate logits for all inputs. When it's True, it would take big memory usage + # at runtime. Enable it only necessary (e.g., use perplexity tools that requires + # logits for all input tokens.) + generate_full_logits: bool = False + # A dictionary mapping from pruned token-id to original token-id + input_prune_map: Optional[Dict[int, int]] = None + # A dictionary mapping from pruned token-id to original token-id + output_prune_map: Optional[Dict[int, int]] = None + use_hf_rope: bool = False # Use HuggingFace's RoPE implementation + rope_theta: Optional[float] = ( + None # The official name to override self.rope_freq_base. + ) + rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. + use_scaled_rope: bool = True # Use scaled RoPE, introduced in llama3.1. + # Additional Model Metadata needed at runtime + rope_scale_factor: int = 8 + bos_idx: int = 1 + eos_idx: int = 3 + bos_count: int = -1 # i.e., a single EOS is used as BOS + eos_count: int = 2 + + quantization_args: Optional[dict] = None + lora_args: Optional[dict] = None + + def __post_init__(self): + if self.n_kv_heads is None: + self.n_kv_heads = self.n_heads + + # rope_theta overrides rope_freq_base since it's the official name. + if self.rope_theta is not None: + self.rope_freq_base = self.rope_theta + + if self.hidden_dim is None: + # If hidden_dim is not explicitly set in the ModelArgs, + # then calculate implicitly based on dim and also multiple of `args.multiple_of` + multiple_of = self.multiple_of + hidden_dim = 4 * self.dim + hidden_dim = int(2 * hidden_dim / 3) + if self.ffn_dim_multiplier is not None: + hidden_dim = int(self.ffn_dim_multiplier * hidden_dim) + self.hidden_dim = find_multiple(hidden_dim, multiple_of) + + if self.head_dim is None: + self.head_dim = self.dim // self.n_heads + + +class Rope(torch.nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + if self.params.use_hf_rope: + self.precompute_freqs_cis = hf_precompute_freqs_cis + else: + self.precompute_freqs_cis = partial( + precompute_freqs_cis, use_scaled=self.params.use_scaled_rope + ) + freqs_cos, freqs_sin = self.precompute_freqs_cis( + self.params.head_dim, + ( + self.params.max_context_len # Normal llama2. + if self.params.ffn_dim_multiplier is None + else self.params.max_context_len * 2 # Sharded checkpoint. + ), + self.params.rope_freq_base, + scale_factor=8, + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + if self.params.use_hf_rope: + self.apply_rotary_emb = hf_apply_rotary_emb + else: + self.apply_rotary_emb = RotaryEmbedding() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + return self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) + + def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): + """ + Get the precomputed frequencies for the given input position and sequence length. + + Args: + input_pos (torch.Tensor): The input position tensor. + seq_len (int): The sequence length. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for the given input position and sequence length. + """ + assert ( + input_pos is not None + ), "input_pos must be provided when use_kv_cache is True" + input_pos_item = input_pos[-1].item() + + # CoreML partitioner is not picking up _check_is_size + # So instead use _check as workaround. Should be easy fix for partitioner + # torch._check_is_size(input_pos_item) + torch._check(input_pos_item >= 0) + torch._check(input_pos_item + seq_len <= self.params.max_seq_len) + # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor + freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len) + # pyre-ignore: Incompatible parameter type [6] + freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len) + + return freqs_cos, freqs_sin + + +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + assert args.hidden_dim is not None + hidden_dim: int = args.hidden_dim + self.w1 = nn.Linear(args.dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class ConditionalFeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.dim = args.dim + hidden_dim = args.hidden_dim + if hidden_dim is None: + # If hidden_dim is not explicitly set in the ModelArgs, + # then calculate implicitly based on dim and also multiple of `args.multiple_of` + multiple_of = args.multiple_of + hidden_dim = 4 * self.dim + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) + self.w2 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) + self.w3 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) + self.num_experts = args.num_experts + + def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor: + w1_weights = self.w1[expert_indices].transpose(-1, -2) # [T, A, D, D] + w3_weights = self.w3[expert_indices].transpose(-1, -2) # [T, A, D, D] + w2_weights = self.w2[expert_indices] # [T, A, D, D] + x1 = F.silu(torch.einsum("ti,taio -> tao", x, w1_weights)) + x3 = torch.einsum("ti, taio -> tao", x, w3_weights) + expert_outs = torch.einsum("tao, taoi -> tai", (x1 * x3), w2_weights) + return expert_outs + + +class MOEFeedForward(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) + self.cond_ffn = ConditionalFeedForward(config) + self.dim = config.dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.view(-1, self.dim) + # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts + # x: [T, D] + scores = self.gate(x) # [T, E] + expert_weights, expert_indices = torch.topk(scores, 2, dim=-1) # [T, A], [T, A] + expert_weights = expert_weights.softmax(dim=-1) # [T, A] + expert_outs = self.cond_ffn(x, expert_indices) + return torch.einsum("tai,ta -> ti", expert_outs, expert_weights) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): + super().__init__() + self.n_heads = args.n_heads + self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads + + assert self.n_heads % self.n_kv_heads == 0 + model_parallel_size = 1 + self.n_local_heads = self.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.head_dim + self.max_batch_size = args.max_batch_size + self.max_seq_len = args.max_seq_len + self.dim = args.dim + self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) + + self.layer_id = layer_id + + self.rope = rope + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_mask: torch.Tensor, + ): + bsz, seqlen, _ = x.shape + # QKV + q, k, v = self.wq(x), self.wk(x), self.wv(x) + # We need view_copy elimination + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + # RoPE relative positional embeddings + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + new_k = k + new_v = v + + k = torch.concat([k_cache, k], dim=2) + v = torch.concat([v_cache, v], dim=2) + + # grouped multiquery attention: expand out keys and values + if self.n_rep > 1: + k = k.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + output = torch.ops.coreml.sdpa(q, k, v, attn_mask) + + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + + output = self.wo(output) + + return output, new_k, new_v + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.head_dim + self.attention = Attention(args, layer_id, rope) + if args.moe: + self.block_sparse_moe = MOEFeedForward(args) + else: + self.feed_forward = FeedForward(args) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x, + freqs_cos, + freqs_sin, + k_cache, + v_cache, + attn_mask, + ): # x: 1xN + norm_emb = self.attention_norm(x) + h, new_k, new_v = self.attention.forward( + norm_emb, freqs_cos, freqs_sin, k_cache, v_cache, attn_mask + ) + + h = x + h + out = h + self.feed_forward(self.ffn_norm(h)) + return out, new_k, new_v + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.rope = Rope(params) + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params, self.rope)) + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = nn.Linear(params.dim, params.vocab_size, bias=False) + self.generate_full_logits = params.generate_full_logits + self.max_seq_len = params.max_seq_len + self.input_prune_map = params.input_prune_map + self.output_prune_map = params.output_prune_map + + def forward( + self, + tokens: torch.LongTensor, # tokens + input_pos: torch.LongTensor, + k_cache: torch.FloatTensor, + v_cache: torch.FloatTensor, + attn_mask: torch.LongTensor, + h: Optional[torch.FloatTensor] = None, # embeddings + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if (tokens is None) ^ (h is not None): + raise ValueError( + "You cannot specify both tokens and h at the same time, and must specify either one" + ) + if tokens is not None and h is None: + h = self.tok_embeddings(tokens) + seqlen = h.shape[1] + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen) + + k_out = [] + v_out = [] + + for i, layer in enumerate(self.layers): + h, new_k, new_v = layer( + h, + freqs_cos, + freqs_sin, + k_cache[i,:,:,:,:], + v_cache[i,:,:,:,:], + attn_mask, + ) + k_out.append(new_k) + v_out.append(new_v) + + if not self.generate_full_logits: + # Only the last logit is used for the new generated token + h = h[:, - 1, :] + + h = self.norm(h) + + logits = self.output(h) + + return logits, torch.stack(k_out, dim=0), torch.stack(v_out, dim=0) From 3266e008f0d7e42d19ae111c323f0a43aea38ecb Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 13 Feb 2025 13:30:29 -0800 Subject: [PATCH 02/12] up --- examples/apple/coreml/llama/extract_and_combine.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/apple/coreml/llama/extract_and_combine.py b/examples/apple/coreml/llama/extract_and_combine.py index 404553094e..5cf22b1c27 100644 --- a/examples/apple/coreml/llama/extract_and_combine.py +++ b/examples/apple/coreml/llama/extract_and_combine.py @@ -34,9 +34,13 @@ output_dir = str(args.output_dir) if os.path.exists(output_dir): - shutil.rmtree(output_dir) + raise Exception(f"Output directory {output_dir} already exists. Please make delete it before running script.") os.makedirs(output_dir) + coreml_extract_path = os.path.join(os.getcwd(), "extracted_coreml_models") + if os.path.exists(coreml_extract_path): + raise Exception(f"{coreml_extract_path} already exists. Please delete it before running script.") + extract_script_path = os.path.join(os.path.dirname(__file__), "../scripts/extract_coreml_models.py") extracted_path = "extracted_coreml_models/model_1/lowered_module/model.mlpackage" @@ -44,14 +48,11 @@ items = os.listdir("extracted_coreml_models") assert len(items) == 1, "Expected one CoreML partition" shutil.copytree(extracted_path, f"{output_dir}/model1.mlpackage") - shutil.rmtree("extracted_coreml_models") subprocess.run(["python", extract_script_path, "--model", model2_path]) items = os.listdir("extracted_coreml_models") assert len(items) == 1, "Expected one CoreML partition" shutil.copytree(extracted_path, f"{output_dir}/model2.mlpackage") - shutil.rmtree("extracted_coreml_models") - desc = ct.utils.MultiFunctionDescriptor() From 63dba5512cacea1930c7c75793e808cdf83558c6 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 18 Feb 2025 17:34:10 -0800 Subject: [PATCH 03/12] up --- examples/apple/coreml/llama/export.py | 50 ++++------ .../apple/coreml/llama/llama_transformer.py | 98 ++++++++++++++++++- examples/apple/coreml/llama/run.py | 54 ++++++++++ 3 files changed, 165 insertions(+), 37 deletions(-) create mode 100644 examples/apple/coreml/llama/run.py diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index 243055f9d0..f83f5e9473 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -5,6 +5,8 @@ import argparse import json +import sys + import coremltools as ct import torch from executorch.backends.apple.coreml.compiler import CoreMLBackend # pyre-ignore @@ -20,13 +22,8 @@ from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.extension.export_util.utils import export_to_edge, save_pte_program -import sys sys.path.insert(0, "..") -from llama.llama_transformer import ( - ModelArgs, - Transformer, -) - +from llama.llama_transformer import InputManager, ModelArgs, Transformer def main() -> None: @@ -68,7 +65,7 @@ def main() -> None: ) parser.add_argument( "--coreml-quantize", - default="c4w", + default=None, choices=["b4w", "c4w"], help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight), c4w (for channelwise 4 bit weight)", ) @@ -103,27 +100,6 @@ def main() -> None: print("Unexpected keys: ", unexpected) float_dtype = torch.float16 # dtype for model/inputs - - assert export_args.static_seq_length < args.max_seq_len - - cache_shape = ( - args.n_layers, - args.max_batch_size, - args.n_kv_heads, - args.max_seq_len - export_args.static_seq_length, - args.head_dim, - ) - attn_mask_shape = (export_args.static_seq_length, args.max_seq_len) - - example_inputs = ( - torch.tensor( - [0 for _ in range(export_args.static_seq_length)], dtype=torch.long - ).reshape(1, -1), # tokens - torch.tensor([0], dtype=torch.long), # input_pos - torch.zeros(cache_shape, dtype=float_dtype), # k_cache - torch.zeros(cache_shape, dtype=float_dtype), # v_cache - torch.zeros(attn_mask_shape, dtype=float_dtype), # attn_mask - ) model.eval() model.to(float_dtype) @@ -141,6 +117,9 @@ def main() -> None: packed=(bitwidth in [2, 4]), ).quantized_model() + model = model.to(float_dtype) + + op_linear_quantizer_config = None if export_args.coreml_quantize == "b4w": op_linear_quantizer_config = { "mode": "linear_symmetric", @@ -155,8 +134,6 @@ def main() -> None: "dtype": "int4", "granularity": "per_channel", } - else: - raise ValueError("Invalid coreml_quantize arg") compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] minimum_deployment_target=ct.target.iOS18, @@ -174,12 +151,20 @@ def main() -> None: ], ) + input_manager = InputManager( + model_args=args, + seq_length=export_args.static_seq_length, + dtype=float_dtype, + minus_infinity=-30000, + ) + example_inputs = input_manager.get_inputs(tokens=torch.tensor([0])) + edge_manager = export_to_edge( model, example_inputs, edge_compile_config=EdgeCompileConfig( _check_ir_validity=False, - _skip_type_promotion=(float_dtype == torch.float16), + # _skip_type_promotion=(float_dtype == torch.float16), _skip_dim_order=True, ), ) @@ -205,6 +190,7 @@ def main() -> None: filename = save_pte_program(executorch_program, export_args.output_name) print(f"Saved Executorch program to local {filename}") - + + if __name__ == "__main__": main() # pragma: no cover diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index d4711af21b..ac4db0868d 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -367,6 +367,7 @@ def forward( self, tokens: torch.LongTensor, # tokens input_pos: torch.LongTensor, + input_length: torch.LongTensor, # input_length k_cache: torch.FloatTensor, v_cache: torch.FloatTensor, attn_mask: torch.LongTensor, @@ -383,14 +384,13 @@ def forward( k_out = [] v_out = [] - for i, layer in enumerate(self.layers): h, new_k, new_v = layer( h, freqs_cos, freqs_sin, - k_cache[i,:,:,:,:], - v_cache[i,:,:,:,:], + k_cache[i, :, :, :, :], + v_cache[i, :, :, :, :], attn_mask, ) k_out.append(new_k) @@ -398,10 +398,98 @@ def forward( if not self.generate_full_logits: # Only the last logit is used for the new generated token - h = h[:, - 1, :] + h = h[:, input_length - 1, :] h = self.norm(h) logits = self.output(h) - return logits, torch.stack(k_out, dim=0), torch.stack(v_out, dim=0) + return ( + logits, + torch.stack(k_out, dim=0), + torch.stack(v_out, dim=0), + ) + + +class InputManager: + def __init__( + self, + model_args: ModelArgs, + seq_length, + dtype=torch.float16, + minus_infinity=-torch.inf, + ): + self.n_layers = model_args.n_layers + self.max_batch_size = model_args.max_batch_size + self.n_kv_heads = model_args.n_kv_heads + self.head_dim = model_args.head_dim + + self.seq_length = seq_length + self.max_seq_length = model_args.max_seq_len + + self.k_cache = torch.zeros( + self.get_cache_shape(self.max_seq_length - self.seq_length) + ).to(dtype) + self.v_cache = torch.zeros( + self.get_cache_shape(self.max_seq_length - self.seq_length) + ).to(dtype) + + attn_cache = minus_infinity * torch.ones( + seq_length, self.max_seq_length - self.seq_length + ) # attn for past tokens + attn_seq = torch.triu( + minus_infinity * torch.ones(self.seq_length, self.seq_length), diagonal=1 + ) # attn for current tokens + self.attn_mask = torch.concat([attn_cache, attn_seq], dim=-1).to(dtype) + assert self.attn_mask.shape == (self.seq_length, self.max_seq_length) + + self.input_pos = 0 + + def get_cache_shape(self, length): + return ( + self.n_layers, + self.max_batch_size, + self.n_kv_heads, + length, + self.head_dim, + ) + + def update(self, input_length, new_k_cache, new_v_cache): + assert new_k_cache.shape == self.get_cache_shape(self.seq_length) + assert new_v_cache.shape == self.get_cache_shape(self.seq_length) + + self.k_cache[:, :, :, (self.input_pos) : (self.input_pos + input_length), :] = ( + new_k_cache[:, :, :, 0:input_length, :] + ) + self.v_cache[:, :, :, (self.input_pos) : (self.input_pos + input_length), :] = ( + new_v_cache[:, :, :, 0:input_length, :] + ) + self.attn_mask[:, (self.input_pos) : (self.input_pos + input_length)] = 0.0 + self.input_pos += input_length + + def get_inputs(self, tokens): + assert tokens.dim() == 1 + assert tokens.dtype == torch.int64 + input_length = len(tokens) + assert input_length <= self.seq_length + + return ( + # tokens + torch.concat( + [ + tokens, + torch.zeros(self.seq_length - input_length, dtype=torch.int64), + ], + axis=-1, + ).reshape(1, -1), + # input_pos + torch.tensor([self.input_pos], dtype=torch.long), + # input_length + torch.tensor([input_length], dtype=torch.long), + # k_cache + self.k_cache, + # v_cache + self.v_cache, + # attn_mask + self.attn_mask, + ) diff --git a/examples/apple/coreml/llama/run.py b/examples/apple/coreml/llama/run.py new file mode 100644 index 0000000000..504faa68d6 --- /dev/null +++ b/examples/apple/coreml/llama/run.py @@ -0,0 +1,54 @@ +import sys + +import torch + +sys.path.insert(0, "..") +import json + +from llama.llama_transformer import InputManager, ModelArgs, Transformer + +params_path = "/Users/scroy/models/stories110M/params.json" +max_seq_length = 512 +seq_length = 64 +# Load model args +with open(params_path, "r") as f: + params = json.loads(f.read()) + +args = ModelArgs( + max_seq_len=max_seq_length, + generate_full_logits=False, + **params, +) +input_manager = InputManager( + model_args=args, + seq_length=seq_length, + dtype=torch.float16, + minus_infinity=-30000, +) + + +filename = "/Users/scroy/Desktop/model.pte" + +# Test PTE +from executorch.runtime import Runtime + +from transformers import AutoTokenizer + +# Load the tokenizer for LLaMA 3 +tokenizer = AutoTokenizer.from_pretrained("neuralmagic/llama2.c-stories110M-pruned50") + +text = "Once upon a time," +runtime = Runtime.get() +program = runtime.load_program(filename) +method = program.load_method("forward") +print(text) +tokens = tokenizer.encode(text) +while input_manager.input_pos + len(tokens) < max_seq_length: + inputs = input_manager.get_inputs(torch.tensor(tokens, dtype=torch.long)) + logits, k, v = method.execute(inputs) + input_manager.update(input_length=len(tokens), new_k_cache=k, new_v_cache=v) + + new_token = logits.argmax(-1).item() + tokens = [new_token] + decoded_text = tokenizer.decode(tokens) + print(decoded_text, end=" ") From 9a011c4ea32c5ac9c2dbd8076c4f863df4947a39 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 18 Feb 2025 18:32:05 -0800 Subject: [PATCH 04/12] up --- examples/apple/coreml/llama/export.py | 2 +- .../apple/coreml/llama/llama_transformer.py | 16 +++++++++----- examples/apple/coreml/llama/run.py | 21 ++++++++++++------- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index f83f5e9473..96e6316915 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -157,7 +157,7 @@ def main() -> None: dtype=float_dtype, minus_infinity=-30000, ) - example_inputs = input_manager.get_inputs(tokens=torch.tensor([0])) + example_inputs = input_manager.get_inputs(tokens=[0]) edge_manager = export_to_edge( model, diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index ac4db0868d..d436a72a08 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from functools import partial -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.nn.functional as F @@ -467,9 +467,7 @@ def update(self, input_length, new_k_cache, new_v_cache): self.attn_mask[:, (self.input_pos) : (self.input_pos + input_length)] = 0.0 self.input_pos += input_length - def get_inputs(self, tokens): - assert tokens.dim() == 1 - assert tokens.dtype == torch.int64 + def get_inputs(self, tokens: List[int]): input_length = len(tokens) assert input_length <= self.seq_length @@ -477,7 +475,7 @@ def get_inputs(self, tokens): # tokens torch.concat( [ - tokens, + torch.tensor(tokens, dtype=torch.int64), torch.zeros(self.seq_length - input_length, dtype=torch.int64), ], axis=-1, @@ -493,3 +491,11 @@ def get_inputs(self, tokens): # attn_mask self.attn_mask, ) + + def get_inputs_and_remaining_tokens(self, tokens: List[int]): + processed_tokens = min(self.seq_length, len(tokens)) + return ( + self.get_inputs(tokens[0:processed_tokens]), + processed_tokens, + tokens[processed_tokens:], + ) diff --git a/examples/apple/coreml/llama/run.py b/examples/apple/coreml/llama/run.py index 504faa68d6..d7330588bf 100644 --- a/examples/apple/coreml/llama/run.py +++ b/examples/apple/coreml/llama/run.py @@ -43,12 +43,17 @@ method = program.load_method("forward") print(text) tokens = tokenizer.encode(text) -while input_manager.input_pos + len(tokens) < max_seq_length: - inputs = input_manager.get_inputs(torch.tensor(tokens, dtype=torch.long)) - logits, k, v = method.execute(inputs) - input_manager.update(input_length=len(tokens), new_k_cache=k, new_v_cache=v) - - new_token = logits.argmax(-1).item() - tokens = [new_token] +while input_manager.input_pos + len(tokens) < max_seq_length - seq_length: + while len(tokens) > 0: + inputs, processed_tokens, remaining_tokens = ( + input_manager.get_inputs_and_remaining_tokens(tokens) + ) + logits, k, v = method.execute(inputs) + input_manager.update( + input_length=processed_tokens, new_k_cache=k, new_v_cache=v + ) + tokens = remaining_tokens + + tokens = [logits.argmax(-1).item()] decoded_text = tokenizer.decode(tokens) - print(decoded_text, end=" ") + print(decoded_text, end=" ", flush=True) From 37fb4f87f5cbff5bd1a627218abb0a750f4f31a7 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 18 Feb 2025 20:38:53 -0800 Subject: [PATCH 05/12] up --- examples/apple/coreml/llama/export.py | 6 +- .../apple/coreml/llama/llama_transformer.py | 1 + examples/apple/coreml/llama/readme.md | 13 ++ examples/apple/coreml/llama/run.py | 135 ++++++++++++------ 4 files changed, 106 insertions(+), 49 deletions(-) create mode 100644 examples/apple/coreml/llama/readme.md diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index 96e6316915..05a958208d 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -45,7 +45,7 @@ def main() -> None: help="checkpoint path", ) parser.add_argument( - "--static_seq_length", + "--seq_length", type=int, default=1, # set to 1 for decode help="length sequence to evaluate", @@ -153,7 +153,7 @@ def main() -> None: input_manager = InputManager( model_args=args, - seq_length=export_args.static_seq_length, + seq_length=export_args.seq_length, dtype=float_dtype, minus_infinity=-30000, ) @@ -164,7 +164,7 @@ def main() -> None: example_inputs, edge_compile_config=EdgeCompileConfig( _check_ir_validity=False, - # _skip_type_promotion=(float_dtype == torch.float16), + _skip_type_promotion=(float_dtype == torch.float16), _skip_dim_order=True, ), ) diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index d436a72a08..1089c27484 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -457,6 +457,7 @@ def get_cache_shape(self, length): def update(self, input_length, new_k_cache, new_v_cache): assert new_k_cache.shape == self.get_cache_shape(self.seq_length) assert new_v_cache.shape == self.get_cache_shape(self.seq_length) + assert self.input_pos + input_length <= self.max_seq_length - self.seq_length self.k_cache[:, :, :, (self.input_pos) : (self.input_pos + input_length), :] = ( new_k_cache[:, :, :, 0:input_length, :] diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md new file mode 100644 index 0000000000..47fb4bb148 --- /dev/null +++ b/examples/apple/coreml/llama/readme.md @@ -0,0 +1,13 @@ +This directory contains static, ANE-friendly Llama models. + +Export model with: +``` +python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w +``` + +Run model with: +``` +python run.py -m /path/to/model.pte -p /path/to/params.json -t /path/to/tokenizer.model --seq_length 64 --max_seq_length 1024 --prompt "Once upon a time," +``` + +The runner is written in python and is only intended to serve as an example for how the model inputs should be processed; it is not performant. diff --git a/examples/apple/coreml/llama/run.py b/examples/apple/coreml/llama/run.py index d7330588bf..9a06ba2426 100644 --- a/examples/apple/coreml/llama/run.py +++ b/examples/apple/coreml/llama/run.py @@ -1,3 +1,4 @@ +import argparse import sys import torch @@ -5,55 +6,97 @@ sys.path.insert(0, "..") import json -from llama.llama_transformer import InputManager, ModelArgs, Transformer +import sentencepiece as spm +from executorch.runtime import Runtime -params_path = "/Users/scroy/models/stories110M/params.json" -max_seq_length = 512 -seq_length = 64 -# Load model args -with open(params_path, "r") as f: - params = json.loads(f.read()) +from llama.llama_transformer import InputManager, ModelArgs -args = ModelArgs( - max_seq_len=max_seq_length, - generate_full_logits=False, - **params, -) -input_manager = InputManager( - model_args=args, - seq_length=seq_length, - dtype=torch.float16, - minus_infinity=-30000, -) +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + help="model.pte", + ) + parser.add_argument( + "-p", + "--params", + help="config.json", + ) + parser.add_argument( + "-t", + "--tokenizer", + help="tokenizer.model path", + ) + parser.add_argument( + "--seq_length", + type=int, + default=1, # set to 1 for decode + help="length sequence to evaluate", + ) + parser.add_argument( + "--max_seq_length", + type=int, + default=128, + help="maximum length sequence to evaluate", + ) + parser.add_argument( + "--prompt", + type=str, + default="Once upon a time,", + ) -filename = "/Users/scroy/Desktop/model.pte" + args = parser.parse_args() + params_path = args.params + + # Load model args + with open(params_path, "r") as f: + params = json.loads(f.read()) + + model_args = ModelArgs( + max_seq_len=args.max_seq_length, + generate_full_logits=False, + **params, + ) + + input_manager = InputManager( + model_args=model_args, + seq_length=args.seq_length, + dtype=torch.float16, + minus_infinity=-30000, + ) + + sp = spm.SentencePieceProcessor() + sp.load(args.tokenizer) + + runtime = Runtime.get() + program = runtime.load_program(args.model) + method = program.load_method("forward") + generated_tokens = [] + tokens = sp.encode(args.prompt) + generated_tokens.extend(tokens) + while ( + input_manager.input_pos + args.seq_length + < args.max_seq_length - args.seq_length + ): + while len(tokens) > 0: + inputs, processed_tokens, remaining_tokens = ( + input_manager.get_inputs_and_remaining_tokens(tokens) + ) + logits, k, v = method.execute(inputs) + input_manager.update( + input_length=processed_tokens, new_k_cache=k, new_v_cache=v + ) + tokens = remaining_tokens + + tokens = [logits.argmax(-1).item()] + generated_tokens.extend(tokens) + print(sp.decode(generated_tokens[-1]), end=" ", flush=True) + + print("\n\nFull text:") + print(sp.decode(generated_tokens)) -# Test PTE -from executorch.runtime import Runtime -from transformers import AutoTokenizer - -# Load the tokenizer for LLaMA 3 -tokenizer = AutoTokenizer.from_pretrained("neuralmagic/llama2.c-stories110M-pruned50") - -text = "Once upon a time," -runtime = Runtime.get() -program = runtime.load_program(filename) -method = program.load_method("forward") -print(text) -tokens = tokenizer.encode(text) -while input_manager.input_pos + len(tokens) < max_seq_length - seq_length: - while len(tokens) > 0: - inputs, processed_tokens, remaining_tokens = ( - input_manager.get_inputs_and_remaining_tokens(tokens) - ) - logits, k, v = method.execute(inputs) - input_manager.update( - input_length=processed_tokens, new_k_cache=k, new_v_cache=v - ) - tokens = remaining_tokens - - tokens = [logits.argmax(-1).item()] - decoded_text = tokenizer.decode(tokens) - print(decoded_text, end=" ", flush=True) +if __name__ == "__main__": + main() From 4c7eec3b2023a498554fa8963b9c54e6548c6d16 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 19 Feb 2025 12:41:44 -0800 Subject: [PATCH 06/12] up --- .../apple/coreml/llama/llama_transformer.py | 126 +++++++++++++----- examples/apple/coreml/llama/readme.md | 12 +- examples/apple/coreml/llama/run.py | 100 ++++++++++++-- 3 files changed, 191 insertions(+), 47 deletions(-) diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index 1089c27484..d99c89077a 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -97,6 +97,8 @@ class ModelArgs: quantization_args: Optional[dict] = None lora_args: Optional[dict] = None + use_cache_list: bool = True + def __post_init__(self): if self.n_kv_heads is None: self.n_kv_heads = self.n_heads @@ -362,14 +364,15 @@ def __init__(self, params: ModelArgs): self.max_seq_len = params.max_seq_len self.input_prune_map = params.input_prune_map self.output_prune_map = params.output_prune_map + self.use_cache_list = params.use_cache_list def forward( self, tokens: torch.LongTensor, # tokens input_pos: torch.LongTensor, input_length: torch.LongTensor, # input_length - k_cache: torch.FloatTensor, - v_cache: torch.FloatTensor, + k_caches: List[torch.FloatTensor], + v_caches: List[torch.FloatTensor], attn_mask: torch.LongTensor, h: Optional[torch.FloatTensor] = None, # embeddings ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -389,8 +392,8 @@ def forward( h, freqs_cos, freqs_sin, - k_cache[i, :, :, :, :], - v_cache[i, :, :, :, :], + k_caches[i] if self.use_cache_list else k_caches[i,:,:,:,:], + v_caches[i] if self.use_cache_list else v_caches[i,:,:,:,:], attn_mask, ) k_out.append(new_k) @@ -404,11 +407,10 @@ def forward( logits = self.output(h) - return ( - logits, - torch.stack(k_out, dim=0), - torch.stack(v_out, dim=0), - ) + if not self.use_cache_list: + k_out = torch.stack(k_out, dim=0) + v_out = torch.stack(v_out, dim=0) + return logits, k_out, v_out class InputManager: @@ -418,34 +420,59 @@ def __init__( seq_length, dtype=torch.float16, minus_infinity=-torch.inf, + cache_size = None, ): + if cache_size is None: + cache_size = model_args.max_seq_len - seq_length + self.cache_size = cache_size + assert self.cache_size + seq_length <= model_args.max_seq_len + + self.n_layers = model_args.n_layers self.max_batch_size = model_args.max_batch_size self.n_kv_heads = model_args.n_kv_heads self.head_dim = model_args.head_dim self.seq_length = seq_length - self.max_seq_length = model_args.max_seq_len - - self.k_cache = torch.zeros( - self.get_cache_shape(self.max_seq_length - self.seq_length) - ).to(dtype) - self.v_cache = torch.zeros( - self.get_cache_shape(self.max_seq_length - self.seq_length) - ).to(dtype) + self.use_cache_list = model_args.use_cache_list + + if self.use_cache_list: + self.k_caches = [ + torch.zeros(self.get_cache_shape(self.cache_size)).to( + dtype + ) + for _ in range(self.n_layers) + ] + self.v_caches = [ + torch.zeros(self.get_cache_shape(self.cache_size)).to( + dtype + ) + for _ in range(self.n_layers) + ] + else: + self.k_caches = torch.zeros(self.get_cache_shape(self.cache_size)).to(dtype) + self.v_caches = torch.zeros(self.get_cache_shape(self.cache_size)).to(dtype) attn_cache = minus_infinity * torch.ones( - seq_length, self.max_seq_length - self.seq_length + seq_length, self.cache_size ) # attn for past tokens attn_seq = torch.triu( minus_infinity * torch.ones(self.seq_length, self.seq_length), diagonal=1 ) # attn for current tokens self.attn_mask = torch.concat([attn_cache, attn_seq], dim=-1).to(dtype) - assert self.attn_mask.shape == (self.seq_length, self.max_seq_length) + assert self.attn_mask.shape == (self.seq_length, self.cache_size + self.seq_length) self.input_pos = 0 + self.cache_pos = 0 def get_cache_shape(self, length): + if self.use_cache_list: + return ( + self.max_batch_size, + self.n_kv_heads, + length, + self.head_dim, + ) return ( self.n_layers, self.max_batch_size, @@ -454,18 +481,52 @@ def get_cache_shape(self, length): self.head_dim, ) - def update(self, input_length, new_k_cache, new_v_cache): - assert new_k_cache.shape == self.get_cache_shape(self.seq_length) - assert new_v_cache.shape == self.get_cache_shape(self.seq_length) - assert self.input_pos + input_length <= self.max_seq_length - self.seq_length + def _update_cache(self, start, length, new_k_caches, new_v_caches): + """ + Copies new cache data from start to start + length to cache + """ + assert self.cache_pos + length <= self.cache_size + assert start + length <= self.seq_length + + if self.use_cache_list: + for i in range(self.n_layers): + assert new_k_caches[i].shape == self.get_cache_shape(self.seq_length) + assert new_v_caches[i].shape == self.get_cache_shape(self.seq_length) + + self.k_caches[i][ + :, :, (self.cache_pos) : (self.cache_pos + length), : + ] = new_k_caches[i][:, :, start:(start+length), :] + self.v_caches[i][ + :, :, (self.cache_pos) : (self.cache_pos + length), : + ] = new_v_caches[i][:, :, start:(start+length), :] + else: + assert new_k_caches.shape == self.get_cache_shape(self.seq_length) + assert new_v_caches.shape == self.get_cache_shape(self.seq_length) + self.k_caches[ + :, :, :, (self.cache_pos) : (self.cache_pos + length), : + ] = new_k_caches[:, :, :, start:(start+length), :] + self.v_caches[ + :, :, :, (self.cache_pos) : (self.cache_pos + length), : + ] = new_v_caches[:, :, :, start:(start+length), :] + + self.cache_pos += length + if self.cache_pos == self.cache_size: + self.cache_pos = 0 + + + def update(self, input_length, new_k_caches, new_v_caches): + # Copy as much new cache data into cache as possible without wrapping + amount_to_copy = min(input_length, self.cache_size - self.cache_pos) + self._update_cache(0, amount_to_copy, new_k_caches, new_v_caches) + if self.input_pos <= self.cache_size: + self.attn_mask[:, (self.input_pos) : (self.input_pos + amount_to_copy)] = 0.0 + + # Copy remainder (cache is now wrapped around and has more room) + # Attention mask needs no further updates. Attention is paid to the whole cache + remaining_to_copy = min(input_length - amount_to_copy, self.cache_size - self.cache_pos) + if remaining_to_copy > 0: + self._update_cache(amount_to_copy, remaining_to_copy, new_k_caches, new_v_caches) - self.k_cache[:, :, :, (self.input_pos) : (self.input_pos + input_length), :] = ( - new_k_cache[:, :, :, 0:input_length, :] - ) - self.v_cache[:, :, :, (self.input_pos) : (self.input_pos + input_length), :] = ( - new_v_cache[:, :, :, 0:input_length, :] - ) - self.attn_mask[:, (self.input_pos) : (self.input_pos + input_length)] = 0.0 self.input_pos += input_length def get_inputs(self, tokens: List[int]): @@ -486,9 +547,9 @@ def get_inputs(self, tokens: List[int]): # input_length torch.tensor([input_length], dtype=torch.long), # k_cache - self.k_cache, + self.k_caches, # v_cache - self.v_cache, + self.v_caches, # attn_mask self.attn_mask, ) @@ -497,6 +558,5 @@ def get_inputs_and_remaining_tokens(self, tokens: List[int]): processed_tokens = min(self.seq_length, len(tokens)) return ( self.get_inputs(tokens[0:processed_tokens]), - processed_tokens, tokens[processed_tokens:], ) diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index 47fb4bb148..d5f45f9ef7 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -1,13 +1,19 @@ -This directory contains static, ANE-friendly Llama models. +# ANE-friendly Llama models + +This directory contains ANE-friendly Llama models. Export model with: ``` python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w ``` + +The runner is written in python and is only intended to serve as an example for how the model inputs should be processed; it is not performant. + + Run model with: ``` -python run.py -m /path/to/model.pte -p /path/to/params.json -t /path/to/tokenizer.model --seq_length 64 --max_seq_length 1024 --prompt "Once upon a time," +python run.py -m /path/to/model.pte -p /path/to/params.json -t /path/to/tokenizer.model --seq_length 64 --max_seq_length 1024 --prompt "Once upon a time," --n_steps 512 ``` -The runner is written in python and is only intended to serve as an example for how the model inputs should be processed; it is not performant. +The model here is based on a "sliding" cache, where old tokens are evicted from the cache. By default, the cache size is max_seq_length - seq_length, but you can explicitly pass in a smaller cache size (e.g., --cache_size 512). This can speed up computation and reduce memory. Keep in mind that once cache_size is reached, older tokens get evicted from the cache and do not participate in attention. diff --git a/examples/apple/coreml/llama/run.py b/examples/apple/coreml/llama/run.py index 9a06ba2426..bd54b264b0 100644 --- a/examples/apple/coreml/llama/run.py +++ b/examples/apple/coreml/llama/run.py @@ -1,16 +1,83 @@ import argparse +from multiprocessing import process import sys import torch +from pathlib import Path + sys.path.insert(0, "..") import json import sentencepiece as spm +import tiktoken +from tiktoken.load import load_tiktoken_bpe + from executorch.runtime import Runtime from llama.llama_transformer import InputManager, ModelArgs +class Tokenizer: + def __init__(self, model_path: str): + # Try sentence piece + try: + print("Trying to load sentencepiece") + sp = spm.SentencePieceProcessor() + sp.load(model_path) + self.tokenizer = sp + except: + print("Trying to tiktoken") + self.num_reserved_special_tokens = 256 + self.pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.tokenizer = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + def encode(self, text): + return self.tokenizer.encode(text) + + def encode_prompt(self, text): + if isinstance(self.tokenizer, spm.SentencePieceProcessor): + return self.tokenizer.encode(text) + + get_prompt = lambda x: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{x}<|eot_id|><|start_header_id|>assistant<|end_header_id|>" + return self.tokenizer.encode(get_prompt(text), allowed_special={"<|begin_of_text|>", "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>"}) + + def decode(self, tokens): + return self.tokenizer.decode(tokens) + + def stop_tokens(self): + if isinstance(self.tokenizer, spm.SentencePieceProcessor): + return [self.tokenizer.eos_id()] + if isinstance(self.tokenizer, tiktoken.Encoding): + return [ + self.tokenizer.encode("<|eot_id|>", allowed_special={"<|eot_id|>"})[0], + self.tokenizer.encode("<|end_of_text|>", allowed_special={"<|end_of_text|>"})[0], + ] def main() -> None: parser = argparse.ArgumentParser() @@ -46,6 +113,16 @@ def main() -> None: type=str, default="Once upon a time,", ) + parser.add_argument( + "--n_steps", + type=int, + ) + parser.add_argument( + "--cache_size", + type=int, + default=None, + help="Cache size. Old items are evicted from cache", + ) args = parser.parse_args() params_path = args.params @@ -57,6 +134,7 @@ def main() -> None: model_args = ModelArgs( max_seq_len=args.max_seq_length, generate_full_logits=False, + use_cache_list=False, # cache_list does not work in pybindings **params, ) @@ -65,37 +143,37 @@ def main() -> None: seq_length=args.seq_length, dtype=torch.float16, minus_infinity=-30000, + cache_size=args.cache_size, ) - sp = spm.SentencePieceProcessor() - sp.load(args.tokenizer) + tokenizer = Tokenizer(args.tokenizer) runtime = Runtime.get() program = runtime.load_program(args.model) method = program.load_method("forward") generated_tokens = [] - tokens = sp.encode(args.prompt) + tokens = tokenizer.encode_prompt(args.prompt) generated_tokens.extend(tokens) - while ( - input_manager.input_pos + args.seq_length - < args.max_seq_length - args.seq_length - ): + while input_manager.input_pos < args.n_steps: while len(tokens) > 0: - inputs, processed_tokens, remaining_tokens = ( + inputs, remaining_tokens = ( input_manager.get_inputs_and_remaining_tokens(tokens) ) + processed_tokens = len(tokens) - len(remaining_tokens) logits, k, v = method.execute(inputs) input_manager.update( - input_length=processed_tokens, new_k_cache=k, new_v_cache=v + input_length=processed_tokens, new_k_caches=k, new_v_caches=v ) tokens = remaining_tokens tokens = [logits.argmax(-1).item()] generated_tokens.extend(tokens) - print(sp.decode(generated_tokens[-1]), end=" ", flush=True) + if tokens[-1] in tokenizer.stop_tokens(): + break + print(tokenizer.decode([generated_tokens[-1]]), end=" ", flush=True) print("\n\nFull text:") - print(sp.decode(generated_tokens)) + print(tokenizer.decode(generated_tokens)) if __name__ == "__main__": From 92528e95ffe370a1228ba4604891ac5dda220722 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 19 Feb 2025 12:45:14 -0800 Subject: [PATCH 07/12] lint --- examples/apple/coreml/llama/export.py | 13 +++++ .../apple/coreml/llama/extract_and_combine.py | 23 ++++++--- .../apple/coreml/llama/llama_transformer.py | 51 ++++++++++--------- examples/apple/coreml/llama/run.py | 38 +++++++++----- 4 files changed, 81 insertions(+), 44 deletions(-) diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index 05a958208d..f86b5306cc 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -56,6 +56,12 @@ def main() -> None: default=128, help="maximum length sequence to evaluate", ) + parser.add_argument( + "--cache_size", + type=int, + default=None, + help="Cache size. Old items are evicted from cache", + ) parser.add_argument( "-E", "--embedding-quantize", @@ -69,6 +75,11 @@ def main() -> None: choices=["b4w", "c4w"], help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight), c4w (for channelwise 4 bit weight)", ) + parser.add_argument( + "--use-cache-list", + action="store_true", + help="Use cache list to speed up model computation (does not work in pybindings)", + ) export_args = parser.parse_args() params_path = export_args.params @@ -81,6 +92,7 @@ def main() -> None: args = ModelArgs( max_seq_len=export_args.max_seq_length, generate_full_logits=False, + use_cache_list=export_args.use_cache_list, **params, ) @@ -156,6 +168,7 @@ def main() -> None: seq_length=export_args.seq_length, dtype=float_dtype, minus_infinity=-30000, + cache_size=export_args.cache_size, ) example_inputs = input_manager.get_inputs(tokens=[0]) diff --git a/examples/apple/coreml/llama/extract_and_combine.py b/examples/apple/coreml/llama/extract_and_combine.py index 5cf22b1c27..f73b5713bb 100644 --- a/examples/apple/coreml/llama/extract_and_combine.py +++ b/examples/apple/coreml/llama/extract_and_combine.py @@ -1,8 +1,9 @@ -import coremltools as ct import argparse import os -import subprocess import shutil +import subprocess + +import coremltools as ct if __name__ == "__main__": """ @@ -34,16 +35,22 @@ output_dir = str(args.output_dir) if os.path.exists(output_dir): - raise Exception(f"Output directory {output_dir} already exists. Please make delete it before running script.") + raise Exception( + f"Output directory {output_dir} already exists. Please make delete it before running script." + ) os.makedirs(output_dir) coreml_extract_path = os.path.join(os.getcwd(), "extracted_coreml_models") if os.path.exists(coreml_extract_path): - raise Exception(f"{coreml_extract_path} already exists. Please delete it before running script.") + raise Exception( + f"{coreml_extract_path} already exists. Please delete it before running script." + ) - extract_script_path = os.path.join(os.path.dirname(__file__), "../scripts/extract_coreml_models.py") + extract_script_path = os.path.join( + os.path.dirname(__file__), "../scripts/extract_coreml_models.py" + ) extracted_path = "extracted_coreml_models/model_1/lowered_module/model.mlpackage" - + subprocess.run(["python", extract_script_path, "--model", model1_path]) items = os.listdir("extracted_coreml_models") assert len(items) == 1, "Expected one CoreML partition" @@ -59,12 +66,12 @@ desc.add_function( f"{output_dir}/model1.mlpackage", src_function_name="main", - target_function_name="model1" + target_function_name="model1", ) desc.add_function( f"{output_dir}/model2.mlpackage", src_function_name="main", - target_function_name="model2" + target_function_name="model2", ) desc.default_function_name = "model1" ct.utils.save_multifunction(desc, f"{output_dir}/combined.mlpackage") diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index d99c89077a..cbc9579800 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -392,8 +392,8 @@ def forward( h, freqs_cos, freqs_sin, - k_caches[i] if self.use_cache_list else k_caches[i,:,:,:,:], - v_caches[i] if self.use_cache_list else v_caches[i,:,:,:,:], + k_caches[i] if self.use_cache_list else k_caches[i, :, :, :, :], + v_caches[i] if self.use_cache_list else v_caches[i, :, :, :, :], attn_mask, ) k_out.append(new_k) @@ -420,14 +420,13 @@ def __init__( seq_length, dtype=torch.float16, minus_infinity=-torch.inf, - cache_size = None, + cache_size=None, ): if cache_size is None: cache_size = model_args.max_seq_len - seq_length self.cache_size = cache_size assert self.cache_size + seq_length <= model_args.max_seq_len - - + self.n_layers = model_args.n_layers self.max_batch_size = model_args.max_batch_size self.n_kv_heads = model_args.n_kv_heads @@ -438,15 +437,11 @@ def __init__( if self.use_cache_list: self.k_caches = [ - torch.zeros(self.get_cache_shape(self.cache_size)).to( - dtype - ) + torch.zeros(self.get_cache_shape(self.cache_size)).to(dtype) for _ in range(self.n_layers) ] self.v_caches = [ - torch.zeros(self.get_cache_shape(self.cache_size)).to( - dtype - ) + torch.zeros(self.get_cache_shape(self.cache_size)).to(dtype) for _ in range(self.n_layers) ] else: @@ -460,7 +455,10 @@ def __init__( minus_infinity * torch.ones(self.seq_length, self.seq_length), diagonal=1 ) # attn for current tokens self.attn_mask = torch.concat([attn_cache, attn_seq], dim=-1).to(dtype) - assert self.attn_mask.shape == (self.seq_length, self.cache_size + self.seq_length) + assert self.attn_mask.shape == ( + self.seq_length, + self.cache_size + self.seq_length, + ) self.input_pos = 0 self.cache_pos = 0 @@ -495,37 +493,42 @@ def _update_cache(self, start, length, new_k_caches, new_v_caches): self.k_caches[i][ :, :, (self.cache_pos) : (self.cache_pos + length), : - ] = new_k_caches[i][:, :, start:(start+length), :] + ] = new_k_caches[i][:, :, start : (start + length), :] self.v_caches[i][ :, :, (self.cache_pos) : (self.cache_pos + length), : - ] = new_v_caches[i][:, :, start:(start+length), :] + ] = new_v_caches[i][:, :, start : (start + length), :] else: assert new_k_caches.shape == self.get_cache_shape(self.seq_length) assert new_v_caches.shape == self.get_cache_shape(self.seq_length) - self.k_caches[ - :, :, :, (self.cache_pos) : (self.cache_pos + length), : - ] = new_k_caches[:, :, :, start:(start+length), :] - self.v_caches[ - :, :, :, (self.cache_pos) : (self.cache_pos + length), : - ] = new_v_caches[:, :, :, start:(start+length), :] + self.k_caches[:, :, :, (self.cache_pos) : (self.cache_pos + length), :] = ( + new_k_caches[:, :, :, start : (start + length), :] + ) + self.v_caches[:, :, :, (self.cache_pos) : (self.cache_pos + length), :] = ( + new_v_caches[:, :, :, start : (start + length), :] + ) self.cache_pos += length if self.cache_pos == self.cache_size: self.cache_pos = 0 - def update(self, input_length, new_k_caches, new_v_caches): # Copy as much new cache data into cache as possible without wrapping amount_to_copy = min(input_length, self.cache_size - self.cache_pos) self._update_cache(0, amount_to_copy, new_k_caches, new_v_caches) if self.input_pos <= self.cache_size: - self.attn_mask[:, (self.input_pos) : (self.input_pos + amount_to_copy)] = 0.0 + self.attn_mask[:, (self.input_pos) : (self.input_pos + amount_to_copy)] = ( + 0.0 + ) # Copy remainder (cache is now wrapped around and has more room) # Attention mask needs no further updates. Attention is paid to the whole cache - remaining_to_copy = min(input_length - amount_to_copy, self.cache_size - self.cache_pos) + remaining_to_copy = min( + input_length - amount_to_copy, self.cache_size - self.cache_pos + ) if remaining_to_copy > 0: - self._update_cache(amount_to_copy, remaining_to_copy, new_k_caches, new_v_caches) + self._update_cache( + amount_to_copy, remaining_to_copy, new_k_caches, new_v_caches + ) self.input_pos += input_length diff --git a/examples/apple/coreml/llama/run.py b/examples/apple/coreml/llama/run.py index bd54b264b0..803949046c 100644 --- a/examples/apple/coreml/llama/run.py +++ b/examples/apple/coreml/llama/run.py @@ -1,9 +1,9 @@ import argparse -from multiprocessing import process import sys +from multiprocessing import process +from pathlib import Path import torch -from pathlib import Path sys.path.insert(0, "..") @@ -11,11 +11,12 @@ import sentencepiece as spm import tiktoken -from tiktoken.load import load_tiktoken_bpe from executorch.runtime import Runtime from llama.llama_transformer import InputManager, ModelArgs +from tiktoken.load import load_tiktoken_bpe + class Tokenizer: def __init__(self, model_path: str): @@ -56,29 +57,42 @@ def __init__(self, model_path: str): mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens, ) - + def encode(self, text): return self.tokenizer.encode(text) - + def encode_prompt(self, text): if isinstance(self.tokenizer, spm.SentencePieceProcessor): return self.tokenizer.encode(text) - get_prompt = lambda x: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{x}<|eot_id|><|start_header_id|>assistant<|end_header_id|>" - return self.tokenizer.encode(get_prompt(text), allowed_special={"<|begin_of_text|>", "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>"}) + get_prompt = ( + lambda x: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{x}<|eot_id|><|start_header_id|>assistant<|end_header_id|>" + ) + return self.tokenizer.encode( + get_prompt(text), + allowed_special={ + "<|begin_of_text|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eot_id|>", + }, + ) def decode(self, tokens): return self.tokenizer.decode(tokens) - + def stop_tokens(self): if isinstance(self.tokenizer, spm.SentencePieceProcessor): return [self.tokenizer.eos_id()] if isinstance(self.tokenizer, tiktoken.Encoding): return [ self.tokenizer.encode("<|eot_id|>", allowed_special={"<|eot_id|>"})[0], - self.tokenizer.encode("<|end_of_text|>", allowed_special={"<|end_of_text|>"})[0], + self.tokenizer.encode( + "<|end_of_text|>", allowed_special={"<|end_of_text|>"} + )[0], ] + def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( @@ -134,7 +148,7 @@ def main() -> None: model_args = ModelArgs( max_seq_len=args.max_seq_length, generate_full_logits=False, - use_cache_list=False, # cache_list does not work in pybindings + use_cache_list=False, # cache_list does not work in pybindings **params, ) @@ -156,8 +170,8 @@ def main() -> None: generated_tokens.extend(tokens) while input_manager.input_pos < args.n_steps: while len(tokens) > 0: - inputs, remaining_tokens = ( - input_manager.get_inputs_and_remaining_tokens(tokens) + inputs, remaining_tokens = input_manager.get_inputs_and_remaining_tokens( + tokens ) processed_tokens = len(tokens) - len(remaining_tokens) logits, k, v = method.execute(inputs) From cffa508d89e22ea4e78ec73950e383234524c6ed Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 19 Feb 2025 20:12:25 -0800 Subject: [PATCH 08/12] up --- examples/apple/coreml/llama/export.py | 56 +++++++++++++++++++ .../apple/coreml/llama/llama_transformer.py | 3 +- examples/apple/coreml/llama/readme.md | 1 + 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index f86b5306cc..d4ee6363e4 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -26,6 +26,50 @@ from llama.llama_transformer import InputManager, ModelArgs, Transformer +class SplitLinearModule(torch.nn.Module): + def __init__(self, in_features, out_features, target_size): + super(SplitLinearModule, self).__init__() + self.num_splits = max(out_features // target_size, 1) + self.common_size = out_features // self.num_splits + self.remainder = out_features % self.num_splits + self.splits = torch.nn.ModuleList( + [ + torch.nn.Linear(in_features, self.common_size) + for _ in range(self.num_splits) + ] + ) + if self.remainder > 0: + self.splits.append(torch.nn.Linear(in_features, self.remainder)) + + def split_sizes(self): + return [split.out_features for split in self.splits] + + def forward(self, x): + return torch.cat([split(x) for split in self.splits], dim=-1) + + +def replace_linear_with_split_linear(model, target_size): + for name, module in model.named_children(): + if isinstance(module, torch.nn.Linear): + new_module = SplitLinearModule( + module.in_features, module.out_features, target_size + ) + split_sizes = new_module.split_sizes() + if module.bias is not None: + split_bias = module.bias.split(split_sizes) + split_weights = module.weight.split(split_sizes, dim=0) + for i, split in enumerate(new_module.splits): + split.weight = torch.nn.Parameter(split_weights[i]) + if module.bias is not None: + split.bias = torch.nn.Parameter(split_bias[i]) + else: + split.bias = None + setattr(model, name, new_module) + else: + replace_linear_with_split_linear(module, target_size) + + + def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( @@ -80,6 +124,12 @@ def main() -> None: action="store_true", help="Use cache list to speed up model computation (does not work in pybindings)", ) + parser.add_argument( + "--target_size", + type=int, + default=None, + help="Split linear layers into smaller chunks of target_size", + ) export_args = parser.parse_args() params_path = export_args.params @@ -129,6 +179,9 @@ def main() -> None: packed=(bitwidth in [2, 4]), ).quantized_model() + if export_args.target_size is not None: + replace_linear_with_split_linear(model, export_args.target_size) + model = model.to(float_dtype) op_linear_quantizer_config = None @@ -184,6 +237,9 @@ def main() -> None: print("Edge program") print(edge_manager.exported_program()) + for node in edge_manager.exported_program().graph_module.graph.nodes: + print(node.name, node.target, node.args, node.kwargs) + edge_manager = edge_manager.to_backend(partitioner) print("Delegated program") diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index cbc9579800..861bb212b7 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -120,7 +120,6 @@ def __post_init__(self): if self.head_dim is None: self.head_dim = self.dim // self.n_heads - class Rope(torch.nn.Module): def __init__(self, params: ModelArgs): super().__init__() @@ -401,7 +400,7 @@ def forward( if not self.generate_full_logits: # Only the last logit is used for the new generated token - h = h[:, input_length - 1, :] + h = h[:, input_length - 1, :].squeeze(1) h = self.norm(h) diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index d5f45f9ef7..038fdc571f 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -7,6 +7,7 @@ Export model with: python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w ``` +For better performance, use "--use_cache_list" export arg (does not work with pybindings). You can also set "--target_size", which splits linear layers into smaller sizes for the ANE (it defaults to no splitting). This can have substantial impact on performance. For example, on Llama1B by setting "--target_size" to 1024, I see 1.34x increase in inference speed on M1 Pro (but loading time is increased). We need further experiments to tune this. The runner is written in python and is only intended to serve as an example for how the model inputs should be processed; it is not performant. From efc53828bdb8b0f3dc307b824531e8ef396516ee Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 20 Feb 2025 14:12:23 -0800 Subject: [PATCH 09/12] up --- examples/apple/coreml/llama/export.py | 27 +++++-- .../apple/coreml/llama/extract_and_combine.py | 77 ------------------- examples/apple/coreml/llama/readme.md | 18 ++++- 3 files changed, 36 insertions(+), 86 deletions(-) delete mode 100644 examples/apple/coreml/llama/extract_and_combine.py diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index d4ee6363e4..58a480777c 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -27,9 +27,11 @@ class SplitLinearModule(torch.nn.Module): - def __init__(self, in_features, out_features, target_size): + def __init__(self, in_features, out_features, target_size, max_splits): super(SplitLinearModule, self).__init__() self.num_splits = max(out_features // target_size, 1) + if self.num_splits > max_splits: + self.num_splits = max_splits self.common_size = out_features // self.num_splits self.remainder = out_features % self.num_splits self.splits = torch.nn.ModuleList( @@ -38,7 +40,13 @@ def __init__(self, in_features, out_features, target_size): for _ in range(self.num_splits) ] ) + print( + f"Splitting out_features={out_features} into {self.num_splits} of size {self.common_size}" + ) if self.remainder > 0: + print( + f"Warning: remainder {self.remainder} after splitting out_features={out_features} into {self.num_splits} of size {self.common_size}" + ) self.splits.append(torch.nn.Linear(in_features, self.remainder)) def split_sizes(self): @@ -48,11 +56,11 @@ def forward(self, x): return torch.cat([split(x) for split in self.splits], dim=-1) -def replace_linear_with_split_linear(model, target_size): +def replace_linear_with_split_linear(model, target_size, max_splits): for name, module in model.named_children(): if isinstance(module, torch.nn.Linear): new_module = SplitLinearModule( - module.in_features, module.out_features, target_size + module.in_features, module.out_features, target_size, max_splits ) split_sizes = new_module.split_sizes() if module.bias is not None: @@ -66,8 +74,7 @@ def replace_linear_with_split_linear(model, target_size): split.bias = None setattr(model, name, new_module) else: - replace_linear_with_split_linear(module, target_size) - + replace_linear_with_split_linear(module, target_size, max_splits) def main() -> None: @@ -130,6 +137,12 @@ def main() -> None: default=None, help="Split linear layers into smaller chunks of target_size", ) + parser.add_argument( + "--max_splits", + type=int, + default=8, + help="Maximum number of splits to divide linear layers", + ) export_args = parser.parse_args() params_path = export_args.params @@ -180,7 +193,9 @@ def main() -> None: ).quantized_model() if export_args.target_size is not None: - replace_linear_with_split_linear(model, export_args.target_size) + replace_linear_with_split_linear( + model, export_args.target_size, export_args.max_splits + ) model = model.to(float_dtype) diff --git a/examples/apple/coreml/llama/extract_and_combine.py b/examples/apple/coreml/llama/extract_and_combine.py deleted file mode 100644 index f73b5713bb..0000000000 --- a/examples/apple/coreml/llama/extract_and_combine.py +++ /dev/null @@ -1,77 +0,0 @@ -import argparse -import os -import shutil -import subprocess - -import coremltools as ct - -if __name__ == "__main__": - """ - Extract mlpackage from two CoreML pte files, and combine them into one mlpackage using multifunction - """ - parser = argparse.ArgumentParser() - parser.add_argument( - "-m1", - "--model1_path", - type=str, - help="Model1 path.", - ) - parser.add_argument( - "-m2", - "--model2_path", - type=str, - help="Model2 path.", - ) - parser.add_argument( - "-o", - "--output_dir", - type=str, - help="Output path to save combined model", - ) - - args = parser.parse_args() - model1_path = str(args.model1_path) - model2_path = str(args.model2_path) - output_dir = str(args.output_dir) - - if os.path.exists(output_dir): - raise Exception( - f"Output directory {output_dir} already exists. Please make delete it before running script." - ) - os.makedirs(output_dir) - - coreml_extract_path = os.path.join(os.getcwd(), "extracted_coreml_models") - if os.path.exists(coreml_extract_path): - raise Exception( - f"{coreml_extract_path} already exists. Please delete it before running script." - ) - - extract_script_path = os.path.join( - os.path.dirname(__file__), "../scripts/extract_coreml_models.py" - ) - extracted_path = "extracted_coreml_models/model_1/lowered_module/model.mlpackage" - - subprocess.run(["python", extract_script_path, "--model", model1_path]) - items = os.listdir("extracted_coreml_models") - assert len(items) == 1, "Expected one CoreML partition" - shutil.copytree(extracted_path, f"{output_dir}/model1.mlpackage") - - subprocess.run(["python", extract_script_path, "--model", model2_path]) - items = os.listdir("extracted_coreml_models") - assert len(items) == 1, "Expected one CoreML partition" - shutil.copytree(extracted_path, f"{output_dir}/model2.mlpackage") - - desc = ct.utils.MultiFunctionDescriptor() - - desc.add_function( - f"{output_dir}/model1.mlpackage", - src_function_name="main", - target_function_name="model1", - ) - desc.add_function( - f"{output_dir}/model2.mlpackage", - src_function_name="main", - target_function_name="model2", - ) - desc.default_function_name = "model1" - ct.utils.save_multifunction(desc, f"{output_dir}/combined.mlpackage") diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index 038fdc571f..112db16c28 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -7,8 +7,6 @@ Export model with: python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w ``` -For better performance, use "--use_cache_list" export arg (does not work with pybindings). You can also set "--target_size", which splits linear layers into smaller sizes for the ANE (it defaults to no splitting). This can have substantial impact on performance. For example, on Llama1B by setting "--target_size" to 1024, I see 1.34x increase in inference speed on M1 Pro (but loading time is increased). We need further experiments to tune this. - The runner is written in python and is only intended to serve as an example for how the model inputs should be processed; it is not performant. @@ -17,4 +15,18 @@ Run model with: python run.py -m /path/to/model.pte -p /path/to/params.json -t /path/to/tokenizer.model --seq_length 64 --max_seq_length 1024 --prompt "Once upon a time," --n_steps 512 ``` -The model here is based on a "sliding" cache, where old tokens are evicted from the cache. By default, the cache size is max_seq_length - seq_length, but you can explicitly pass in a smaller cache size (e.g., --cache_size 512). This can speed up computation and reduce memory. Keep in mind that once cache_size is reached, older tokens get evicted from the cache and do not participate in attention. +The model here is based on a "sliding" cache, where old tokens are evicted from the cache. There is no actual sliding in the implementation, though.tion. + + +## Export args +* seq_length: the number of tokens processed by the model. Sequences shorter than seq_length must be padded, and sequences longer than it must be chunked. +* max_seq_length: the maximum context tokens that can be processed. +* cache_size: the size of the KV cache sequences. This parameter is optional, and defaults to max_seq_length - seq_length. If a smaller cache_size is used, older tokens are evicted from the cache and no longer play a role in attention. For example, if max_seq_length=1024, but cache_size is 512, the model can generate up to 1024 tokens, but only the current tokens and the previous 512 will participate in attention. In terms of computation, cache_size plays a similar role to max_seq_length in models without cache eviction. +* use_cache_list: boolean option that controls whether KV caches are passed as a list of 4D tensors, one per layer, or if they are passed as one 5D tensor. (Note that use_cache_list does not work with ExecuTorch pybindings.) +* target_size: this option splits linear layers into chunks of target size. For example, if target_size is 1024, a linear layer with (in_features=512, out_features=8096) will be split into 8 linear layers with (in_features=512, out_features=1024) and the results concatted. If not specified, the default is no splitting. +* max_splits: this controls the maximum number of splits for linear layers. It is only relevant if target_size is passed and defaults to 8. + +## Llama1B on iPhone 15 + +We are actively experimenting with different settings, but here are ones we've found that work well on iPhone 15 Pro for Llama1B: +* max_seq_length=1024, seq_length=64, use_cache_list, target_size=1024, max_splits=8 From 64f032178f6351df23278fa47f167c610a5abbf1 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 20 Feb 2025 17:37:40 -0800 Subject: [PATCH 10/12] up --- examples/apple/coreml/llama/export.py | 42 +++++++++---------- .../apple/coreml/llama/llama_transformer.py | 1 + examples/apple/coreml/llama/readme.md | 18 ++++++-- 3 files changed, 35 insertions(+), 26 deletions(-) diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index 58a480777c..8cf10375f9 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -27,27 +27,25 @@ class SplitLinearModule(torch.nn.Module): - def __init__(self, in_features, out_features, target_size, max_splits): + def __init__(self, in_features, out_features, target_split_size, max_splits): super(SplitLinearModule, self).__init__() - self.num_splits = max(out_features // target_size, 1) - if self.num_splits > max_splits: - self.num_splits = max_splits - self.common_size = out_features // self.num_splits - self.remainder = out_features % self.num_splits + num_splits = max(out_features // target_split_size, 1) + if num_splits > max_splits: + num_splits = max_splits + + self.split_size = out_features // num_splits + self.split_remainder = out_features % num_splits self.splits = torch.nn.ModuleList( - [ - torch.nn.Linear(in_features, self.common_size) - for _ in range(self.num_splits) - ] + [torch.nn.Linear(in_features, self.split_size) for _ in range(num_splits)] ) print( - f"Splitting out_features={out_features} into {self.num_splits} of size {self.common_size}" + f"Splitting out_features={out_features} into {num_splits} of size {self.split_size}" ) - if self.remainder > 0: + if self.split_remainder > 0: print( - f"Warning: remainder {self.remainder} after splitting out_features={out_features} into {self.num_splits} of size {self.common_size}" + f"Warning: remainder {self.split_remainder} after splitting out_features={out_features} into {num_splits} of size {self.split_size}" ) - self.splits.append(torch.nn.Linear(in_features, self.remainder)) + self.splits.append(torch.nn.Linear(in_features, self.split_remainder)) def split_sizes(self): return [split.out_features for split in self.splits] @@ -56,11 +54,11 @@ def forward(self, x): return torch.cat([split(x) for split in self.splits], dim=-1) -def replace_linear_with_split_linear(model, target_size, max_splits): +def replace_linear_with_split_linear(model, target_split_size, max_splits): for name, module in model.named_children(): if isinstance(module, torch.nn.Linear): new_module = SplitLinearModule( - module.in_features, module.out_features, target_size, max_splits + module.in_features, module.out_features, target_split_size, max_splits ) split_sizes = new_module.split_sizes() if module.bias is not None: @@ -74,7 +72,7 @@ def replace_linear_with_split_linear(model, target_size, max_splits): split.bias = None setattr(model, name, new_module) else: - replace_linear_with_split_linear(module, target_size, max_splits) + replace_linear_with_split_linear(module, target_split_size, max_splits) def main() -> None: @@ -98,7 +96,7 @@ def main() -> None: parser.add_argument( "--seq_length", type=int, - default=1, # set to 1 for decode + default=1, help="length sequence to evaluate", ) parser.add_argument( @@ -132,10 +130,10 @@ def main() -> None: help="Use cache list to speed up model computation (does not work in pybindings)", ) parser.add_argument( - "--target_size", + "--target_split_size", type=int, default=None, - help="Split linear layers into smaller chunks of target_size", + help="Split linear layers into smaller chunks of target_split_size.", ) parser.add_argument( "--max_splits", @@ -192,9 +190,9 @@ def main() -> None: packed=(bitwidth in [2, 4]), ).quantized_model() - if export_args.target_size is not None: + if export_args.target_split_size is not None: replace_linear_with_split_linear( - model, export_args.target_size, export_args.max_splits + model, export_args.target_split_size, export_args.max_splits ) model = model.to(float_dtype) diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index 861bb212b7..1c70aea7ee 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -120,6 +120,7 @@ def __post_init__(self): if self.head_dim is None: self.head_dim = self.dim // self.n_heads + class Rope(torch.nn.Module): def __init__(self, params: ModelArgs): super().__init__() diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index 112db16c28..3daf91db1f 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -7,6 +7,8 @@ Export model with: python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w ``` +(Note the script should be run from the executorch/examples/apple/coreml/llama directory.) + The runner is written in python and is only intended to serve as an example for how the model inputs should be processed; it is not performant. @@ -15,7 +17,10 @@ Run model with: python run.py -m /path/to/model.pte -p /path/to/params.json -t /path/to/tokenizer.model --seq_length 64 --max_seq_length 1024 --prompt "Once upon a time," --n_steps 512 ``` -The model here is based on a "sliding" cache, where old tokens are evicted from the cache. There is no actual sliding in the implementation, though.tion. + +(Note the script should be run from the executorch/examples/apple/coreml/llama directory.) + +The model here is based on a "sliding" cache, where old tokens are evicted from the cache. There is no actual sliding in the implementation, though. ## Export args @@ -23,10 +28,15 @@ The model here is based on a "sliding" cache, where old tokens are evicted from * max_seq_length: the maximum context tokens that can be processed. * cache_size: the size of the KV cache sequences. This parameter is optional, and defaults to max_seq_length - seq_length. If a smaller cache_size is used, older tokens are evicted from the cache and no longer play a role in attention. For example, if max_seq_length=1024, but cache_size is 512, the model can generate up to 1024 tokens, but only the current tokens and the previous 512 will participate in attention. In terms of computation, cache_size plays a similar role to max_seq_length in models without cache eviction. * use_cache_list: boolean option that controls whether KV caches are passed as a list of 4D tensors, one per layer, or if they are passed as one 5D tensor. (Note that use_cache_list does not work with ExecuTorch pybindings.) -* target_size: this option splits linear layers into chunks of target size. For example, if target_size is 1024, a linear layer with (in_features=512, out_features=8096) will be split into 8 linear layers with (in_features=512, out_features=1024) and the results concatted. If not specified, the default is no splitting. +* target_split_size: this option splits linear layers into chunks of target size. For example, if target_split_size is 1024, a linear layer with (in_features=512, out_features=8096) will be split into 8 linear layers with (in_features=512, out_features=1024) and the results concatted. If not specified, the default is no splitting. * max_splits: this controls the maximum number of splits for linear layers. It is only relevant if target_size is passed and defaults to 8. ## Llama1B on iPhone 15 -We are actively experimenting with different settings, but here are ones we've found that work well on iPhone 15 Pro for Llama1B: -* max_seq_length=1024, seq_length=64, use_cache_list, target_size=1024, max_splits=8 +We are actively experimenting with different settings. But here are ones that we've found work well for Llama1B on iPhone 15 Pro: + +* Set use_cache_list +* Split linear layers with target_split_size=1024, max_splits=8 +* Use seq_length=32 or seq_length=64, both of which offer reasonable tradeoffs for prefill and decode performance. seq_length=32 is better at decode and seq_length=64 is better at prefill. + +In our tests, we set max_seq_length=1024, but if your application allows for it, performance can improve with max_seq_length=512 or by keeping max_seq_length=1024 and setting cache_size=512-seq_length. From 7b3bb13e621a4fba379ec405fe6e27a748a69021 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 21 Feb 2025 13:37:52 -0800 Subject: [PATCH 11/12] up --- examples/apple/coreml/llama/export.py | 13 +- .../apple/coreml/llama/llama_transformer.py | 23 ++- examples/apple/coreml/llama/readme.md | 5 +- examples/apple/coreml/llama/run.py | 179 ++++++------------ 4 files changed, 84 insertions(+), 136 deletions(-) diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index 8cf10375f9..768048227a 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -22,8 +22,8 @@ from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.extension.export_util.utils import export_to_edge, save_pte_program -sys.path.insert(0, "..") -from llama.llama_transformer import InputManager, ModelArgs, Transformer +sys.path.insert(0, ".") +from llama_transformer import InputManager, ModelArgs, Transformer class SplitLinearModule(torch.nn.Module): @@ -125,7 +125,7 @@ def main() -> None: help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight), c4w (for channelwise 4 bit weight)", ) parser.add_argument( - "--use-cache-list", + "--use_cache_list", action="store_true", help="Use cache list to speed up model computation (does not work in pybindings)", ) @@ -230,7 +230,12 @@ def main() -> None: ) input_manager = InputManager( - model_args=args, + n_layers=args.n_layers, + max_batch_size=args.max_batch_size, + n_kv_heads=args.n_kv_heads, + max_seq_length=args.max_seq_len, + head_dim=args.head_dim, + use_cache_list=export_args.use_cache_list, seq_length=export_args.seq_length, dtype=float_dtype, minus_infinity=-30000, diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index 1c70aea7ee..5788bcd5e5 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -416,24 +416,29 @@ def forward( class InputManager: def __init__( self, - model_args: ModelArgs, - seq_length, + n_layers: int, + max_batch_size: int, + n_kv_heads: int, + max_seq_length: int, + head_dim: int, + use_cache_list: bool, + seq_length: int, dtype=torch.float16, minus_infinity=-torch.inf, cache_size=None, ): if cache_size is None: - cache_size = model_args.max_seq_len - seq_length + cache_size = max_seq_length - seq_length self.cache_size = cache_size - assert self.cache_size + seq_length <= model_args.max_seq_len + assert self.cache_size + seq_length <= max_seq_length - self.n_layers = model_args.n_layers - self.max_batch_size = model_args.max_batch_size - self.n_kv_heads = model_args.n_kv_heads - self.head_dim = model_args.head_dim + self.n_layers = n_layers + self.max_batch_size = max_batch_size + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim self.seq_length = seq_length - self.use_cache_list = model_args.use_cache_list + self.use_cache_list = use_cache_list if self.use_cache_list: self.k_caches = [ diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index 3daf91db1f..353f0b5630 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -14,14 +14,11 @@ The runner is written in python and is only intended to serve as an example for Run model with: ``` -python run.py -m /path/to/model.pte -p /path/to/params.json -t /path/to/tokenizer.model --seq_length 64 --max_seq_length 1024 --prompt "Once upon a time," --n_steps 512 +python run.py -m /path/to/model.pte -t /path/to/tokenizer.model --prompt "Once upon a time," ``` - (Note the script should be run from the executorch/examples/apple/coreml/llama directory.) -The model here is based on a "sliding" cache, where old tokens are evicted from the cache. There is no actual sliding in the implementation, though. - ## Export args * seq_length: the number of tokens processed by the model. Sequences shorter than seq_length must be padded, and sequences longer than it must be chunked. diff --git a/examples/apple/coreml/llama/run.py b/examples/apple/coreml/llama/run.py index 803949046c..323fbe4601 100644 --- a/examples/apple/coreml/llama/run.py +++ b/examples/apple/coreml/llama/run.py @@ -1,21 +1,18 @@ import argparse import sys -from multiprocessing import process -from pathlib import Path - -import torch - - -sys.path.insert(0, "..") -import json import sentencepiece as spm import tiktoken +import torch + from executorch.runtime import Runtime -from llama.llama_transformer import InputManager, ModelArgs -from tiktoken.load import load_tiktoken_bpe + +sys.path.insert(0, ".") +from executorch.examples.models.llama.runner.generation import next_token +from executorch.examples.models.llama.tokenizer import tiktoken +from llama_transformer import InputManager class Tokenizer: @@ -27,70 +24,25 @@ def __init__(self, model_path: str): sp.load(model_path) self.tokenizer = sp except: - print("Trying to tiktoken") - self.num_reserved_special_tokens = 256 - self.pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 - - mergeable_ranks = load_tiktoken_bpe(model_path) - num_base_tokens = len(mergeable_ranks) - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|reserved_special_token_2|>", - "<|reserved_special_token_3|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|reserved_special_token_4|>", - "<|eot_id|>", # end of turn - ] + [ - f"<|reserved_special_token_{i}|>" - for i in range(5, self.num_reserved_special_tokens - 5) - ] - self.special_tokens = { - token: num_base_tokens + i for i, token in enumerate(special_tokens) - } - self.tokenizer = tiktoken.Encoding( - name=Path(model_path).name, - pat_str=self.pat_str, - mergeable_ranks=mergeable_ranks, - special_tokens=self.special_tokens, - ) + print("Trying to load tiktoken") + self.tokenizer = tiktoken.Tokenizer(model_path) - def encode(self, text): - return self.tokenizer.encode(text) + def encode(self, text, bos, eos): + if isinstance(self.tokenizer, spm.SentencePieceProcessor): + bos_string = "" if bos else "" + eos_string = "" if eos else "" + return self.tokenizer.encode(f"{bos_string}{text}{eos_string}") + return self.tokenizer.encode(text, bos=bos, eos=eos) - def encode_prompt(self, text): + def decode_token(self, token): if isinstance(self.tokenizer, spm.SentencePieceProcessor): - return self.tokenizer.encode(text) - - get_prompt = ( - lambda x: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{x}<|eot_id|><|start_header_id|>assistant<|end_header_id|>" - ) - return self.tokenizer.encode( - get_prompt(text), - allowed_special={ - "<|begin_of_text|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|eot_id|>", - }, - ) - - def decode(self, tokens): - return self.tokenizer.decode(tokens) + return f"{self.tokenizer.decode(token)} " + return self.tokenizer.decode_token(token) def stop_tokens(self): if isinstance(self.tokenizer, spm.SentencePieceProcessor): return [self.tokenizer.eos_id()] - if isinstance(self.tokenizer, tiktoken.Encoding): - return [ - self.tokenizer.encode("<|eot_id|>", allowed_special={"<|eot_id|>"})[0], - self.tokenizer.encode( - "<|end_of_text|>", allowed_special={"<|end_of_text|>"} - )[0], - ] + return self.tokenizer.stop_tokens def main() -> None: @@ -100,76 +52,68 @@ def main() -> None: "--model", help="model.pte", ) - parser.add_argument( - "-p", - "--params", - help="config.json", - ) parser.add_argument( "-t", "--tokenizer", help="tokenizer.model path", ) - parser.add_argument( - "--seq_length", - type=int, - default=1, # set to 1 for decode - help="length sequence to evaluate", - ) - parser.add_argument( - "--max_seq_length", - type=int, - default=128, - help="maximum length sequence to evaluate", - ) parser.add_argument( "--prompt", type=str, default="Once upon a time,", ) parser.add_argument( - "--n_steps", - type=int, + "--temperature", + type=float, + default=0.6, ) parser.add_argument( - "--cache_size", - type=int, - default=None, - help="Cache size. Old items are evicted from cache", + "--top_p", + type=float, + default=0.9, ) args = parser.parse_args() - params_path = args.params - # Load model args - with open(params_path, "r") as f: - params = json.loads(f.read()) + tokenizer = Tokenizer(args.tokenizer) - model_args = ModelArgs( - max_seq_len=args.max_seq_length, - generate_full_logits=False, - use_cache_list=False, # cache_list does not work in pybindings - **params, + runtime = Runtime.get() + program = runtime.load_program(args.model) + method = program.load_method("forward") + + metadata = method.metadata + print("Method metadata: ", metadata, "\n\n") + + assert ( + metadata.num_inputs() == 6 + ), "Do not export with --use_cache_list for use in pybindings" + # k_cache input + n_layers, max_batch_size, n_kv_heads, cache_size, head_dim = ( + metadata.input_tensor_meta(3).sizes() ) + # mask input + seq_length, max_seq_length = metadata.input_tensor_meta(5).sizes() + input_manager = InputManager( - model_args=model_args, - seq_length=args.seq_length, + n_layers=n_layers, + max_batch_size=max_batch_size, + n_kv_heads=n_kv_heads, + max_seq_length=max_seq_length, + head_dim=head_dim, + use_cache_list=False, + seq_length=seq_length, dtype=torch.float16, - minus_infinity=-30000, - cache_size=args.cache_size, + minus_infinity=-30000.0, + cache_size=cache_size, ) - tokenizer = Tokenizer(args.tokenizer) - - runtime = Runtime.get() - program = runtime.load_program(args.model) - method = program.load_method("forward") - generated_tokens = [] - tokens = tokenizer.encode_prompt(args.prompt) - generated_tokens.extend(tokens) - while input_manager.input_pos < args.n_steps: - while len(tokens) > 0: + print(args.prompt, end="") + tokens = tokenizer.encode(args.prompt, bos=True, eos=False) + while input_manager.input_pos + seq_length < max_seq_length: + while len(tokens) > 0 and ( + input_manager.input_pos + seq_length < max_seq_length + ): inputs, remaining_tokens = input_manager.get_inputs_and_remaining_tokens( tokens ) @@ -180,14 +124,11 @@ def main() -> None: ) tokens = remaining_tokens - tokens = [logits.argmax(-1).item()] - generated_tokens.extend(tokens) + tokens = [next_token(logits, args.temperature, args.top_p)] + if tokens[-1] in tokenizer.stop_tokens(): break - print(tokenizer.decode([generated_tokens[-1]]), end=" ", flush=True) - - print("\n\nFull text:") - print(tokenizer.decode(generated_tokens)) + print(tokenizer.decode_token(tokens[-1]), end="", flush=True) if __name__ == "__main__": From 92a1be871a346082b5d72334cc03a822f82951c6 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 21 Feb 2025 14:25:46 -0800 Subject: [PATCH 12/12] lint --- examples/apple/coreml/llama/export.py | 4 +++- examples/apple/coreml/llama/run.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index 768048227a..58bc0859c7 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -160,7 +160,9 @@ def main() -> None: with torch.device("meta"): model = Transformer(args) - checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True) + checkpoint = torch.load( + checkpoint_path, map_location="cpu", mmap=True, weights_only=True + ) if "model" in checkpoint: checkpoint = checkpoint["model"] diff --git a/examples/apple/coreml/llama/run.py b/examples/apple/coreml/llama/run.py index 323fbe4601..65026e1f6b 100644 --- a/examples/apple/coreml/llama/run.py +++ b/examples/apple/coreml/llama/run.py @@ -2,7 +2,6 @@ import sys import sentencepiece as spm -import tiktoken import torch