Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Custom RoPE Scaling #389

Merged
merged 3 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use clap::{Parser, ValueEnum};
use color_eyre::eyre::{self, WrapErr};
use llm::{
ggml_format, ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias,
LoadProgress, Model, ModelKVMemoryType, ModelParameters, TokenBias, TokenizerSource,
LoadProgress, Model, ModelKVMemoryType, ModelParameters, RoPEOverrides, TokenBias,
TokenizerSource,
};
use rand::SeedableRng;

Expand Down Expand Up @@ -430,6 +431,29 @@ impl ModelAndTokenizer {
}
}

#[derive(Parser, Debug)]
pub struct RoPEScaling {
#[arg(long)]
pub rope_freq_base: Option<usize>,

#[arg(long)]
pub rope_freq_scale: Option<f32>,
}

impl RoPEScaling {
pub fn to_rope_arguments(&self) -> Option<RoPEOverrides> {
if self.rope_freq_base.is_none() && self.rope_freq_scale.is_none() {
return None;
}

let default = RoPEOverrides::default();
Some(RoPEOverrides {
frequency_base: self.rope_freq_base.unwrap_or(default.frequency_base),
frequency_scale: self.rope_freq_scale.unwrap_or(default.frequency_scale),
})
}
}

#[derive(Parser, Debug)]
pub struct ModelLoad {
#[command(flatten)]
Expand Down Expand Up @@ -460,7 +484,11 @@ pub struct ModelLoad {
/// Number of layers to run on the GPU. If not specified, all layers will be run on the GPU.
#[arg(long)]
pub gpu_layers: Option<usize>,

#[command(flatten)]
pub rope_scaling: RoPEScaling,
}

impl ModelLoad {
pub fn load(&self, use_gpu: bool) -> eyre::Result<Box<dyn Model>> {
let params = ModelParameters {
Expand All @@ -469,6 +497,7 @@ impl ModelLoad {
lora_adapters: self.lora_paths.clone(),
use_gpu,
gpu_layers: self.gpu_layers,
rope_overrides: self.rope_scaling.to_rope_arguments(),
};

let mut sp = Some(spinoff::Spinner::new(
Expand Down
45 changes: 34 additions & 11 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use std::{

use memmap2::Mmap;

use crate::{accelerator::Backend, sys, usize_to_i32, usize_to_i64, Buffer, Tensor, Type};
use crate::{
accelerator::Backend, sys, usize_to_i32, usize_to_i64, Buffer, RoPEOverrides, Tensor, Type,
};

/// Acts as a RAII-guard over a `sys::ggml_context`, allocating via
/// `ggml_init` and dropping via `ggml_free`.
Expand Down Expand Up @@ -267,7 +269,8 @@ impl Context {

/// Creates a new tensor with the values of `a`, but normalized using RMSNorm.
pub fn op_rms_norm(&self, a: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_rms_norm(self.as_ptr(), a.ptr.as_ptr()) };
let tensor =
unsafe { sys::ggml_rms_norm(self.as_ptr(), a.ptr.as_ptr(), crate::DEFAULT_EPS) };
self.new_tensor_raw(tensor)
}

Expand Down Expand Up @@ -527,16 +530,36 @@ impl Context {
}

/// In-place; applies ROtary Positional Encoding.
pub fn op_rope_inplace(&self, a: &Tensor, npast: usize, ndims: usize, mode: i32) -> Tensor {
pub fn op_rope_inplace(
&self,
a: &Tensor,
npast: usize,
ndims: usize,
mode: i32,
overrides: Option<&RoPEOverrides>,
) -> Tensor {
let tensor = unsafe {
sys::ggml_rope_inplace(
self.as_ptr(),
a.ptr.as_ptr(),
usize_to_i32(npast),
usize_to_i32(ndims),
mode,
0,
)
if let Some(custom_args) = overrides {
sys::ggml_rope_custom_inplace(
self.as_ptr(),
a.ptr.as_ptr(),
usize_to_i32(npast),
usize_to_i32(ndims),
mode,
1,
custom_args.frequency_base as f32,
custom_args.frequency_scale,
)
} else {
sys::ggml_rope_inplace(
self.as_ptr(),
a.ptr.as_ptr(),
usize_to_i32(npast),
usize_to_i32(ndims),
mode,
0,
)
}
};
self.new_tensor_raw(tensor)
}
Expand Down
23 changes: 23 additions & 0 deletions crates/ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,29 @@ pub const OBJECT_SIZE: usize = sys::GGML_OBJECT_SIZE;
/// The maximum length of a `ggml` tensor-name.
pub const MAX_NAME_LENGTH: usize = sys::GGML_MAX_NAME as usize;

/// Default epsilon to use for RMS computation.
pub const DEFAULT_EPS: f32 = sys::llama::LLAMA_DEFAULT_RMS_EPS as f32;

/// Value overrides to use for RoPE.
///
/// Formula: `theta_i = scale * base^(−2(i−1)/d), for i in [1, 2, ..., d/2]`
#[derive(Debug, Clone)]
pub struct RoPEOverrides {
/// The frequency scale to use.
pub frequency_scale: f32,
/// The frequency base value to use.
pub frequency_base: usize,
}

impl Default for RoPEOverrides {
fn default() -> Self {
Self {
frequency_scale: 1.0,
frequency_base: 10_000,
}
}
}

#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
/// The type of a value in `ggml`.
pub enum Type {
Expand Down
2 changes: 1 addition & 1 deletion crates/ggml/sys/llama-cpp
Submodule llama-cpp updated 61 files
+2 −0 .github/workflows/build.yml
+19 −1 .gitignore
+26 −0 CMakeLists.txt
+74 −18 Makefile
+18 −1 README.md
+25 −0 ci/README.md
+409 −0 ci/run.sh
+1 −0 convert-lora-to-ggml.py
+92 −44 convert.py
+2 −0 examples/CMakeLists.txt
+9 −8 examples/Miku.sh
+1 −0 examples/baby-llama/CMakeLists.txt
+15 −9 examples/baby-llama/baby-llama.cpp
+1 −0 examples/benchmark/CMakeLists.txt
+145 −96 examples/common.cpp
+16 −11 examples/common.h
+2 −0 examples/embd-input/CMakeLists.txt
+1 −1 examples/embd-input/minigpt4.py
+1 −0 examples/embedding/CMakeLists.txt
+423 −0 examples/grammar-parser.cpp
+29 −0 examples/grammar-parser.h
+18 −0 examples/llama2-13b.sh
+18 −0 examples/llama2.sh
+23 −0 examples/llm.vim
+1 −0 examples/main/CMakeLists.txt
+86 −27 examples/main/main.cpp
+92 −0 examples/make-ggml.py
+1 −0 examples/metal/CMakeLists.txt
+1 −0 examples/perplexity/CMakeLists.txt
+81 −3 examples/perplexity/perplexity.cpp
+1 −0 examples/quantize-stats/CMakeLists.txt
+1 −0 examples/quantize/CMakeLists.txt
+19 −95 examples/quantize/quantize.cpp
+1 −0 examples/save-load-state/CMakeLists.txt
+4 −0 examples/server/CMakeLists.txt
+668 −414 examples/server/index.html.hpp
+97 −28 examples/server/public/index.html
+56 −35 examples/server/server.cpp
+1 −0 examples/simple/CMakeLists.txt
+1 −0 examples/train-text-from-scratch/CMakeLists.txt
+20 −18 examples/train-text-from-scratch/train-text-from-scratch.cpp
+55 −43 flake.nix
+371 −108 ggml-cuda.cu
+7 −0 ggml-metal.h
+227 −91 ggml-metal.m
+605 −508 ggml-metal.metal
+723 −984 ggml.c
+86 −20 ggml.h
+6 −0 grammars/arithmetic.gbnf
+13 −0 grammars/chess.gbnf
+7 −0 grammars/japanese.gbnf
+29 −0 grammars/json.gbnf
+4 −0 grammars/list.gbnf
+327 −3 k_quants.c
+547 −139 llama.cpp
+65 −9 llama.h
+3 −2 scripts/build-info.sh
+1 −0 tests/CMakeLists.txt
+395 −86 tests/test-grad0.c
+3 −3 tests/test-opt.c
+2 −0 tests/test-sampling.cpp
Loading