Skip to content

Commit

Permalink
Context size consistency & fixes
Browse files Browse the repository at this point in the history
Fairly certain this fixes rustformers#167
  • Loading branch information
danforbes committed May 20, 2023
1 parent 8c1acc0 commit 8048012
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl InferenceSession {
let vocab = model.vocabulary();
let prompt_tokens = prompt.into().to_tokens(vocab, beginning_of_sentence)?;

if self.n_past + prompt_tokens.len() >= model.n_context_tokens() {
if self.n_past + prompt_tokens.len() >= model.context_size() {
return Err(InferenceError::ContextFull);
}

Expand Down Expand Up @@ -119,7 +119,7 @@ impl InferenceSession {
output_request: &mut OutputRequest,
rng: &mut impl rand::Rng,
) -> Result<&'v [u8], InferenceError> {
if self.n_past + 1 >= model.n_context_tokens() {
if self.n_past + 1 >= model.context_size() {
return Err(InferenceError::ContextFull);
}

Expand Down Expand Up @@ -241,7 +241,7 @@ impl InferenceSession {
let mut count = 0;

// TODO: make this handle <n_ctx tokens
let n_ctx = model.n_context_tokens();
let n_ctx = model.context_size();
let n_chunk = tokens.len() / n_ctx;
let n_vocab = model.vocabulary().len();
let n_batch = parameters.n_batch;
Expand Down
4 changes: 2 additions & 2 deletions crates/llm-base/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ pub trait Model: Send + Sync {

/// Get the context size (configured with [ModelParameters::context_size]) used by
/// this model.
fn n_context_tokens(&self) -> usize;
fn context_size(&self) -> usize;

/// Get the beginning of text/beginning of string token ID, if available. This value is defined by model implementers.
fn bot_token_id(&self) -> Option<TokenId>;
Expand Down Expand Up @@ -225,7 +225,7 @@ impl<H: Hyperparameters, M: KnownModel<Hyperparameters = H>> Model for M {
KnownModel::vocabulary(self)
}

fn n_context_tokens(&self) -> usize {
fn context_size(&self) -> usize {
KnownModel::context_size(self)
}

Expand Down
2 changes: 1 addition & 1 deletion crates/models/gpt2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ impl KnownModel for Gpt2 {
}

fn context_size(&self) -> usize {
self.hyperparameters.n_ctx
self.context_size
}

fn bot_token_id(&self) -> Option<TokenId> {
Expand Down
2 changes: 1 addition & 1 deletion crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ impl KnownModel for GptJ {
}

fn context_size(&self) -> usize {
self.hyperparameters.n_ctx
self.context_size
}

fn bot_token_id(&self) -> Option<TokenId> {
Expand Down
2 changes: 1 addition & 1 deletion crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ impl KnownModel for GptNeoX {
}

fn context_size(&self) -> usize {
self.hyperparameters.n_ctx
self.context_size
}

fn bot_token_id(&self) -> Option<TokenId> {
Expand Down

0 comments on commit 8048012

Please sign in to comment.