Skip to content

Commit

Permalink
Support uqff load/save for idefics3 (#1023)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Jan 2, 2025
1 parent 0875194 commit 27ab495
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 21 deletions.
26 changes: 14 additions & 12 deletions mistralrs-core/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,19 @@ impl Llama {
let xs = MatMul.qmethod_matmul(&x, &*self.lm_head)?;
extract_logits(&xs, context_lens)
}

pub fn residual_tensors_m(&self, uvb_m: UnVarBuilder) -> Vec<(String, Tensor)> {
uvb_m.pp("embed_tokens").add(&self.wte);
uvb_m.pp("norm").add(&self.ln_f);

for (layer_idx, layer) in self.blocks.iter().enumerate() {
let uvb_l = uvb_m.pp("layers").pp(layer_idx);
uvb_l.pp("input_layernorm").add(&layer.rms_1);
uvb_l.pp("post_attention_layernorm").add(&layer.rms_2);
}

uvb_m.to_safetensors()
}
}

impl IsqModel for Llama {
Expand Down Expand Up @@ -650,18 +663,7 @@ impl IsqModel for Llama {

fn residual_tensors(&self) -> Vec<(String, Tensor)> {
let uvb = UnVarBuilder::new();

let uvb_m = uvb.pp("model");
uvb_m.pp("embed_tokens").add(&self.wte);
uvb_m.pp("norm").add(&self.ln_f);

for (layer_idx, layer) in self.blocks.iter().enumerate() {
let uvb_l = uvb_m.pp("layers").pp(layer_idx);
uvb_l.pp("input_layernorm").add(&layer.rms_1);
uvb_l.pp("post_attention_layernorm").add(&layer.rms_2);
}

uvb.to_safetensors()
self.residual_tensors_m(uvb.pp("model"))
}

fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
Expand Down
17 changes: 13 additions & 4 deletions mistralrs-core/src/pipeline/loaders/vision_loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -732,9 +732,18 @@ impl VisionModelLoader for Idefics3Loader {
}

impl IsqModelLoader for Idefics3Loader {
fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
let config: Idefics3Config = serde_json::from_str(config)?;
let text_cfg = serde_json::to_string(&config.text_config)?;
super::LlamaLoader.isq_layer_regexes(&text_cfg)
fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
Ok(vec![
Regex::new(r"lm_head\.(weight|bias)$")?,
// Attention
Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
// MLP
Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
])
}
}
1 change: 1 addition & 0 deletions mistralrs-core/src/utils/varbuilder_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ trait LoadTensors {
.get_names()
.into_iter()
.filter(|x| predicate(x.to_string()));
dbg!(&names_only.clone().collect::<Vec<_>>());
let iter = self.get_name_key_pairs(names_only).collect::<Vec<_>>();

// Take the filtered list of tensors to load, store with derived lookup key:
Expand Down
16 changes: 15 additions & 1 deletion mistralrs-core/src/vision_models/idefics3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::{
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
},
utils::unvarbuilder::UnVarBuilder,
AnyMoeConfig, AnyMoeExpertType,
};

Expand Down Expand Up @@ -246,7 +247,20 @@ impl IsqModel for Idefics3Model {
}

fn residual_tensors(&self) -> Vec<(String, Tensor)> {
self.text_model.residual_tensors()
let uvb = UnVarBuilder::new();

let uvb_m = uvb.pp("model");
uvb_m
.pp("connector")
.pp("modality_projection")
.pp("proj")
.add(&self.connector.modality_projection.proj);
uvb.extend(self.text_model.residual_tensors_m(uvb_m.pp("text_model")));
uvb_m
.pp("vision_model")
.extend(self.vision.residual_tensors());

uvb.to_safetensors()
}
}

Expand Down
60 changes: 56 additions & 4 deletions mistralrs-core/src/vision_models/idefics3/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ use candle_nn::{
};
use std::ops::Mul;

use crate::layers::{Activation, CausalMasker};
use crate::{
layers::{Activation, CausalMasker},
utils::unvarbuilder::UnVarBuilder,
};

use super::config::{Idefics3Config, Idefics3VisionConfig};

struct Idefics3SimpleMLP {
proj: Linear,
pub(crate) struct Idefics3SimpleMLP {
pub(crate) proj: Linear,
}

impl Idefics3SimpleMLP {
Expand All @@ -29,7 +32,7 @@ impl Idefics3SimpleMLP {

pub struct Idefics3Connector {
scale_factor: usize,
modality_projection: Idefics3SimpleMLP,
pub(crate) modality_projection: Idefics3SimpleMLP,
}

impl Idefics3Connector {
Expand Down Expand Up @@ -192,6 +195,15 @@ impl VisionEmbeddings {
let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
}

fn residual_tensors(&self) -> Vec<(String, Tensor)> {
let uvb = UnVarBuilder::new();

uvb.pp("patch_embedding").add(&self.patch_embedding);
uvb.pp("position_embedding").add(&self.position_embedding);

uvb.to_safetensors()
}
}

struct Attention {
Expand Down Expand Up @@ -264,6 +276,17 @@ impl Attention {
.reshape((b_sz, q_len, self.embed_dim))?
.apply(&self.o_proj)
}

fn residual_tensors(&self) -> Vec<(String, Tensor)> {
let uvb = UnVarBuilder::new();

uvb.pp("q_proj").add(&self.q_proj);
uvb.pp("k_proj").add(&self.k_proj);
uvb.pp("v_proj").add(&self.v_proj);
uvb.pp("out_proj").add(&self.o_proj);

uvb.to_safetensors()
}
}

struct VisionMLP {
Expand All @@ -288,6 +311,15 @@ impl VisionMLP {
x = self.activation.forward(&x)?;
self.fc2.forward(&x)
}

fn residual_tensors(&self) -> Vec<(String, Tensor)> {
let uvb = UnVarBuilder::new();

uvb.pp("fc1").add(&self.fc1);
uvb.pp("fc2").add(&self.fc2);

uvb.to_safetensors()
}
}

struct EncoderLayer {
Expand Down Expand Up @@ -417,4 +449,24 @@ impl Idefics3VisionTransformer {
.forward(&hidden_states, attention_mask.as_ref())?;
hidden_states.apply(&self.post_layernorm)
}

pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
let uvb = UnVarBuilder::new();

uvb.pp("post_layernorm").add(&self.post_layernorm);
uvb.pp("embeddings")
.extend(self.embeddings.residual_tensors());

let uvb_enc = uvb.pp("encoder");
for (i, layer) in self.encoder.layers.iter().enumerate() {
let uvb_l = uvb_enc.pp("layers").pp(i);

uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
}

uvb.to_safetensors()
}
}

0 comments on commit 27ab495

Please sign in to comment.