Skip to content

Commit

Permalink
MLlama - if f16, load vision model in f32 (#820)
Browse files Browse the repository at this point in the history
* If f16, load in f32

* Fix set_dtype

* Now it works
  • Loading branch information
EricLBuehler authored Oct 3, 2024
1 parent 9365c76 commit 3e79d85
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 16 deletions.
10 changes: 5 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ license = "MIT"

[workspace.dependencies]
anyhow = "1.0.80"
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "156ebd1" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "156ebd1" }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4" }
serde = "1.0.197"
serde_json = "1.0.114"
indexmap = { version = "2.2.5", features = ["serde"] }
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ candle-core.workspace = true
candle-nn.workspace = true
serde.workspace = true
serde_json.workspace = true
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "156ebd1", optional = true }
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4", optional = true }
dirs = "5.0.1"
hf-hub = "0.3.2"
thiserror = "1.0.57"
Expand Down
29 changes: 22 additions & 7 deletions mistralrs-core/src/vision_models/mllama/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ fn prepare_cross_attention_mask(
let bs = cross_attention_mask.dim(0)?;
let text_total_length = cross_attention_mask.dim(1)?;
let mut cross_attn_mask = repeat_interleave(
&cross_attention_mask.to_dtype(DType::F32)?.to_dtype(dtype)?,
&cross_attention_mask.to_dtype(DType::F32)?,
num_vision_tokens,
3,
)?;
cross_attn_mask = cross_attn_mask.reshape((bs, text_total_length, ()))?;
cross_attn_mask = cross_attn_mask.unsqueeze(1)?;

// Invert the mask
let inverted_cross_attn_mask = (1. - cross_attn_mask.to_dtype(DType::F32)?.to_dtype(dtype)?)?;
let inverted_cross_attn_mask = (1. - cross_attn_mask)?;
const NEG_INF_VALUE: f32 = -1e15;
cross_attn_mask = masked_fill(
&inverted_cross_attn_mask,
Expand All @@ -75,7 +75,9 @@ fn prepare_cross_attention_mask(
.unsqueeze(D::Minus1)?;

cross_attn_mask = cross_attn_mask
.broadcast_mul(&full_text_row_masked_out_mask.to_dtype(cross_attn_mask.dtype())?)?;
.broadcast_mul(&full_text_row_masked_out_mask.to_dtype(cross_attn_mask.dtype())?)?
.to_dtype(DType::F32)?
.to_dtype(dtype)?;

Ok((cross_attn_mask, full_text_row_masked_out_mask))
}
Expand All @@ -85,6 +87,7 @@ pub(crate) struct MLlamaModel {
language_model: MLlamaTextModel,
multi_modal_projector: Linear,
hidden_size: usize,
dtype: DType,
}

impl MLlamaModel {
Expand All @@ -96,10 +99,18 @@ impl MLlamaModel {
attention_mechanism: AttentionImplementation,
) -> Result<Self> {
let real_dev = normal_loading_metadata.real_device.clone();
// This vision model is very sensitive.
let vision_model_dtype = if vb.dtype() == DType::F16 {
DType::F32
} else {
vb.dtype()
};
Ok(Self {
vision_model: MLlamaVisionModel::new(
&cfg.vision_config,
vb.pp("vision_model").set_device(real_dev.clone()),
vb.pp("vision_model")
.set_device(real_dev.clone())
.set_dtype(vision_model_dtype),
)?,
language_model: MLlamaTextModel::new(
&cfg.text_config,
Expand All @@ -111,9 +122,12 @@ impl MLlamaModel {
multi_modal_projector: linear(
cfg.vision_config.vision_output_dim,
cfg.text_config.hidden_size,
vb.pp("multi_modal_projector").set_device(real_dev.clone()),
vb.pp("multi_modal_projector")
.set_device(real_dev.clone())
.set_dtype(vision_model_dtype),
)?,
hidden_size: cfg.text_config.hidden_size,
dtype: vb.dtype(),
})
}

Expand Down Expand Up @@ -142,7 +156,8 @@ impl MLlamaModel {
let cross_attention_states = self
.multi_modal_projector
.forward(&vision_outputs.flatten(0, 1)?)?
.reshape(((), vision_outputs.dim(D::Minus2)?, self.hidden_size))?;
.reshape(((), vision_outputs.dim(D::Minus2)?, self.hidden_size))?
.to_dtype(self.dtype)?;
Some(cross_attention_states)
} else {
None
Expand All @@ -153,7 +168,7 @@ impl MLlamaModel {
let (cmask, fmask) = prepare_cross_attention_mask(
cross_attn_mask,
self.vision_model.num_patches,
self.multi_modal_projector.weight().dtype(),
self.dtype,
)?;
(Some(cmask), Some(fmask))
} else {
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-pyo3/Cargo_template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pyo3.workspace = true
mistralrs-core = { version = "0.3.1", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] }
serde.workspace = true
serde_json.workspace = true
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "156ebd1", features=["$feature_name"] }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4", features=["$feature_name"] }
indexmap.workspace = true
accelerate-src = { workspace = true, optional = true }
intel-mkl-src = { workspace = true, optional = true }
Expand Down

0 comments on commit 3e79d85

Please sign in to comment.