From 741a77d537e2a04a5a7748e796f954f453b87cfd Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Tue, 4 Jun 2024 21:07:02 -0400 Subject: [PATCH 1/2] Organize normal loading metadata --- mistralrs-core/src/models/gemma.rs | 19 +-- mistralrs-core/src/models/llama.rs | 21 +-- mistralrs-core/src/models/mistral.rs | 35 ++--- mistralrs-core/src/models/mixtral.rs | 21 +-- mistralrs-core/src/models/phi2.rs | 21 +-- mistralrs-core/src/models/phi3.rs | 21 +-- mistralrs-core/src/models/qwen2.rs | 21 +-- mistralrs-core/src/pipeline/macros.rs | 24 ++-- mistralrs-core/src/pipeline/mod.rs | 4 +- mistralrs-core/src/pipeline/normal_loaders.rs | 126 ++++++------------ mistralrs-core/src/xlora_models/gemma.rs | 19 +-- mistralrs-core/src/xlora_models/llama.rs | 21 +-- mistralrs-core/src/xlora_models/mistral.rs | 21 +-- mistralrs-core/src/xlora_models/mixtral.rs | 21 +-- mistralrs-core/src/xlora_models/phi2.rs | 21 +-- mistralrs-core/src/xlora_models/phi3.rs | 21 +-- 16 files changed, 199 insertions(+), 238 deletions(-) diff --git a/mistralrs-core/src/models/gemma.rs b/mistralrs-core/src/models/gemma.rs index a74b1558cc..89871580f5 100644 --- a/mistralrs-core/src/models/gemma.rs +++ b/mistralrs-core/src/models/gemma.rs @@ -8,8 +8,7 @@ use candle_nn::{linear_b as linear, Activation, RotaryEmbedding, VarBuilder}; use crate::{ device_map::DeviceMapper, layers::{repeat_kv, CausalMasker, MatMul, QLinear, ScaledDotProductAttention}, - pipeline::{extract_logits, Cache, IsqModel, NormalModel}, - DeviceMapMetadata, + pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata, NormalModel}, }; fn default_max_position_embeddings() -> usize { @@ -319,11 +318,11 @@ impl Model { cfg: &Config, vb: VarBuilder, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let vb_m = vb.pp("model"); let embed_tokens = candle_nn::embedding( cfg.vocab_size, @@ -337,7 +336,9 @@ impl Model { cfg.rope_theta as f32, cfg.head_dim, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?); @@ -347,7 +348,7 @@ impl Model { vb_l.pp(layer_idx), &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, )?; layers.push(layer) } @@ -362,7 +363,7 @@ impl Model { layers, norm, lm_head, - device: real_device, + device: normal_loading_metadata.real_device, hidden_size: cfg.hidden_size, cache: Cache::new(cfg.num_hidden_layers, false), max_seq_len: default_max_position_embeddings(), diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 55554a4679..ab2b1df5a9 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -10,8 +10,7 @@ use std::sync::Arc; use crate::{ device_map::DeviceMapper, layers::{repeat_kv, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention}, - pipeline::{extract_logits, IsqModel, NormalModel}, - DeviceMapMetadata, + pipeline::{extract_logits, IsqModel, NormalLoadingMetadata, NormalModel}, }; #[derive(Debug, Clone, Deserialize)] @@ -294,11 +293,11 @@ impl Llama { cfg: &Config, vb: VarBuilder, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let wte = embedding( cfg.vocab_size, cfg.hidden_size, @@ -307,7 +306,7 @@ impl Llama { let lm_head = linear( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; let ln_f = RmsNorm::new( cfg.hidden_size, @@ -322,7 +321,9 @@ impl Llama { cfg.rope_theta, head_dim, cfg.max_position_embeddings, - mapper.device_for(i, false).unwrap_or(&real_device), + mapper + .device_for(i, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), ) @@ -333,7 +334,7 @@ impl Llama { cfg, &*mapper, i, - loading_isq, + normal_loading_metadata.loading_isq, rotary_emb, ) .expect("Failed to load block.") @@ -346,7 +347,7 @@ impl Llama { ln_f, lm_head: QMatMul::Tensor(lm_head.weight().clone()), kv_cache: crate::pipeline::Cache::new(cfg.num_hidden_layers, false), - device: real_device, + device: normal_loading_metadata.real_device, mapper, }) } diff --git a/mistralrs-core/src/models/mistral.rs b/mistralrs-core/src/models/mistral.rs index d7f2c8b099..f41dd11c1c 100644 --- a/mistralrs-core/src/models/mistral.rs +++ b/mistralrs-core/src/models/mistral.rs @@ -8,8 +8,7 @@ use std::sync::Arc; use crate::{ device_map::DeviceMapper, layers::{repeat_kv, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention}, - pipeline::{extract_logits, Cache, IsqModel, NormalModel}, - DeviceMapMetadata, + pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata, NormalModel}, }; #[derive(Debug, Clone, PartialEq)] @@ -280,21 +279,11 @@ impl Model { cfg: &Config, vb: VarBuilder, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { let vb_m = vb.pp("model"); let vb_lm_head = vb.pp("lm_head"); - Self::new_inner( - cfg, - vb_m, - vb_lm_head, - is_gptx, - mapper, - loading_isq, - real_device, - ) + Self::new_inner(cfg, vb_m, vb_lm_head, is_gptx, normal_loading_metadata) } pub fn new_inner( @@ -302,11 +291,11 @@ impl Model { vb_m: VarBuilder, vb_lm_head: VarBuilder, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = candle_nn::embedding( cfg.vocab_size, cfg.hidden_size, @@ -320,7 +309,9 @@ impl Model { cfg.rope_theta as f32, head_dim, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb_m.dtype(), )?); @@ -330,7 +321,7 @@ impl Model { vb_l.pp(layer_idx), &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, )?; layers.push(layer) } @@ -342,7 +333,7 @@ impl Model { let lm_head = linear_no_bias( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb_lm_head, loading_isq), + mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, @@ -350,7 +341,7 @@ impl Model { norm, lm_head: QMatMul::Tensor(lm_head.weight().clone()), sliding_window: cfg.sliding_window, - device: real_device, + device: normal_loading_metadata.real_device, cache: Cache::new(cfg.num_hidden_layers, false), max_seq_len: cfg.max_position_embeddings, mapper, diff --git a/mistralrs-core/src/models/mixtral.rs b/mistralrs-core/src/models/mixtral.rs index 11cb332942..0299037d0c 100644 --- a/mistralrs-core/src/models/mixtral.rs +++ b/mistralrs-core/src/models/mixtral.rs @@ -11,8 +11,7 @@ use std::sync::Arc; use crate::{ device_map::DeviceMapper, layers::{repeat_kv, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention}, - pipeline::{extract_logits, Cache, IsqModel, NormalModel}, - DeviceMapMetadata, + pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata, NormalModel}, }; /// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113 @@ -383,12 +382,12 @@ impl Model { cfg: &Config, vb: VarBuilder, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = candle_nn::embedding( cfg.vocab_size, cfg.hidden_size, @@ -402,7 +401,9 @@ impl Model { cfg.rope_theta as f32, head_dim, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?); @@ -412,7 +413,7 @@ impl Model { vb_l.pp(layer_idx), &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, )?; layers.push(layer) } @@ -424,7 +425,7 @@ impl Model { let lm_head = linear_no_bias( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, @@ -432,7 +433,7 @@ impl Model { norm, lm_head: QMatMul::Tensor(lm_head.weight().clone()), sliding_window: cfg.sliding_window, - device: real_device, + device: normal_loading_metadata.real_device, cache: Cache::new(cfg.num_hidden_layers, false), max_seq_len: cfg.max_position_embeddings, mapper, diff --git a/mistralrs-core/src/models/phi2.rs b/mistralrs-core/src/models/phi2.rs index d1a8a89cd2..187e1fcf83 100644 --- a/mistralrs-core/src/models/phi2.rs +++ b/mistralrs-core/src/models/phi2.rs @@ -14,8 +14,7 @@ use serde::Deserialize; use crate::{ device_map::DeviceMapper, layers::{repeat_kv, CausalMasker, QLinear, ScaledDotProductAttention}, - pipeline::{extract_logits, Cache, IsqModel, NormalModel}, - DeviceMapMetadata, + pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata, NormalModel}, }; // https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py @@ -287,12 +286,12 @@ impl Model { cfg: &Config, vb: VarBuilder, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = embedding( cfg.vocab_size, cfg.hidden_size, @@ -312,7 +311,9 @@ impl Model { cfg.head_dim(), (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?; @@ -321,7 +322,7 @@ impl Model { vb_m.pp(layer_idx), &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, rotary_emb, )?; layers.push(layer) @@ -329,7 +330,7 @@ impl Model { let lm_head = linear( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, @@ -337,7 +338,7 @@ impl Model { final_layernorm, lm_head: QLinear::from_linear(lm_head), cache: Cache::new(cfg.num_hidden_layers, false), - device: real_device, + device: normal_loading_metadata.real_device, max_seq_len: cfg.max_position_embeddings, mapper, }) diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index c00e60ea75..1fbc1218ba 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -13,8 +13,7 @@ use crate::{ repeat_kv, CausalMasker, MatMul, PhiRopeConfig, PhiRotaryEmbedding, RmsNorm, ScaledDotProductAttention, }, - pipeline::{extract_logits, Cache, IsqModel, NormalModel}, - DeviceMapMetadata, + pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata, NormalModel}, }; // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json @@ -291,12 +290,12 @@ impl Model { cfg: &Config, vb: VarBuilder, _is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = candle_nn::embedding( cfg.vocab_size, cfg.hidden_size, @@ -308,7 +307,9 @@ impl Model { let rotary_emb = Arc::new(PhiRotaryEmbedding::new( vb.dtype(), cfg.clone(), - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), )?); let layer = DecoderLayer::new( rotary_emb.clone(), @@ -316,7 +317,7 @@ impl Model { vb_l.pp(layer_idx), &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, )?; layers.push(layer) } @@ -328,14 +329,14 @@ impl Model { let lm_head = linear_no_bias( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, layers, norm, lm_head: QMatMul::Tensor(lm_head.weight().clone()), - device: real_device, + device: normal_loading_metadata.real_device, cache: Cache::new(cfg.num_hidden_layers, false), max_seq_len: cfg.max_position_embeddings, mapper, diff --git a/mistralrs-core/src/models/qwen2.rs b/mistralrs-core/src/models/qwen2.rs index eae06e4aba..a6103fb507 100644 --- a/mistralrs-core/src/models/qwen2.rs +++ b/mistralrs-core/src/models/qwen2.rs @@ -7,8 +7,7 @@ use std::sync::Arc; use crate::{ device_map::DeviceMapper, layers::{repeat_kv, CausalMasker, MatMul, QLinear, RmsNorm, ScaledDotProductAttention}, - pipeline::{extract_logits, Cache, IsqModel, NormalModel}, - DeviceMapMetadata, + pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata, NormalModel}, }; #[derive(Debug, Clone, PartialEq, serde::Deserialize)] @@ -269,12 +268,12 @@ impl Model { cfg: &Config, vb: VarBuilder, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = candle_nn::embedding( cfg.vocab_size, cfg.hidden_size, @@ -288,7 +287,9 @@ impl Model { cfg.rope_theta as f32, head_dim, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?); @@ -298,7 +299,7 @@ impl Model { vb_l.pp(layer_idx), &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, )?; layers.push(layer) } @@ -310,7 +311,7 @@ impl Model { let lm_head = linear_no_bias( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, @@ -318,7 +319,7 @@ impl Model { norm, lm_head: QMatMul::Tensor(lm_head.weight().clone()), sliding_window: cfg.sliding_window, - device: real_device, + device: normal_loading_metadata.real_device, cache: Cache::new(cfg.num_hidden_layers, false), max_seq_len: cfg.max_position_embeddings, mapper, diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index 46de131673..d3795c4137 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -314,9 +314,11 @@ macro_rules! normal_model_loader { &$config, $use_flash_attn, vb, - $mapper, - $loading_isq, - $real_device, + $crate::pipeline::NormalLoadingMetadata { + mapper: $mapper, + loading_isq: $loading_isq, + real_device: $real_device, + }, )? }}; } @@ -372,9 +374,11 @@ macro_rules! xlora_model_loader { $paths.get_adapter_configs().as_ref().unwrap(), Some($paths.get_classifier_config().as_ref().unwrap().clone()), $paths.get_ordering().as_ref().unwrap().clone(), - $mapper, - $loading_isq, - $real_device, + $crate::pipeline::NormalLoadingMetadata { + mapper: $mapper, + loading_isq: $loading_isq, + real_device: $real_device, + }, &$crate::utils::varbuilder_utils::load_preload_adapters( $paths.get_lora_preload_adapter_info(), $dtype.unwrap_or($default_dtype), @@ -413,9 +417,11 @@ macro_rules! lora_model_loader { $paths.get_adapter_configs().as_ref().unwrap(), None, $paths.get_ordering().as_ref().unwrap().clone(), - $mapper, - $loading_isq, - $real_device, + $crate::pipeline::NormalLoadingMetadata { + mapper: $mapper, + loading_isq: $loading_isq, + real_device: $real_device, + }, &$crate::utils::varbuilder_utils::load_preload_adapters( $paths.get_lora_preload_adapter_info(), $dtype.unwrap_or($default_dtype), diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 250987be19..262ec0663d 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -24,8 +24,8 @@ pub use gguf::{GGUFArchitecture, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConf pub use isq::IsqModel; pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig}; pub use normal_loaders::{ - GemmaLoader, LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, NormalModelLoader, - Phi2Loader, Phi3Loader, Qwen2Loader, + GemmaLoader, LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, + NormalLoadingMetadata, NormalModelLoader, Phi2Loader, Phi3Loader, Qwen2Loader, }; pub(crate) use paths::{get_chat_template, get_model_paths, get_xlora_paths, XLoraPaths}; pub(crate) use processing::{BasicProcessor, Processor, ProcessorCreator}; diff --git a/mistralrs-core/src/pipeline/normal_loaders.rs b/mistralrs-core/src/pipeline/normal_loaders.rs index 09e0140a98..80add70c4d 100644 --- a/mistralrs-core/src/pipeline/normal_loaders.rs +++ b/mistralrs-core/src/pipeline/normal_loaders.rs @@ -11,15 +11,23 @@ use pyo3::pyclass; use serde::Deserialize; +/// Metadata for loading a model with ISQ or device mapping. +pub struct NormalLoadingMetadata { + // Device mapping metadata which can be used to contstruct a concrete device mapper + pub mapper: DeviceMapMetadata, + // Flag to check if loading in ISQ + pub loading_isq: bool, + // Device mapping target device (the one that is not the cpu) + pub real_device: Device, +} + pub trait NormalModelLoader { fn load( &self, config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result>; #[allow(clippy::too_many_arguments)] fn load_xlora( @@ -30,9 +38,7 @@ pub trait NormalModelLoader { lora_config: &[((String, String), LoraConfig)], xlora_config: Option, xlora_ordering: Ordering, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result>; fn is_gptx(&self) -> bool; @@ -127,17 +133,13 @@ impl NormalModelLoader for MistralLoader { config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result> { Ok(Box::new(models::mistral::Model::new( &MistralBasicConfig::deserialize(config, use_flash_attn)?, vb, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, )?)) } fn load_xlora( @@ -148,9 +150,7 @@ impl NormalModelLoader for MistralLoader { lora_config: &[((String, String), LoraConfig)], xlora_config: Option, xlora_ordering: Ordering, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result> { Ok(Box::new(xlora_models::XLoraMistral::new( @@ -160,9 +160,7 @@ impl NormalModelLoader for MistralLoader { xlora_config, xlora_ordering, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, preload_adapters, )?)) } @@ -233,17 +231,13 @@ impl NormalModelLoader for GemmaLoader { config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result> { Ok(Box::new(models::gemma::Model::new( &GemmaBasicConfig::deserialize(config, use_flash_attn)?, vb, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, )?)) } fn load_xlora( @@ -254,9 +248,7 @@ impl NormalModelLoader for GemmaLoader { lora_config: &[((String, String), LoraConfig)], xlora_config: Option, xlora_ordering: Ordering, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result> { Ok(Box::new(xlora_models::XLoraGemma::new( @@ -266,9 +258,7 @@ impl NormalModelLoader for GemmaLoader { xlora_config, xlora_ordering, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, preload_adapters, )?)) } @@ -331,17 +321,13 @@ impl NormalModelLoader for LlamaLoader { config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result> { Ok(Box::new(models::llama::Llama::new( &LlamaBasicConfig::deserialize(config, use_flash_attn)?, vb, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, )?)) } fn load_xlora( @@ -352,9 +338,7 @@ impl NormalModelLoader for LlamaLoader { lora_config: &[((String, String), LoraConfig)], xlora_config: Option, xlora_ordering: Ordering, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result> { Ok(Box::new(xlora_models::XLoraLlama::new( @@ -364,9 +348,7 @@ impl NormalModelLoader for LlamaLoader { xlora_config, xlora_ordering, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, preload_adapters, )?)) } @@ -430,17 +412,13 @@ impl NormalModelLoader for MixtralLoader { config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result> { Ok(Box::new(models::mixtral::Model::new( &MixtralBasicConfig::deserialize(config, use_flash_attn)?, vb, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, )?)) } fn load_xlora( @@ -451,9 +429,7 @@ impl NormalModelLoader for MixtralLoader { lora_config: &[((String, String), LoraConfig)], xlora_config: Option, xlora_ordering: Ordering, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result> { Ok(Box::new(xlora_models::XLoraMixtral::new( @@ -463,9 +439,7 @@ impl NormalModelLoader for MixtralLoader { xlora_config, xlora_ordering, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, preload_adapters, )?)) } @@ -529,17 +503,13 @@ impl NormalModelLoader for Phi2Loader { config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result> { Ok(Box::new(models::phi2::Model::new( &Phi2BasicConfig::deserialize(config, use_flash_attn)?, vb, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, )?)) } fn load_xlora( @@ -550,9 +520,7 @@ impl NormalModelLoader for Phi2Loader { lora_config: &[((String, String), LoraConfig)], xlora_config: Option, xlora_ordering: Ordering, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result> { Ok(Box::new(xlora_models::XLoraPhi2::new( @@ -562,9 +530,7 @@ impl NormalModelLoader for Phi2Loader { xlora_config, xlora_ordering, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, preload_adapters, )?)) } @@ -639,17 +605,13 @@ impl NormalModelLoader for Phi3Loader { config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result> { Ok(Box::new(models::phi3::Model::new( &Phi3BasicConfig::deserialize(config, use_flash_attn)?, vb, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, )?)) } fn load_xlora( @@ -660,9 +622,7 @@ impl NormalModelLoader for Phi3Loader { lora_config: &[((String, String), LoraConfig)], xlora_config: Option, xlora_ordering: Ordering, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result> { Ok(Box::new(xlora_models::XLoraPhi3::new( @@ -672,9 +632,7 @@ impl NormalModelLoader for Phi3Loader { xlora_config, xlora_ordering, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, preload_adapters, )?)) } @@ -740,17 +698,13 @@ impl NormalModelLoader for Qwen2Loader { config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result> { Ok(Box::new(models::qwen2::Model::new( &Qwen2BasicConfig::deserialize(config, use_flash_attn)?, vb, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, )?)) } fn load_xlora( @@ -761,9 +715,7 @@ impl NormalModelLoader for Qwen2Loader { _lora_config: &[((String, String), LoraConfig)], _xlora_config: Option, _xlora_ordering: Ordering, - _mapper: DeviceMapMetadata, - _loading_isq: bool, - _device: Device, + _normal_loading_metadata: NormalLoadingMetadata, _preload_adapters: &Option>, ) -> Result> { todo!() diff --git a/mistralrs-core/src/xlora_models/gemma.rs b/mistralrs-core/src/xlora_models/gemma.rs index bde161cee8..57bce33492 100644 --- a/mistralrs-core/src/xlora_models/gemma.rs +++ b/mistralrs-core/src/xlora_models/gemma.rs @@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ layers::ScaledDotProductAttention, lora::{linear_b as linear, LinearLayerLike, LoraConfig, Ordering}, - pipeline::IsqModel, + pipeline::{IsqModel, NormalLoadingMetadata}, }; use candle_core::{quantized::QMatMul, DType, Device, Module, Result, Tensor, D}; use candle_nn::{RotaryEmbedding, VarBuilder}; @@ -17,7 +17,6 @@ use crate::{ layers::{repeat_kv, CausalMasker, QLinear}, models::gemma::Config, pipeline::{extract_logits, Cache, NormalModel}, - DeviceMapMetadata, }; use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig}; @@ -471,13 +470,13 @@ impl XLoraModel { xlora_config: Option, xlora_ordering: Ordering, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = candle_nn::embedding( cfg.vocab_size, cfg.hidden_size, @@ -491,7 +490,9 @@ impl XLoraModel { cfg.rope_theta as f32, cfg.head_dim, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?); @@ -504,7 +505,7 @@ impl XLoraModel { &xlora_ordering, &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, preload_adapters, )?; layers.push(layer) @@ -548,7 +549,7 @@ impl XLoraModel { layers, norm, lm_head: QLinear::from_linear(lm_head), - device: real_device, + device: normal_loading_metadata.real_device, dtype: vb.dtype(), hidden_size: cfg.hidden_size, cache: Cache::new(cfg.num_hidden_layers, true), diff --git a/mistralrs-core/src/xlora_models/llama.rs b/mistralrs-core/src/xlora_models/llama.rs index db1e6f92de..55046a40bb 100644 --- a/mistralrs-core/src/xlora_models/llama.rs +++ b/mistralrs-core/src/xlora_models/llama.rs @@ -15,8 +15,7 @@ use crate::{ device_map::DeviceMapper, layers::{repeat_kv, CausalMasker, QLinear, RmsNorm}, models::llama::Config, - pipeline::{self, extract_logits, LayerCaches, NormalModel}, - DeviceMapMetadata, + pipeline::{self, extract_logits, LayerCaches, NormalLoadingMetadata, NormalModel}, }; use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig}; @@ -552,13 +551,13 @@ impl XLoraLlama { xlora_config: Option, xlora_ordering: Ordering, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { let dtype = vb.dtype(); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let wte = embedding( cfg.vocab_size, cfg.hidden_size, @@ -567,7 +566,7 @@ impl XLoraLlama { let lm_head = candle_nn::linear( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; let ln_f = RmsNorm::new( cfg.hidden_size, @@ -583,7 +582,9 @@ impl XLoraLlama { cfg.rope_theta, head_dim, cfg.max_position_embeddings, - mapper.device_for(i, false).unwrap_or(&real_device), + mapper + .device_for(i, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), ) @@ -597,7 +598,7 @@ impl XLoraLlama { &xlora_ordering, &*mapper, i, - loading_isq, + normal_loading_metadata.loading_isq, rotary_emb, preload_adapters, ) @@ -639,7 +640,7 @@ impl XLoraLlama { ln_f, lm_head: QLinear::from_linear(lm_head), kv_cache: pipeline::Cache::new(cfg.num_hidden_layers, true), - device: real_device, + device: normal_loading_metadata.real_device, xlora_classifier: xlora_config.map(|xlora_config| { XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap() }), diff --git a/mistralrs-core/src/xlora_models/mistral.rs b/mistralrs-core/src/xlora_models/mistral.rs index 4d240a7e54..620cd6fc66 100644 --- a/mistralrs-core/src/xlora_models/mistral.rs +++ b/mistralrs-core/src/xlora_models/mistral.rs @@ -3,7 +3,7 @@ use crate::{ layers::ScaledDotProductAttention, lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering}, - pipeline::IsqModel, + pipeline::{IsqModel, NormalLoadingMetadata}, }; /// Mistral LLM, https://github.com/mistralai/mistral-src use candle_core::{quantized::QMatMul, DType, Device, Module, Result, Tensor}; @@ -17,7 +17,6 @@ use crate::{ layers::{repeat_kv, CausalMasker, QLinear, RmsNorm}, models::mistral::Config, pipeline::{extract_logits, Cache, NormalModel}, - DeviceMapMetadata, }; use super::{classifier::XLoraClassifier, config::XLoraConfig, NonGranularState, ScalingsMaker}; @@ -438,12 +437,12 @@ impl XLoraModel { xlora_config: Option, xlora_ordering: Ordering, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let vb_m = vb.pp("model"); let embed_tokens = candle_nn::embedding( cfg.vocab_size, @@ -459,7 +458,9 @@ impl XLoraModel { cfg.rope_theta as f32, head_dim, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?); @@ -472,7 +473,7 @@ impl XLoraModel { &xlora_ordering, &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, preload_adapters, )?; layers.push(layer) @@ -513,7 +514,7 @@ impl XLoraModel { let lm_head = candle_nn::linear_no_bias( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, @@ -521,7 +522,7 @@ impl XLoraModel { norm, lm_head: QLinear::from_linear(lm_head), sliding_window: cfg.sliding_window, - device: real_device, + device: normal_loading_metadata.real_device, dtype: vb.dtype(), cache: Cache::new(cfg.num_hidden_layers, true), max_seq_len: cfg.max_position_embeddings, diff --git a/mistralrs-core/src/xlora_models/mixtral.rs b/mistralrs-core/src/xlora_models/mixtral.rs index 65e715dfee..223bedd72f 100644 --- a/mistralrs-core/src/xlora_models/mixtral.rs +++ b/mistralrs-core/src/xlora_models/mixtral.rs @@ -3,7 +3,7 @@ use crate::{ layers::{MatMul, ScaledDotProductAttention}, lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering}, - pipeline::IsqModel, + pipeline::{IsqModel, NormalLoadingMetadata}, }; /// Mixtral Model /// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py @@ -19,7 +19,6 @@ use crate::{ layers::{repeat_kv, CausalMasker, RmsNorm}, models::mixtral::Config, pipeline::{extract_logits, Cache, NormalModel}, - DeviceMapMetadata, }; use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig}; @@ -576,13 +575,13 @@ impl XLoraModel { xlora_config: Option, xlora_ordering: Ordering, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = candle_nn::embedding( cfg.vocab_size, cfg.hidden_size, @@ -597,7 +596,9 @@ impl XLoraModel { cfg.rope_theta as f32, head_dim, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?); @@ -610,7 +611,7 @@ impl XLoraModel { &xlora_ordering, &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, preload_adapters, )?; layers.push(layer) @@ -650,7 +651,7 @@ impl XLoraModel { let lm_head = candle_nn::linear_no_bias( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, @@ -658,7 +659,7 @@ impl XLoraModel { norm, lm_head: QMatMul::Tensor(lm_head.weight().clone()), sliding_window: cfg.sliding_window, - device: real_device, + device: normal_loading_metadata.real_device, dtype: vb.dtype(), cache: Cache::new(cfg.num_hidden_layers, false), max_seq_len: cfg.max_position_embeddings, diff --git a/mistralrs-core/src/xlora_models/phi2.rs b/mistralrs-core/src/xlora_models/phi2.rs index 327ca8fdeb..a648522748 100644 --- a/mistralrs-core/src/xlora_models/phi2.rs +++ b/mistralrs-core/src/xlora_models/phi2.rs @@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ layers::ScaledDotProductAttention, lora::{linear, LinearLayerLike, LoraConfig, Ordering}, - pipeline::IsqModel, + pipeline::{IsqModel, NormalLoadingMetadata}, }; /// Phi model. /// https://huggingface.co/microsoft/phi-2 @@ -24,7 +24,6 @@ use crate::{ layers::{repeat_kv, CausalMasker, QLinear}, models::phi2::Config, pipeline::{extract_logits, NormalModel}, - DeviceMapMetadata, }; use super::{classifier::XLoraClassifier, Cache, NonGranularState, ScalingsMaker, XLoraConfig}; @@ -427,13 +426,13 @@ impl Model { xlora_config: Option, xlora_ordering: Ordering, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = embedding( cfg.vocab_size, cfg.hidden_size, @@ -454,7 +453,9 @@ impl Model { cfg.head_dim(), (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?; @@ -466,7 +467,7 @@ impl Model { &xlora_ordering, &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, rotary_emb, preload_adapters, )?; @@ -496,7 +497,7 @@ impl Model { let lm_head = candle_nn::linear( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, @@ -504,7 +505,7 @@ impl Model { final_layernorm, lm_head: QLinear::from_linear(lm_head), cache: Cache::new(cfg.num_hidden_layers, true), - device: real_device, + device: normal_loading_metadata.real_device, max_seq_len: cfg.max_position_embeddings, dtype: vb.dtype(), xlora_classifier: xlora_config.map(|xlora_config| { diff --git a/mistralrs-core/src/xlora_models/phi3.rs b/mistralrs-core/src/xlora_models/phi3.rs index 1fabcf2860..8754531405 100644 --- a/mistralrs-core/src/xlora_models/phi3.rs +++ b/mistralrs-core/src/xlora_models/phi3.rs @@ -5,7 +5,7 @@ use crate::{ layers::ScaledDotProductAttention, lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering}, - pipeline::IsqModel, + pipeline::{IsqModel, NormalLoadingMetadata}, }; use candle_core::{quantized::QMatMul, DType, Device, Module, Result, Tensor, D}; use candle_nn::VarBuilder; @@ -18,7 +18,6 @@ use crate::{ layers::{repeat_kv, CausalMasker, PhiRotaryEmbedding, QLinear, RmsNorm}, models::phi3::Config, pipeline::{extract_logits, NormalModel}, - DeviceMapMetadata, }; use crate::pipeline::Cache; @@ -387,13 +386,13 @@ impl Model { xlora_config: Option, xlora_ordering: Ordering, _is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = candle_nn::embedding( cfg.vocab_size, cfg.hidden_size, @@ -406,7 +405,9 @@ impl Model { let rotary_emb = Arc::new(PhiRotaryEmbedding::new( vb.dtype(), cfg.clone(), - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), )?); let layer = DecoderLayer::new( rotary_emb.clone(), @@ -417,7 +418,7 @@ impl Model { &xlora_ordering, &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, preload_adapters, )?; layers.push(layer) @@ -449,14 +450,14 @@ impl Model { let lm_head = candle_nn::linear_no_bias( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, layers, norm, lm_head: QLinear::from_linear(lm_head), - device: real_device, + device: normal_loading_metadata.real_device, dtype: vb.dtype(), cache: Cache::new(cfg.num_hidden_layers, true), max_seq_len: cfg.max_position_embeddings, From c08f7e5a21cd6e4e93385102ad266ef30486c5fa Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Tue, 4 Jun 2024 21:10:16 -0400 Subject: [PATCH 2/2] Fix --- mistralrs-core/src/pipeline/normal_loaders.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/pipeline/normal_loaders.rs b/mistralrs-core/src/pipeline/normal_loaders.rs index 80add70c4d..5237834f29 100644 --- a/mistralrs-core/src/pipeline/normal_loaders.rs +++ b/mistralrs-core/src/pipeline/normal_loaders.rs @@ -13,7 +13,7 @@ use serde::Deserialize; /// Metadata for loading a model with ISQ or device mapping. pub struct NormalLoadingMetadata { - // Device mapping metadata which can be used to contstruct a concrete device mapper + // Device mapping metadata which can be used to construct a concrete device mapper pub mapper: DeviceMapMetadata, // Flag to check if loading in ISQ pub loading_isq: bool,