diff --git a/Cargo.lock b/Cargo.lock index ef6e1ab5..d4b5a65f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1278,6 +1278,7 @@ dependencies = [ "clap", "llm-base", "llm-bloom", + "llm-falcon", "llm-gpt2", "llm-gptj", "llm-gptneox", @@ -1335,6 +1336,14 @@ dependencies = [ "zstd 0.12.3+zstd.1.5.2", ] +[[package]] +name = "llm-falcon" +version = "0.2.0-dev" +dependencies = [ + "bytemuck", + "llm-base", +] + [[package]] name = "llm-gpt2" version = "0.2.0-dev" diff --git a/binaries/llm-cli/Cargo.toml b/binaries/llm-cli/Cargo.toml index dea1f3d3..bac2ff87 100644 --- a/binaries/llm-cli/Cargo.toml +++ b/binaries/llm-cli/Cargo.toml @@ -35,3 +35,6 @@ rusty-hook = "^0.11.2" cublas = ["llm/cublas"] clblast = ["llm/clblast"] metal = ["llm/metal"] + +# Falcon is off by default. See `llm_falcon`'s module documentation for more information. +falcon = ["llm/falcon"] diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 39a14fb4..5aec546f 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -44,6 +44,13 @@ pub enum Args { #[command(subcommand)] args: BaseArgs, }, + /// Use a Falcon model + #[clap(id = "falcon")] + #[cfg(feature = "falcon")] + Falcon { + #[command(subcommand)] + args: BaseArgs, + }, } #[derive(Subcommand, Debug)] diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 689b75c7..679a753e 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -33,6 +33,8 @@ fn main() -> Result<()> { Args::GptJ { args } => handle_args::(args), Args::GptNeoX { args } => handle_args::(args), Args::Mpt { args } => handle_args::(args), + #[cfg(feature = "falcon")] + Args::Falcon { args } => handle_args::(args), } } diff --git a/binaries/precommit-check/src/main.rs b/binaries/precommit-check/src/main.rs index 945881d1..04d3add0 100644 --- a/binaries/precommit-check/src/main.rs +++ b/binaries/precommit-check/src/main.rs @@ -1,16 +1,23 @@ fn main() { // Ensure that these match `.github/workflows/rust.yml`. - cmd("cargo", &["check"]); - cmd("cargo", &["test", "--all"]); - cmd("cargo", &["fmt", "--check", "--all"]); - cmd("cargo", &["doc", "--workspace", "--exclude", "llm-cli"]); - cmd("cargo", &["clippy", "--", "-Dclippy::all"]); + cmd("cargo", &["check"], &[]); + cmd("cargo", &["test", "--all"], &[]); + cmd("cargo", &["fmt", "--check", "--all"], &[]); + cmd( + "cargo", + &["doc", "--workspace", "--exclude", "llm-cli"], + &[("RUSTDOCFLAGS", "-Dwarnings")], + ); + cmd("cargo", &["clippy", "--", "-Dclippy::all"], &[]); } -fn cmd(cmd: &str, args: &[&str]) { +fn cmd(cmd: &str, args: &[&str], env: &[(&str, &str)]) { println!("=== Running command: {cmd} {args:?}"); - let mut child = std::process::Command::new(cmd).args(args).spawn().unwrap(); + let mut builder = std::process::Command::new(cmd); + builder.args(args); + builder.envs(env.iter().copied()); + let mut child = builder.spawn().unwrap(); if !child.wait().unwrap().success() { - panic!("Failed to run command: {} {:?}", cmd, args); + panic!("Failed to run command: {} {:?}", cmd, builder); } } diff --git a/crates/llm/Cargo.toml b/crates/llm/Cargo.toml index 3a4f6bc7..035b8a83 100644 --- a/crates/llm/Cargo.toml +++ b/crates/llm/Cargo.toml @@ -15,6 +15,7 @@ llm-gptj = { path = "../models/gptj", optional = true, version = "0.2.0-dev" } llm-bloom = { path = "../models/bloom", optional = true, version = "0.2.0-dev" } llm-gptneox = { path = "../models/gptneox", optional = true, version = "0.2.0-dev" } llm-mpt = { path = "../models/mpt", optional = true, version = "0.2.0-dev" } +llm-falcon = { path = "../models/falcon", optional = true, version = "0.2.0-dev" } serde = { workspace = true } @@ -29,12 +30,16 @@ clap = { workspace = true } [features] default = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt"] + llama = ["dep:llm-llama"] gpt2 = ["dep:llm-gpt2"] gptj = ["dep:llm-gptj"] bloom = ["dep:llm-bloom"] gptneox = ["dep:llm-gptneox"] mpt = ["dep:llm-mpt"] +# Falcon is off by default. See `llm_falcon`'s module documentation for more information. +falcon = ["dep:llm-falcon"] + cublas = ["llm-base/cublas"] clblast = ["llm-base/clblast"] metal = ["llm-base/metal"] diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index b5e12da7..8adda7e7 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -7,6 +7,7 @@ //! - [GPT-NeoX](llm_gptneox) //! - [LLaMA](llm_llama) //! - [MPT](llm_mpt) +//! - Falcon (currently disabled due to incompleteness) //! //! At present, the only supported backend is [GGML](https://github.com/ggerganov/ggml), but this is expected to //! change in the future. @@ -92,6 +93,8 @@ use serde::Serialize; pub mod models { #[cfg(feature = "bloom")] pub use llm_bloom::{self as bloom, Bloom}; + #[cfg(feature = "falcon")] + pub use llm_falcon::{self as falcon, Falcon}; #[cfg(feature = "gpt2")] pub use llm_gpt2::{self as gpt2, Gpt2}; #[cfg(feature = "gptj")] @@ -125,6 +128,9 @@ pub enum ModelArchitecture { #[cfg(feature = "mpt")] /// [MPT](llm_mpt) Mpt, + #[cfg(feature = "falcon")] + /// [Falcon](llm_falcon) + Falcon, } impl ModelArchitecture { @@ -142,6 +148,8 @@ impl ModelArchitecture { Self::Llama, #[cfg(feature = "mpt")] Self::Mpt, + #[cfg(feature = "falcon")] + Self::Falcon, ]; } @@ -185,6 +193,8 @@ impl FromStr for ModelArchitecture { "llama" => Ok(Llama), #[cfg(feature = "mpt")] "mpt" => Ok(Mpt), + #[cfg(feature = "falcon")] + "falcon" => Ok(Falcon), _ => Err(UnsupportedModelArchitecture(format!( "{s} is not a supported model architecture" @@ -210,6 +220,8 @@ impl Display for ModelArchitecture { Llama => write!(f, "LLaMA"), #[cfg(feature = "mpt")] Mpt => write!(f, "MPT"), + #[cfg(feature = "falcon")] + Falcon => write!(f, "Falcon"), } } } @@ -264,6 +276,10 @@ pub fn load_dynamic( } #[cfg(feature = "mpt")] Mpt => load_model::(path, vocabulary_source, params, load_progress_callback)?, + #[cfg(feature = "falcon")] + Falcon => { + load_model::(path, vocabulary_source, params, load_progress_callback)? + } }; Ok(model) diff --git a/crates/models/falcon/Cargo.toml b/crates/models/falcon/Cargo.toml new file mode 100644 index 00000000..e71c261c --- /dev/null +++ b/crates/models/falcon/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "llm-falcon" +version = "0.2.0-dev" +license = { workspace = true } +repository = { workspace = true } +description = "An implementation of Falcon for the `llm` ecosystem." +edition = "2021" +readme = "../../../README.md" + +[dependencies] +llm-base = { path = "../../llm-base", version = "0.2.0-dev" } + +bytemuck = { workspace = true } diff --git a/crates/models/falcon/src/lib.rs b/crates/models/falcon/src/lib.rs new file mode 100644 index 00000000..3b989e26 --- /dev/null +++ b/crates/models/falcon/src/lib.rs @@ -0,0 +1,419 @@ +//! An implementation of the [Falcon](https://falconllm.tii.ae/) model for the `llm` ecosystem. +//! +//! This implementation only works for Falcon 7B, and with 32-bit memory tensors (i.e. your inference session +//! must be configured with a 32-bit [InferenceSessionConfig]). +//! +//! This model will not be generally available in the `llm` ecosystem until Falcon 40B and 16-bit memory is +//! supported. It is currently only available as a preview. +#![deny(missing_docs)] + +use std::sync::Arc; + +use ggml::Tensor; +use llm_base::{ + ggml, + model::{common, HyperparametersWriteError}, + util, FileType, GraphOutputs, InferenceParameters, InferenceSession, InferenceSessionConfig, + KnownModel, LoadError, ModelParameters, OutputRequest, Regex, TokenId, Vocabulary, +}; + +/// The Falcon model. Ref: [Technology Innovation Institute](https://huggingface.co/tiiuae) +/// +/// # Safety +/// This implements [Send] and [Sync] as it is immutable after construction. +pub struct Falcon { + // the context size ("memory") the model should use when evaluating a prompt + context_size: usize, + + hyperparameters: Hyperparameters, + + vocabulary: Vocabulary, + + // model-global weights + // weighted token embeddings + tok_embeddings: Tensor, + output_norm: Tensor, + output_norm_b: Tensor, + lm_head: Tensor, + + // weights for the model + layers: Vec, + + // must be kept alive for the model + context: Arc, +} + +unsafe impl Send for Falcon {} +unsafe impl Sync for Falcon {} + +impl KnownModel for Falcon { + type Hyperparameters = Hyperparameters; + + fn new( + hyperparameters: Self::Hyperparameters, + params: ModelParameters, + vocabulary: Vocabulary, + tensor_loader: impl llm_base::TensorLoader, + ) -> Result { + let mut tl = tensor_loader; + + // model-gobal weights + let tok_embeddings = tl.load("transformer.word_embeddings.weight")?; + let output_norm = tl.load("transformer.ln_f.weight")?; + let output_norm_b = tl.load("transformer.ln_f.bias")?; + let lm_head = tl.load("lm_head.weight")?; + + let mut layers = Vec::new(); + for i in 0..hyperparameters.n_layer { + let layer = Layer { + attention_norm: tl.load(&format!("transformer.h.{i}.input_layernorm.weight"))?, + attention_norm_b: tl.load(&format!("transformer.h.{i}.input_layernorm.bias"))?, + + query_key_value: tl.load(&format!( + "transformer.h.{i}.self_attention.query_key_value.weight" + ))?, + wo: tl.load(&format!("transformer.h.{i}.self_attention.dense.weight"))?, + + ffn_up: tl.load(&format!("transformer.h.{i}.mlp.dense_h_to_4h.weight"))?, + ffn_down: tl.load(&format!("transformer.h.{i}.mlp.dense_4h_to_h.weight"))?, + }; + + layers.push(layer); + } + + let (context, _) = tl.finish(); + + let ModelParameters { context_size, .. } = params; + + Ok(Falcon { + hyperparameters, + context_size, + vocabulary, + tok_embeddings, + output_norm, + output_norm_b, + lm_head, + layers, + context: Arc::new(context), + }) + } + + fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { + InferenceSession::new( + config, + self.context_size, + self.hyperparameters.n_layer, + self.hyperparameters.n_embd, + self.hyperparameters.n_vocab, + ) + } + + fn evaluate( + &self, + session: &mut InferenceSession, + params: &InferenceParameters, + input_tokens: &[TokenId], + output_request: &mut OutputRequest, + ) { + let input_len = input_tokens.len(); + let session_len = session.n_past; + let num_threads = params.n_threads; + let ctx_size = self.context_size; + + let Hyperparameters { + n_embd, + n_head, + n_vocab, + n_layer, + .. + } = self.hyperparameters; + + let head_dim = n_embd / n_head; + let n = input_len; + + let outputs = session.compute(self.context.clone(), input_tokens, |mut builder| { + let ctx0 = builder.ctx0; + let embd = builder.embd; + let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, embd); + let repeat_dummy = ctx0.new_tensor_3d( + input_layer.get_type(), + head_dim, + input_len + session_len, + n_head, + ); + + let f32_size = std::mem::size_of::(); + + let memory_k = builder.memory_k; + let memory_k_size = memory_k.element_size(); + + let memory_v = builder.memory_v; + let memory_v_size = memory_v.element_size(); + + let mut gf = ggml::ComputationGraph::new(num_threads); + + let mut current: Tensor; + let mut layernorm_output: Tensor; + + for il in 0..n_layer { + // attention uses first scratch buffer + builder.use_scratch(Some(0)); + + // self-attention + current = ctx0.op_norm(&input_layer); + current = ctx0.op_add( + &ctx0.op_mul( + &ctx0.op_repeat(&self.layers[il].attention_norm, ¤t), + ¤t, + ), + &ctx0.op_repeat(&self.layers[il].attention_norm_b, ¤t), + ); + + layernorm_output = current.share(); + + // compute QKV + current = ctx0.op_mul_mat(&self.layers[il].query_key_value, ¤t); + + let fused_qkv_row_nb = (n_embd + 2 * (n_embd / n_head)) * f32_size; + + let mut qcur = ctx0.op_view_3d( + ¤t, + (head_dim, n_head, n), + (head_dim * f32_size, fused_qkv_row_nb), + 0, + ); + + let mut kcur = ctx0.op_view_3d( + ¤t, + (head_dim, 1, n), + (head_dim * f32_size, fused_qkv_row_nb), + n_embd * f32_size, + ); + + let vcur = ctx0.op_view_3d( + ¤t, + (head_dim, 1, n), + (head_dim * f32_size, fused_qkv_row_nb), + (n_embd + head_dim) * f32_size, + ); + + // using mode = 2 for neox mode + qcur = ctx0.op_rope_inplace(&qcur, session_len, head_dim, 2); + kcur = ctx0.op_rope_inplace(&kcur, session_len, head_dim, 2); + + // store key and value to memory + + let k = ctx0.op_view_1d( + memory_k, + n * head_dim, + (memory_k_size * head_dim) * (il * ctx_size + session_len), + ); + let v = ctx0.op_view_1d( + memory_v, + n * head_dim, + (memory_v_size * head_dim) * (il * ctx_size + session_len), + ); + + gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); + gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v)); + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + let bigq = ctx0.op_permute(&qcur, (0, 2, 1, 3)); + + let mut bigk = ctx0.op_permute( + &ctx0.op_reshape_3d( + &ctx0.op_view_1d( + memory_k, + (session_len + n) * head_dim, + il * ctx_size * memory_k_size * head_dim, + ), + head_dim, + 1, + session_len + n, + ), + (0, 2, 1, 3), + ); + // K * Q + bigk = ctx0.op_cont(&ctx0.op_repeat(&bigk, &repeat_dummy)); + let big_kq = ctx0.op_mul_mat(&bigk, &bigq); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + let big_kq_scaled = ctx0.op_scale_inplace( + &big_kq, + &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)), + ); + + let big_kq_masked = ctx0.op_diag_mask_inf_inplace(&big_kq_scaled, session_len); + + let big_kq_softmax = ctx0.op_soft_max_inplace(&big_kq_masked); + + let mut bigv = ctx0.op_permute( + &ctx0.op_reshape_3d( + &ctx0.op_view_1d( + memory_v, + (session_len + n) * head_dim, + il * ctx_size * memory_v_size * head_dim, + ), + head_dim, + 1, + session_len + n, + ), + (0, 2, 1, 3), + ); + bigv = ctx0.op_cont(&ctx0.op_transpose(&ctx0.op_repeat(&bigv, &repeat_dummy))); + + // KQV = transpose(V) * KQ_soft_max + let big_kqv = ctx0.op_mul_mat(&bigv, &big_kq_softmax); + // KQV_merged = KQV.permute(0, 2, 1, 3) + let big_kqv_merged = ctx0.op_permute(&big_kqv, (0, 2, 1, 3)); + + // cur = KQV_merged.contiguous().view(n_embd, N) + current = ctx0.op_cpy( + &big_kqv_merged, + &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n), + ); + + // projection + current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); + + // feed forward uses second scratch buffer + builder.use_scratch(Some(1)); + + let inp_ff = layernorm_output.share(); + let attn_out = + ctx0.op_cpy(¤t, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n)); + + current = ctx0.op_mul_mat(&self.layers[il].ffn_up, &inp_ff); + current = ctx0.op_gelu(¤t); + current = ctx0.op_mul_mat(&self.layers[il].ffn_down, ¤t); + + current = ctx0.op_add(¤t, &attn_out); + current = ctx0.op_add(¤t, &input_layer); + + input_layer = current.share(); + } + + builder.use_scratch(Some(0)); + + // norm + input_layer = ctx0.op_norm(&input_layer); + + input_layer = ctx0.op_add( + &ctx0.op_mul( + &ctx0.op_repeat(&self.output_norm, &input_layer), + &input_layer, + ), + &ctx0.op_repeat(&self.output_norm_b, &input_layer), + ); + + let embeddings_tensor: ggml::Tensor = input_layer.share(); + + builder.use_scratch(None); + + // lm_head + input_layer = ctx0.op_mul_mat(&self.lm_head, &input_layer); + + ( + gf, + GraphOutputs { + result: input_layer, + embedding_result: embeddings_tensor, + }, + ) + }); + + // finish evaluation + common::read_last_token(session, &outputs.result, n_vocab, input_len); + common::extract_logits(output_request, &outputs.result, n_vocab, input_len); + common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); + } + + /// Returns the vocabulary used by this model. + fn vocabulary(&self) -> &Vocabulary { + &self.vocabulary + } + + fn context_size(&self) -> usize { + self.context_size + } + + fn bot_token_id(&self) -> Option { + None + } + + fn eot_token_id(&self) -> TokenId { + self.vocabulary.id("<|endoftext|>".as_bytes()).unwrap() + } + + fn quantize_tensors() -> Vec { + vec![Regex::new(".*weight").unwrap()] + } + + fn skip_quantize_tensors() -> Vec { + vec![] + } +} + +/// Falcon [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) +#[derive(Debug, Default, PartialEq, Clone, Copy)] +pub struct Hyperparameters { + /// Size of the model's vocabulary + n_vocab: usize, + /// Size of the model's embedding layer + n_embd: usize, + /// n_heads + n_head: usize, + /// Number of layers in the model + n_layer: usize, + /// file_type + file_type: FileType, +} + +impl llm_base::Hyperparameters for Hyperparameters { + fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { + let hyperparameters = Hyperparameters { + n_vocab: util::read_i32(reader)?.try_into()?, + n_embd: util::read_i32(reader)?.try_into()?, + n_head: util::read_i32(reader)?.try_into()?, + n_layer: util::read_i32(reader)?.try_into()?, + file_type: util::read_filetype(reader)?, + }; + + Ok(hyperparameters) + } + + fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { + util::write_i32(writer, self.n_vocab.try_into()?)?; + util::write_i32(writer, self.n_embd.try_into()?)?; + util::write_i32(writer, self.n_head.try_into()?)?; + util::write_i32(writer, self.n_layer.try_into()?)?; + util::write_i32(writer, self.file_type.into())?; + Ok(()) + } + + fn n_vocabulary(&self) -> usize { + self.n_vocab + } + + fn file_type(&self) -> Option { + Some(self.file_type) + } + + fn file_type_mut(&mut self) -> Option<&mut FileType> { + Some(&mut self.file_type) + } +} + +struct Layer { + // normalization + attention_norm: Tensor, + attention_norm_b: Tensor, + + // attention + query_key_value: Tensor, + wo: Tensor, + + // ff + ffn_up: Tensor, + ffn_down: Tensor, +}