Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Add HuggingFace's Tokenizer #271

Merged
merged 22 commits into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,651 changes: 1,591 additions & 60 deletions Cargo.lock

Large diffs are not rendered by default.

70 changes: 63 additions & 7 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::{fmt, ops::Deref, path::PathBuf};

use clap::{Parser, Subcommand, ValueEnum};
use color_eyre::eyre::{Result, WrapErr};
use color_eyre::eyre::{eyre, Result, WrapErr};
use llm::{
ggml_format, ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias,
LoadProgress, Model, ModelKVMemoryType, ModelParameters, TokenBias,
LoadProgress, Model, ModelKVMemoryType, ModelParameters, TokenBias, VocabularySource,
};
use rand::SeedableRng;

Expand Down Expand Up @@ -141,9 +141,8 @@ pub struct Perplexity {

#[derive(Parser, Debug)]
pub struct Info {
/// The model to inspect.
#[arg(long, short = 'm')]
pub model_path: PathBuf,
#[command(flatten)]
pub model_and_vocabulary: ModelAndVocabulary,

/// Show all of the tensors in the model, including their names, formats and shapes.
#[arg(long, short = 't')]
Expand Down Expand Up @@ -331,11 +330,62 @@ fn parse_bias(s: &str) -> Result<TokenBias, InvalidTokenBias> {
}

#[derive(Parser, Debug)]
pub struct ModelLoad {
pub struct ModelVocabulary {
/// Local path to vocabulary
#[arg(long, short = 'v')]
pub vocabulary_path: Option<PathBuf>,

/// Remote HuggingFace repository containing vocabulary
#[arg(long, short = 'r')]
pub vocabulary_repository: Option<String>,
}
impl ModelVocabulary {
pub fn to_source(&self, sp: &mut Option<spinoff::Spinner>) -> Result<VocabularySource> {
Ok(match (&self.vocabulary_path, &self.vocabulary_repository) {
(Some(_), Some(_)) => {
if let Some(sp) = sp.take() {
sp.fail("Invalid arguments");
};

return Err(eyre!(
"Cannot specify both --vocabulary-path and --vocabulary-repository"
));
}
(Some(path), None) => VocabularySource::HuggingFaceTokenizerFile(path.to_owned()),
(None, Some(repo)) => VocabularySource::HuggingFaceRemote(repo.to_owned()),
(None, None) => VocabularySource::Model,
})
}
}

#[derive(Parser, Debug)]
pub struct ModelAndVocabulary {
/// Where to load the model from
#[arg(long, short = 'm')]
pub model_path: PathBuf,

#[command(flatten)]
pub vocabulary: ModelVocabulary,

/// Local path to vocabulary
#[arg(long, short = 'v')]
pub vocabulary_path: Option<PathBuf>,

/// Remote HuggingFace repository containing vocabulary
#[arg(long, short = 'r')]
pub vocabulary_repository: Option<String>,
}
impl ModelAndVocabulary {
pub fn to_source(&self, sp: &mut Option<spinoff::Spinner>) -> Result<VocabularySource> {
self.vocabulary.to_source(sp)
}
}

#[derive(Parser, Debug)]
pub struct ModelLoad {
#[command(flatten)]
pub model_and_vocabulary: ModelAndVocabulary,

/// Sets the size of the context (in tokens). Allows feeding longer prompts.
/// Note that this affects memory.
///
Expand Down Expand Up @@ -377,8 +427,11 @@ impl ModelLoad {
let now = std::time::Instant::now();
let mut prev_load_time = now;

let vocabulary_source = self.model_and_vocabulary.to_source(&mut sp)?;

let model = llm::load::<M>(
&self.model_path,
&self.model_and_vocabulary.model_path,
vocabulary_source,
params,
overrides,
|progress| match progress {
Expand Down Expand Up @@ -492,6 +545,9 @@ pub struct Quantize {
#[arg()]
pub destination: PathBuf,

#[command(flatten)]
pub vocabulary: ModelVocabulary,

/// The GGML container type to target.
///
/// Note that using GGML requires the original model to have
Expand Down
28 changes: 17 additions & 11 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ fn infer<M: llm::KnownModel + 'static>(
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load::<M>(overrides)?;

let (mut session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
args.persist_session.as_deref(),
Expand Down Expand Up @@ -147,28 +148,28 @@ fn perplexity<M: llm::KnownModel + 'static>(
}

fn info<M: llm::KnownModel + 'static>(args: &cli_args::Info) -> Result<()> {
let file = File::open(&args.model_path)?;
let model_path = &args.model_and_vocabulary.model_path;
let vocabulary = args
.model_and_vocabulary
.to_source(&mut None)?
.retrieve(model_path)?;

let file = File::open(model_path)?;
let mut reader = BufReader::new(&file);
let mut loader: llm::Loader<M::Hyperparameters, _> = llm::Loader::new(|_| {
let mut loader: llm::Loader<M::Hyperparameters, _> = llm::Loader::new(vocabulary, |_| {
// We purposely do not print progress here, as we are only interested in the metadata
});

llm::ggml_format::load(&mut reader, &mut loader)?;

log::info!("Container type: {:?}", loader.container_type);
log::info!("Hyperparameters: {:?}", loader.hyperparameters);
log::info!("Vocabulary size: {}", loader.vocabulary.id_to_token.len());
log::info!("Vocabulary size: {}", loader.vocabulary.len());

if args.vocabulary {
log::info!("Vocabulary:");
for (tid, (token, score)) in loader
.vocabulary
.id_to_token
.iter()
.zip(loader.vocabulary.id_to_token_score.iter())
.enumerate()
{
log::info!("- {}: {} ({})", tid, utf8_or_array(token), score);
for i in 0..loader.vocabulary.len() {
log::info!("- {}: {}", i, utf8_or_array(&loader.vocabulary.token(i)));
}
}

Expand Down Expand Up @@ -320,10 +321,15 @@ fn quantize<M: llm::KnownModel + 'static>(args: &cli_args::Quantize) -> Result<(

let mut source = BufReader::new(std::fs::File::open(&args.source)?);
let mut destination = BufWriter::new(std::fs::File::create(&args.destination)?);
let vocabulary = args
.vocabulary
.to_source(&mut None)?
.retrieve(&args.source)?;

llm::quantize::<M, _, _>(
&mut source,
&mut destination,
vocabulary,
args.container_type.into(),
args.target.into(),
|progress| match progress {
Expand Down
1 change: 1 addition & 0 deletions crates/llm-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ partial_sort = "0.2.0"
serde_bytes = "0.11"
memmap2 = "0.5.10"
half = "2.2.1"
tokenizers = "0.13.3"
regex = "1.8"
46 changes: 39 additions & 7 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ pub struct InferenceSession {
/// All tokens generated by this inference session
pub(crate) tokens: Vec<TokenId>,

// All decoded tokens generated by this inference session
pub(crate) decoded_tokens: Vec<u8>,

/// The logits that were last predicted by the network. Zeroed out otherwise.
#[doc(hidden)]
pub last_logits: Vec<f32>,
Expand Down Expand Up @@ -91,10 +94,23 @@ impl InferenceSession {
for &tk in batch {
let should_call_callback = Some(tk) != model.bot_token_id();

let mut token = match model.vocabulary() {
crate::Vocabulary::Model(_) => model.vocabulary().token(tk as usize).to_vec(),
crate::Vocabulary::External(_) => {
let mut previous_tokens = self.tokens.clone();
previous_tokens.push(tk);

let all_tokens = model.vocabulary().decode(previous_tokens, true);
let splitted = all_tokens.split_at(self.decoded_tokens.len());

splitted.1.to_vec()
}
};

if should_call_callback {
// NOTE: No string ever tokenizes to the end of sentence. So we
// can just return the id here.
match callback(vocab.token(tk as usize)) {
match callback(&token) {
Err(e) => return Err(InferenceError::UserCallback(Box::new(e))),
Ok(f) => match f {
InferenceFeedback::Continue => (),
Expand All @@ -105,20 +121,21 @@ impl InferenceSession {

// Update the tokens for this session
self.tokens.push(tk);
self.decoded_tokens.append(&mut token);
}
}

Ok(())
}

/// Infer the next token for this session.
pub fn infer_next_token<'v>(
pub fn infer_next_token(
&mut self,
model: &'v dyn Model,
model: &dyn Model,
params: &InferenceParameters,
output_request: &mut OutputRequest,
rng: &mut impl rand::Rng,
) -> Result<&'v [u8], InferenceError> {
) -> Result<Vec<u8>, InferenceError> {
if self.n_past + 1 >= model.context_size() {
return Err(InferenceError::ContextFull);
}
Expand All @@ -136,7 +153,20 @@ impl InferenceSession {
if next_token as TokenId == model.eot_token_id() {
Err(InferenceError::EndOfText)
} else {
Ok(model.vocabulary().token(next_token as usize))
let res = match model.vocabulary() {
crate::Vocabulary::Model(_) => {
model.vocabulary().token(next_token as usize).to_vec()
}
crate::Vocabulary::External(_) => {
let all_tokens = model.vocabulary().decode(self.tokens.clone(), true);
let splitted = all_tokens.split_at(self.decoded_tokens.len());

splitted.1.to_vec()
}
};

self.decoded_tokens.append(&mut res.clone());
Ok(res)
}
}

Expand All @@ -163,7 +193,7 @@ impl InferenceSession {
for token_id in &self.tokens {
// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) =
token_utf8_buf.push(model.vocabulary().token(*token_id as usize))
token_utf8_buf.push(&model.vocabulary().token(*token_id as usize))
{
if let Err(e) = callback(InferenceResponse::SnapshotToken(tokens)) {
return Err(InferenceError::UserCallback(Box::new(e)));
Expand Down Expand Up @@ -204,7 +234,7 @@ impl InferenceSession {
};

// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) = token_utf8_buf.push(token) {
if let Some(tokens) = token_utf8_buf.push(&token) {
match callback(InferenceResponse::InferredToken(tokens)) {
Err(e) => return Err(InferenceError::UserCallback(Box::new(e))),
Ok(f) => match f {
Expand Down Expand Up @@ -493,6 +523,7 @@ impl InferenceSession {
n_past: 0,
mem_per_token: 0,
tokens: vec![],
decoded_tokens: vec![],
last_logits: vec![0.0; n_vocab],
scratch: scratch_buffers(),
}
Expand All @@ -513,6 +544,7 @@ impl Clone for InferenceSession {
n_past: self.n_past,
mem_per_token: self.mem_per_token,
tokens: self.tokens.clone(),
decoded_tokens: self.decoded_tokens.clone(),
last_logits: self.last_logits.clone(),
scratch: scratch_buffers(),
}
Expand Down
6 changes: 5 additions & 1 deletion crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ pub use model::{
pub use quantize::{quantize, QuantizeError, QuantizeProgress};
pub use regex::Regex;
pub use util::TokenUtf8Buffer;
pub use vocabulary::{InvalidTokenBias, Prompt, TokenBias, TokenId, TokenizationError, Vocabulary};
pub(crate) use vocabulary::ModelVocabulary;
pub use vocabulary::{
InvalidTokenBias, Prompt, TokenBias, TokenId, TokenizationError, Vocabulary,
VocabularyLoadError, VocabularySource,
};

#[derive(Clone, Debug, PartialEq)]
/// The parameters for text generation.
Expand Down
Loading