Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Add Falcon Support #313

Merged
merged 12 commits into from
Jun 28, 2023
9 changes: 9 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions binaries/llm-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ rusty-hook = "^0.11.2"
cublas = ["llm/cublas"]
clblast = ["llm/clblast"]
metal = ["llm/metal"]

# Falcon is off by default. See `llm_falcon`'s module documentation for more information.
falcon = ["llm/falcon"]
7 changes: 7 additions & 0 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ pub enum Args {
#[command(subcommand)]
args: BaseArgs,
},
/// Use a Falcon model
#[clap(id = "falcon")]
#[cfg(feature = "falcon")]
Falcon {
#[command(subcommand)]
args: BaseArgs,
},
}

#[derive(Subcommand, Debug)]
Expand Down
2 changes: 2 additions & 0 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ fn main() -> Result<()> {
Args::GptJ { args } => handle_args::<llm::models::GptJ>(args),
Args::GptNeoX { args } => handle_args::<llm::models::GptNeoX>(args),
Args::Mpt { args } => handle_args::<llm::models::Mpt>(args),
#[cfg(feature = "falcon")]
Args::Falcon { args } => handle_args::<llm::models::Falcon>(args),
}
}

Expand Down
23 changes: 15 additions & 8 deletions binaries/precommit-check/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
fn main() {
// Ensure that these match `.github/workflows/rust.yml`.
cmd("cargo", &["check"]);
cmd("cargo", &["test", "--all"]);
cmd("cargo", &["fmt", "--check", "--all"]);
cmd("cargo", &["doc", "--workspace", "--exclude", "llm-cli"]);
cmd("cargo", &["clippy", "--", "-Dclippy::all"]);
cmd("cargo", &["check"], &[]);
cmd("cargo", &["test", "--all"], &[]);
cmd("cargo", &["fmt", "--check", "--all"], &[]);
cmd(
"cargo",
&["doc", "--workspace", "--exclude", "llm-cli"],
&[("RUSTDOCFLAGS", "-Dwarnings")],
);
cmd("cargo", &["clippy", "--", "-Dclippy::all"], &[]);
}

fn cmd(cmd: &str, args: &[&str]) {
fn cmd(cmd: &str, args: &[&str], env: &[(&str, &str)]) {
println!("=== Running command: {cmd} {args:?}");
let mut child = std::process::Command::new(cmd).args(args).spawn().unwrap();
let mut builder = std::process::Command::new(cmd);
builder.args(args);
builder.envs(env.iter().copied());
let mut child = builder.spawn().unwrap();
if !child.wait().unwrap().success() {
panic!("Failed to run command: {} {:?}", cmd, args);
panic!("Failed to run command: {} {:?}", cmd, builder);
}
}
5 changes: 5 additions & 0 deletions crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ llm-gptj = { path = "../models/gptj", optional = true, version = "0.2.0-dev" }
llm-bloom = { path = "../models/bloom", optional = true, version = "0.2.0-dev" }
llm-gptneox = { path = "../models/gptneox", optional = true, version = "0.2.0-dev" }
llm-mpt = { path = "../models/mpt", optional = true, version = "0.2.0-dev" }
llm-falcon = { path = "../models/falcon", optional = true, version = "0.2.0-dev" }

serde = { workspace = true }

Expand All @@ -29,12 +30,16 @@ clap = { workspace = true }

[features]
default = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt"]

llama = ["dep:llm-llama"]
gpt2 = ["dep:llm-gpt2"]
gptj = ["dep:llm-gptj"]
bloom = ["dep:llm-bloom"]
gptneox = ["dep:llm-gptneox"]
mpt = ["dep:llm-mpt"]
# Falcon is off by default. See `llm_falcon`'s module documentation for more information.
falcon = ["dep:llm-falcon"]

cublas = ["llm-base/cublas"]
clblast = ["llm-base/clblast"]
metal = ["llm-base/metal"]
16 changes: 16 additions & 0 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//! - [GPT-NeoX](llm_gptneox)
//! - [LLaMA](llm_llama)
//! - [MPT](llm_mpt)
//! - Falcon (currently disabled due to incompleteness)
//!
//! At present, the only supported backend is [GGML](https://github.com/ggerganov/ggml), but this is expected to
//! change in the future.
Expand Down Expand Up @@ -92,6 +93,8 @@ use serde::Serialize;
pub mod models {
#[cfg(feature = "bloom")]
pub use llm_bloom::{self as bloom, Bloom};
#[cfg(feature = "falcon")]
pub use llm_falcon::{self as falcon, Falcon};
#[cfg(feature = "gpt2")]
pub use llm_gpt2::{self as gpt2, Gpt2};
#[cfg(feature = "gptj")]
Expand Down Expand Up @@ -125,6 +128,9 @@ pub enum ModelArchitecture {
#[cfg(feature = "mpt")]
/// [MPT](llm_mpt)
Mpt,
#[cfg(feature = "falcon")]
/// [Falcon](llm_falcon)
Falcon,
}

impl ModelArchitecture {
Expand All @@ -142,6 +148,8 @@ impl ModelArchitecture {
Self::Llama,
#[cfg(feature = "mpt")]
Self::Mpt,
#[cfg(feature = "falcon")]
Self::Falcon,
];
}

Expand Down Expand Up @@ -185,6 +193,8 @@ impl FromStr for ModelArchitecture {
"llama" => Ok(Llama),
#[cfg(feature = "mpt")]
"mpt" => Ok(Mpt),
#[cfg(feature = "falcon")]
"falcon" => Ok(Falcon),

_ => Err(UnsupportedModelArchitecture(format!(
"{s} is not a supported model architecture"
Expand All @@ -210,6 +220,8 @@ impl Display for ModelArchitecture {
Llama => write!(f, "LLaMA"),
#[cfg(feature = "mpt")]
Mpt => write!(f, "MPT"),
#[cfg(feature = "falcon")]
Falcon => write!(f, "Falcon"),
}
}
}
Expand Down Expand Up @@ -264,6 +276,10 @@ pub fn load_dynamic(
}
#[cfg(feature = "mpt")]
Mpt => load_model::<models::Mpt>(path, vocabulary_source, params, load_progress_callback)?,
#[cfg(feature = "falcon")]
Falcon => {
load_model::<models::Falcon>(path, vocabulary_source, params, load_progress_callback)?
}
};

Ok(model)
Expand Down
13 changes: 13 additions & 0 deletions crates/models/falcon/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "llm-falcon"
version = "0.2.0-dev"
license = { workspace = true }
repository = { workspace = true }
description = "An implementation of Falcon for the `llm` ecosystem."
edition = "2021"
readme = "../../../README.md"

[dependencies]
llm-base = { path = "../../llm-base", version = "0.2.0-dev" }

bytemuck = { workspace = true }
Loading