Skip to content

Commit

Permalink
Improve Llama 3.2 Vision accuracy (#877)
Browse files Browse the repository at this point in the history
* Rename kv cache

* Almost works, found cross attn mask bug

* Still working

* It works now

* Clippy
  • Loading branch information
EricLBuehler authored Oct 23, 2024
1 parent 32e8945 commit ff76a5a
Show file tree
Hide file tree
Showing 25 changed files with 421 additions and 202 deletions.
18 changes: 12 additions & 6 deletions examples/python/cookbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
"res = runner.send_chat_completion_request(\n",
" ChatCompletionRequest(\n",
" model=\"mistral\",\n",
" messages=[{\"role\":\"user\", \"content\":\"Tell me a story about the Rust type system.\"}],\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"Tell me a story about the Rust type system.\"}\n",
" ],\n",
" max_tokens=256,\n",
" presence_penalty=1.0,\n",
" top_p=0.1,\n",
Expand Down Expand Up @@ -106,7 +108,9 @@
"res = runner.send_chat_completion_request(\n",
" ChatCompletionRequest(\n",
" model=\"mistral\",\n",
" messages=[{\"role\":\"user\", \"content\":\"Tell me a story about the Rust type system.\"}],\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"Tell me a story about the Rust type system.\"}\n",
" ],\n",
" max_tokens=256,\n",
" presence_penalty=1.0,\n",
" top_p=0.1,\n",
Expand Down Expand Up @@ -205,14 +209,14 @@
"\n",
"runner = Runner(\n",
" which=Which.XLoraGGUF(\n",
" tok_model_id=None, # Automatically determine from ordering file\n",
" tok_model_id=None, # Automatically determine from ordering file\n",
" quantized_model_id=\"TheBloke/zephyr-7B-beta-GGUF\",\n",
" quantized_filename=\"zephyr-7b-beta.Q4_0.gguf\",\n",
" xlora_model_id=\"lamm-mit/x-lora\",\n",
" order=\"orderings/xlora-paper-ordering.json\",\n",
" tgt_non_granular_index=None,\n",
" )\n",
")\n"
")"
]
},
{
Expand All @@ -233,15 +237,17 @@
"res = runner.send_chat_completion_request(\n",
" ChatCompletionRequest(\n",
" model=\"mistral\",\n",
" messages=[{\"role\":\"user\", \"content\":\"Tell me a story about the Rust type system.\"}],\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"Tell me a story about the Rust type system.\"}\n",
" ],\n",
" max_tokens=256,\n",
" presence_penalty=1.0,\n",
" top_p=0.1,\n",
" temperature=0.1,\n",
" )\n",
")\n",
"print(res.choices[0].message.content)\n",
"print(res.usage)\n"
"print(res.usage)"
]
}
],
Expand Down
1 change: 1 addition & 0 deletions examples/python/simple_tool_calling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
" )\n",
"]\n",
"\n",
"\n",
"def add_2_numbers(x, y):\n",
" return x + y\n",
"\n",
Expand Down
1 change: 1 addition & 0 deletions examples/python/tool_calling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
" res = None\n",
" return res\n",
"\n",
"\n",
"def run_python(code: str) -> str:\n",
" lcls = dict()\n",
" # No opening of files\n",
Expand Down
7 changes: 5 additions & 2 deletions examples/server/phi3_duckduckgo_mistral.rs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@
" [\n",
" (\"system\", system_message_content),\n",
" (\"human\", human_message_content),\n",
" (\"human\", \"Use `duckduckgo_results_json` to gather information before answering the question.\"),\n",
" (\n",
" \"human\",\n",
" \"Use `duckduckgo_results_json` to gather information before answering the question.\",\n",
" ),\n",
" MessagesPlaceholder(variable_name=\"agent_scratchpad\"),\n",
" ]\n",
")\n",
Expand All @@ -123,7 +126,7 @@
" agent=agent, tools=tools, verbose=True, handle_parsing_errors=True\n",
")\n",
"\n",
"ch = {\"input\": RunnablePassthrough()} | agent_executor | (lambda x: x[\"output\"])\n"
"ch = {\"input\": RunnablePassthrough()} | agent_executor | (lambda x: x[\"output\"])"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/diffusion_models/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use tokenizers::Tokenizer;
use crate::{
pipeline::{
text_models_inputs_processor::PagedAttentionMeta, InputProcessorOutput, InputsProcessor,
InputsProcessorType, MessagesAction, Processor,
InputsProcessorType, MessagesAction, ProcessingConfig, Processor,
},
sequence::Sequence,
MessageContent, Pipeline,
Expand All @@ -23,7 +23,7 @@ impl Processor for DiffusionProcessor {
&self,
_pipeline: &dyn Pipeline,
_messages: Vec<IndexMap<String, MessageContent>>,
_add_generation_prompt: bool,
_cfg: ProcessingConfig,
_tools: Vec<crate::Tool>,
) -> Result<(Vec<u32>, String)> {
anyhow::bail!(
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ impl Engine {
let template = pipeline.get_processor().process(
pipeline,
messages,
true,
pipeline.get_processing_cfg(),
request.tools.unwrap_or_default(),
);
handle_seq_error!(template, request.response)
Expand Down
12 changes: 11 additions & 1 deletion mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,17 @@ impl Module for FusedBiasLinear {
)?
.t()
} else {
x.matmul(&w.t()?)? + b
// x.matmul(&w.t()?)? + b
let dtype = x.dtype();
let mut out = b.contiguous()?.to_dtype(DType::F32)?;
x.to_dtype(DType::F32)?
.contiguous()?
.matmul_with_alpha_beta(
&w.t()?.to_dtype(DType::F32)?.contiguous()?,
&mut out,
None,
)?;
Ok(out.to_dtype(dtype)?)
}
}
}
Expand Down
6 changes: 5 additions & 1 deletion mistralrs-core/src/pipeline/amoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use tracing::{info, warn};
use crate::{
amoe::{AnyMoeConfig, AnyMoeTrainingInputRow, AnyMoeTrainingInputs, AnyMoeTrainingResult},
get_mut_arcmutex,
pipeline::processing::ProcessingConfig,
prefix_cacher::PrefixCacheManager,
sampler::Sampler,
sequence::{SeqStepType, Sequence, SequenceGroup, SequenceRecognizer},
Expand Down Expand Up @@ -215,6 +216,9 @@ impl PreProcessingMixin for AnyMoePipeline {
fn get_processor(&self) -> Arc<dyn super::Processor> {
get_mut_arcmutex!(self.target).get_processor()
}
fn get_processing_cfg(&self) -> ProcessingConfig {
get_mut_arcmutex!(self.target).get_processing_cfg()
}
}

impl MetadataMixin for AnyMoePipeline {
Expand Down Expand Up @@ -398,7 +402,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
("role".to_string(), Either::Left("user".to_string())),
("content".to_string(), Either::Left(prompt.clone())),
])],
true,
ProcessingConfig::default(),
Vec::new(),
)
.map_err(candle_core::Error::msg)?;
Expand Down
37 changes: 31 additions & 6 deletions mistralrs-core/src/pipeline/chat_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use tracing::info;

use crate::{MessageContent, Tool};

use super::processing::ProcessingConfig;

const SUPPORTED_ALTERNATE_EOS: &[&str] = &[
"<|im_end|>", // Handle ChatML case
"<end_of_turn>", // Handle Gemma2 chat case
Expand Down Expand Up @@ -216,7 +218,7 @@ fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {

pub fn apply_chat_template_to(
messages: Vec<IndexMap<String, MessageContent>>,
add_generation_prompt: bool,
cfg: ProcessingConfig,
template: &ChatTemplateValue,
bos_tok: Option<String>,
eos_tok: Option<String>,
Expand Down Expand Up @@ -268,27 +270,50 @@ pub fn apply_chat_template_to(
env.add_filter("tojson", tojson);
let tmpl = env.get_template("chat_template").unwrap();

let date = chrono::Utc::now();
let date_string = date.format("%d, %B, %Y").to_string();
let date_string = if cfg.add_date_string {
let date = chrono::Utc::now();
Some(date.format("%d, %B, %Y").to_string())
} else {
None
};

if tools.is_empty() {
if let Some(date_string) = date_string {
Ok(tmpl.render(context! {
messages => new_messages,
add_generation_prompt => cfg.add_generation_prompt,
bos_token => bos_tok,
eos_token => eos_tok,
unk_token => unk_tok,
date_string => date_string,
})?)
} else {
Ok(tmpl.render(context! {
messages => new_messages,
add_generation_prompt => cfg.add_generation_prompt,
bos_token => bos_tok,
eos_token => eos_tok,
unk_token => unk_tok,
})?)
}
} else if let Some(date_string) = date_string {
Ok(tmpl.render(context! {
messages => new_messages,
add_generation_prompt => add_generation_prompt,
add_generation_prompt => cfg.add_generation_prompt,
bos_token => bos_tok,
eos_token => eos_tok,
unk_token => unk_tok,
tools => tools,
date_string => date_string,
})?)
} else {
Ok(tmpl.render(context! {
messages => new_messages,
add_generation_prompt => add_generation_prompt,
add_generation_prompt => cfg.add_generation_prompt,
bos_token => bos_tok,
eos_token => eos_tok,
unk_token => unk_tok,
tools => tools,
date_string => date_string,
})?)
}
}
5 changes: 4 additions & 1 deletion mistralrs-core/src/pipeline/diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use super::{
AdapterActivationMixin, AnyMoePipelineMixin, Cache, CacheManagerMixin, DiffusionLoaderType,
DiffusionModel, DiffusionModelLoader, FluxLoader, ForwardInputsResult, GeneralMetadata,
IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
PreProcessingMixin, Processor, TokenSource,
PreProcessingMixin, ProcessingConfig, Processor, TokenSource,
};
use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
use crate::paged_attention::AttentionImplementation;
Expand Down Expand Up @@ -251,6 +251,9 @@ impl PreProcessingMixin for DiffusionPipeline {
fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
None
}
fn get_processing_cfg(&self) -> ProcessingConfig {
ProcessingConfig::default()
}
}

impl IsqPipelineMixin for DiffusionPipeline {
Expand Down
5 changes: 4 additions & 1 deletion mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::{
};
use super::{
AdapterActivationMixin, AnyMoePipelineMixin, CacheManagerMixin, ForwardInputsResult,
IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin, ProcessingConfig,
};
use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
Expand Down Expand Up @@ -443,6 +443,9 @@ impl PreProcessingMixin for GGMLPipeline {
fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
None
}
fn get_processing_cfg(&self) -> ProcessingConfig {
ProcessingConfig::default()
}
}

impl IsqPipelineMixin for GGMLPipeline {
Expand Down
5 changes: 4 additions & 1 deletion mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::{
};
use super::{
AdapterActivationMixin, AnyMoePipelineMixin, CacheManagerMixin, ForwardInputsResult,
IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin, ProcessingConfig,
};
use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
Expand Down Expand Up @@ -557,6 +557,9 @@ impl PreProcessingMixin for GGUFPipeline {
fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
None
}
fn get_processing_cfg(&self) -> ProcessingConfig {
ProcessingConfig::default()
}
}

impl IsqPipelineMixin for GGUFPipeline {
Expand Down
8 changes: 7 additions & 1 deletion mistralrs-core/src/pipeline/loaders/normal_loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
pipeline::{
isq::IsqModelLoader,
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
Cache, IsqModel,
Cache, IsqModel, ProcessingConfig,
},
serde_default_fn,
utils::log::once_log_info,
Expand Down Expand Up @@ -111,6 +111,9 @@ pub trait NormalModelLoader: IsqModelLoader {
fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>>;
/// Get total num_hidden_layers for the layers which will be device mapped.
fn get_total_device_mapping_num_layers(&self, config: &str) -> Result<usize>;
fn processing_cfg(&self) -> ProcessingConfig {
ProcessingConfig::default()
}
}

#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
Expand Down Expand Up @@ -638,6 +641,9 @@ impl NormalModelLoader for LlamaLoader {
fn get_total_device_mapping_num_layers(&self, config: &str) -> Result<usize> {
Ok(LlamaBasicConfig::deserialize(config, false)?.num_hidden_layers)
}
fn processing_cfg(&self) -> ProcessingConfig {
ProcessingConfig::default().with_add_date_string(true)
}
}

impl IsqModelLoader for LlamaLoader {
Expand Down
8 changes: 5 additions & 3 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ use mistralrs_quant::IsqType;
pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
pub(crate) use paths::{get_chat_template, get_model_paths, get_xlora_paths, XLoraPaths};
pub(crate) use processing::{
apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
apply_chat_template, BasicProcessor, MessagesAction, ProcessingConfig, Processor,
ProcessorCreator,
};
use rand_isaac::Isaac64Rng;
pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline};
Expand Down Expand Up @@ -100,6 +101,7 @@ pub trait PreProcessingMixin: MetadataMixin {
fn get_processor(&self) -> Arc<dyn Processor> {
Arc::new(BasicProcessor)
}
fn get_processing_cfg(&self) -> ProcessingConfig;
/// Only None if it doesnt make sense for the model
fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
Expand Down Expand Up @@ -564,7 +566,7 @@ mod tests {
expected_outputs: &[&str],
inputs: Vec<IndexMap<String, MessageContent>>,
) {
use crate::pipeline::chat_template::ChatTemplateValue;
use crate::pipeline::{chat_template::ChatTemplateValue, processing::ProcessingConfig};

use super::chat_template::apply_chat_template_to;
let mut failed = Vec::new();
Expand All @@ -578,7 +580,7 @@ mod tests {
} else {
inputs.clone()
},
true,
ProcessingConfig::default(),
&ChatTemplateValue(Either::Left(template.to_string())),
Some(bos.to_string()),
Some(eos.to_string()),
Expand Down
6 changes: 6 additions & 0 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use super::{
use super::{
AdapterActivationMixin, AnyMoePipelineMixin, CacheManagerMixin, ForwardInputsResult,
IsqOrganization, IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
ProcessingConfig,
};
use super::{
AutoLoader, Gemma2Loader, GemmaLoader, LlamaLoader, MistralLoader, MixtralLoader,
Expand Down Expand Up @@ -64,6 +65,7 @@ pub struct NormalPipeline {
template_filename: Option<PathBuf>,
generation_config: Option<PathBuf>,
config: String,
processing_cfg: ProcessingConfig,
}

/// A loader for a "normal" (non-quantized) model.
Expand Down Expand Up @@ -462,6 +464,7 @@ impl Loader for NormalLoader {
template_filename: paths.get_template_filename().clone(),
generation_config: paths.get_gen_conf_filename().cloned(),
config,
processing_cfg: self.inner.processing_cfg(),
})))
}

Expand All @@ -484,6 +487,9 @@ impl PreProcessingMixin for NormalPipeline {
fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
None
}
fn get_processing_cfg(&self) -> ProcessingConfig {
self.processing_cfg
}
}

impl IsqPipelineMixin for NormalPipeline {
Expand Down
Loading

0 comments on commit ff76a5a

Please sign in to comment.