From 25174b495643e7a2f14ae7c2d514bdbcad310afd Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Wed, 16 Oct 2024 18:12:35 -0400 Subject: [PATCH] Patch UQFF metal generation (#857) * Patch metal bug * Improve ISQ and loading speed on metal --- mistralrs-core/src/pipeline/isq.rs | 435 ++++++------------- mistralrs-core/src/utils/progress.rs | 83 ---- mistralrs-core/src/utils/varbuilder_utils.rs | 28 +- 3 files changed, 137 insertions(+), 409 deletions(-) diff --git a/mistralrs-core/src/pipeline/isq.rs b/mistralrs-core/src/pipeline/isq.rs index 5bcdc5b9f2..129e7fd607 100644 --- a/mistralrs-core/src/pipeline/isq.rs +++ b/mistralrs-core/src/pipeline/isq.rs @@ -240,224 +240,39 @@ pub trait IsqModel { } let t_start = Instant::now(); - #[cfg(not(feature = "metal"))] - { - use rayon::iter::IntoParallelRefIterator; - - let current_rayon_threads = rayon::current_num_threads(); - // Get the MINIMUM of the max isq threads the quant method allows - let minimum_max_threads = tensors - .iter() - .map(|(q, _)| { - if let Some(dtype) = dtype { - q.get_max_isq_cpu_threads(dtype) - .map(usize::from) - .unwrap_or(current_rayon_threads) - } else { - current_rayon_threads - } - }) - .min() - .unwrap_or(current_rayon_threads); - - info!("Applying ISQ on {minimum_max_threads} threads."); - let pool = rayon::ThreadPoolBuilder::new() - .num_threads(minimum_max_threads) - .build() - .map_err(candle_core::Error::msg)?; - - pool.install(|| { - use indicatif::ParallelProgressIterator; - use rayon::iter::{ - IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator, - }; - if silent { - tensors.par_iter_mut().zip(devices_and_dtypes).for_each( - |((tensor, _), (device, dtype))| { - **tensor = tensor - .clone() - .apply_isq(dtype, device.clone(), &n_quantized) - .unwrap(); - device.synchronize().unwrap(); - }, - ); + use rayon::iter::IntoParallelRefIterator; + + let current_rayon_threads = rayon::current_num_threads(); + // Get the MINIMUM of the max isq threads the quant method allows + let minimum_max_threads = tensors + .iter() + .map(|(q, _)| { + if let Some(dtype) = dtype { + q.get_max_isq_cpu_threads(dtype) + .map(usize::from) + .unwrap_or(current_rayon_threads) } else { - tensors - .par_iter_mut() - .zip(devices_and_dtypes) - .progress_with(bar) - .for_each(|((tensor, _), (device, dtype))| { - **tensor = tensor - .clone() - .apply_isq(dtype, device.clone(), &n_quantized) - .unwrap(); - device.synchronize().unwrap(); - }); - } - }); - - if let Some(serialized) = write_artifacts { - info!( - "Serializing {total_tensors} ISQ tensors to `{}`.", - serialized.display() - ); - - if !serialized.extension().is_some_and(|ext| ext == "uqff") { - candle_core::bail!("UQFF output path extension must be `.uqff`",); + current_rayon_threads } + }) + .min() + .unwrap_or(current_rayon_threads); - let bar = ProgressBar::new(total_tensors as u64); - bar.set_style( - ProgressStyle::default_bar() - .template( - "[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})", - ) - .unwrap() - .progress_chars("#>-"), - ); - - let pool = rayon::ThreadPoolBuilder::new() - .num_threads(2) - .build() - .map_err(candle_core::Error::msg)?; - - let quantized_values = pool.install(|| { - if silent { - tensors - .par_iter() - .enumerate() - .filter(|(_, (layer, _))| layer.isq_serde_supported()) - .map(|(i, (layer, _))| { - Ok(( - i.to_string(), - Tensor::new( - Cow::into_owned(layer.serialize()?), - &Device::Cpu, - )?, - )) - }) - .collect::>>() - } else { - tensors - .par_iter() - .enumerate() - .progress_with(bar) - .filter(|(_, (layer, _))| layer.isq_serde_supported()) - .map(|(i, (layer, _))| { - Ok(( - i.to_string(), - Tensor::new( - Cow::into_owned(layer.serialize()?), - &Device::Cpu, - )?, - )) - }) - .collect::>>() - } - }); - - let parent = serialized - .parent() - .context("Target UQFF path must have a filename!")?; - - std::fs::create_dir_all(parent)?; - - safetensors::serialize_to_file(quantized_values?, &None, serialized)?; - - let residual = match organization { - IsqOrganization::Default => self.residual_tensors(), - IsqOrganization::MoeExpertsOnly => self - .residual_tensors_moe_experts_only() - .unwrap_or(self.residual_tensors()), - }; - - let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS); - let config_out = parent.join("config.json"); - let tokenizer_out = parent.join("tokenizer.json"); - let tokenizer_cfg_out = parent.join("tokenizer_config.json"); - let gen_cfg_out = parent.join("generation_config.json"); - let processor_out = parent.join("processor_config.json"); - let preprocessor_out = parent.join("preprocessor_config.json"); - - info!( - "Serializing {} residual tensors to `{}`.", - residual.len(), - residual_out.display() - ); - - safetensors::serialize_to_file(residual, &None, &residual_out)?; - - let UqffFullSer { - tokenizer, - template_filename, - generation_config, - config, - processor_filename, - preprocessor_filename, - } = full_ser; - - info!("Serializing configuration to `{}`.", config_out.display()); - - std::fs::write(config_out, config)?; - - info!("Serializing tokenizer to `{}`.", tokenizer_out.display()); - - serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer) - .map_err(candle_core::Error::msg)?; - - if let Some(template_filename) = template_filename { - info!( - "Serializing tokenizer config to `{}`.", - tokenizer_cfg_out.display() - ); - - let template = - std::fs::read(template_filename).map_err(candle_core::Error::msg)?; - std::fs::write(&tokenizer_cfg_out, template) - .map_err(candle_core::Error::msg)?; - } - - if let Some(generation_config) = generation_config { - info!( - "Serializing generation config to `{}`.", - gen_cfg_out.display() - ); - - let cfg = - std::fs::read(generation_config).map_err(candle_core::Error::msg)?; - std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?; - } - - if let Some(processor_config) = processor_filename { - info!( - "Serializing processor config to `{}`.", - processor_out.display() - ); - - let cfg = - std::fs::read(processor_config).map_err(candle_core::Error::msg)?; - std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?; - } + info!("Applying ISQ on {minimum_max_threads} threads."); - if let Some(preprocessor_config) = preprocessor_filename { - info!( - "Serializing preprocessor config to `{}`.", - preprocessor_out.display() - ); + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(minimum_max_threads) + .build() + .map_err(candle_core::Error::msg)?; - let cfg = - std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?; - std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?; - } - } - } - - #[cfg(feature = "metal")] - { - use indicatif::ProgressIterator; + pool.install(|| { + use indicatif::ParallelProgressIterator; + use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator, + }; if silent { - tensors.iter_mut().zip(devices_and_dtypes).for_each( + tensors.par_iter_mut().zip(devices_and_dtypes).for_each( |((tensor, _), (device, dtype))| { **tensor = tensor .clone() @@ -468,7 +283,7 @@ pub trait IsqModel { ); } else { tensors - .iter_mut() + .par_iter_mut() .zip(devices_and_dtypes) .progress_with(bar) .for_each(|((tensor, _), (device, dtype))| { @@ -479,33 +294,35 @@ pub trait IsqModel { device.synchronize().unwrap(); }); } + }); - if let Some(serialized) = write_artifacts { - info!( - "Serializing {total_tensors} ISQ tensors to `{}`.", - serialized.display() - ); + if let Some(serialized) = write_artifacts { + info!( + "Serializing {total_tensors} ISQ tensors to `{}`.", + serialized.display() + ); - if !serialized.extension().is_some_and(|ext| ext == "uqff") { - candle_core::bail!( - "UQFF output path extension must be {:?}", - serialized.extension().as_ref().unwrap() - ); - } + if !serialized.extension().is_some_and(|ext| ext == "uqff") { + candle_core::bail!("UQFF output path extension must be `.uqff`",); + } - let bar = ProgressBar::new(total_tensors as u64); - bar.set_style( - ProgressStyle::default_bar() - .template( - "[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})", - ) - .unwrap() - .progress_chars("#>-"), - ); + let bar = ProgressBar::new(total_tensors as u64); + bar.set_style( + ProgressStyle::default_bar() + .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})") + .unwrap() + .progress_chars("#>-"), + ); - let quantized_values = if silent { + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(2) + .build() + .map_err(candle_core::Error::msg)?; + + let quantized_values = pool.install(|| { + if silent { tensors - .iter() + .par_iter() .enumerate() .filter(|(_, (layer, _))| layer.isq_serde_supported()) .map(|(i, (layer, _))| { @@ -517,7 +334,7 @@ pub trait IsqModel { .collect::>>() } else { tensors - .iter() + .par_iter() .enumerate() .progress_with(bar) .filter(|(_, (layer, _))| layer.isq_serde_supported()) @@ -528,99 +345,99 @@ pub trait IsqModel { )) }) .collect::>>() - }; + } + }); - let parent = serialized - .parent() - .context("Target UQFF path must have a filename!")?; + let parent = serialized + .parent() + .context("Target UQFF path must have a filename!")?; - std::fs::create_dir_all(parent)?; + std::fs::create_dir_all(parent)?; - let residual = match organization { - IsqOrganization::Default => self.residual_tensors(), - IsqOrganization::MoeExpertsOnly => self - .residual_tensors_moe_experts_only() - .unwrap_or(self.residual_tensors()), - }; + safetensors::serialize_to_file(quantized_values?, &None, serialized)?; - let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS); - let config_out = parent.join("config.json"); - let tokenizer_out = parent.join("tokenizer.json"); - let tokenizer_cfg_out = parent.join("tokenizer_config.json"); - let gen_cfg_out = parent.join("generation_config.json"); - let processor_out = parent.join("processor_config.json"); - let preprocessor_out = parent.join("preprocessor_config.json"); + let residual = match organization { + IsqOrganization::Default => self.residual_tensors(), + IsqOrganization::MoeExpertsOnly => self + .residual_tensors_moe_experts_only() + .unwrap_or(self.residual_tensors()), + }; - info!( - "Serializing {} residual tensors to `{}`.", - residual.len(), - residual_out.display() - ); + let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS); + let config_out = parent.join("config.json"); + let tokenizer_out = parent.join("tokenizer.json"); + let tokenizer_cfg_out = parent.join("tokenizer_config.json"); + let gen_cfg_out = parent.join("generation_config.json"); + let processor_out = parent.join("processor_config.json"); + let preprocessor_out = parent.join("preprocessor_config.json"); - safetensors::serialize_to_file(residual, &None, &residual_out)?; + info!( + "Serializing {} residual tensors to `{}`.", + residual.len(), + residual_out.display() + ); - let UqffFullSer { - tokenizer, - template_filename, - generation_config, - config, - processor_filename, - preprocessor_filename, - } = full_ser; + safetensors::serialize_to_file(residual, &None, &residual_out)?; - info!("Serializing configuration to `{}`.", config_out.display()); + let UqffFullSer { + tokenizer, + template_filename, + generation_config, + config, + processor_filename, + preprocessor_filename, + } = full_ser; - std::fs::write(config_out, config)?; + info!("Serializing configuration to `{}`.", config_out.display()); - info!("Serializing tokenizer to `{}`.", tokenizer_out.display()); + std::fs::write(config_out, config)?; - serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer) - .map_err(candle_core::Error::msg)?; + info!("Serializing tokenizer to `{}`.", tokenizer_out.display()); - if let Some(template_filename) = template_filename { - info!( - "Serializing tokenizer config to `{}`.", - tokenizer_cfg_out.display() - ); + serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer) + .map_err(candle_core::Error::msg)?; - let template = - std::fs::read(template_filename).map_err(candle_core::Error::msg)?; - std::fs::write(&tokenizer_cfg_out, template) - .map_err(candle_core::Error::msg)?; - } + if let Some(template_filename) = template_filename { + info!( + "Serializing tokenizer config to `{}`.", + tokenizer_cfg_out.display() + ); - if let Some(generation_config) = generation_config { - info!( - "Serializing generation config to `{}`.", - gen_cfg_out.display() - ); + let template = + std::fs::read(template_filename).map_err(candle_core::Error::msg)?; + std::fs::write(&tokenizer_cfg_out, template) + .map_err(candle_core::Error::msg)?; + } - let cfg = - std::fs::read(generation_config).map_err(candle_core::Error::msg)?; - std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?; - } + if let Some(generation_config) = generation_config { + info!( + "Serializing generation config to `{}`.", + gen_cfg_out.display() + ); - if let Some(processor_config) = processor_filename { - info!( - "Serializing processor config to `{}`.", - processor_out.display() - ); + let cfg = std::fs::read(generation_config).map_err(candle_core::Error::msg)?; + std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?; + } - let cfg = - std::fs::read(processor_config).map_err(candle_core::Error::msg)?; - std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?; - } + if let Some(processor_config) = processor_filename { + info!( + "Serializing processor config to `{}`.", + processor_out.display() + ); - if let Some(preprocessor_config) = preprocessor_filename { - info!( - "Serializing preprocessor config to `{}`.", - preprocessor_out.display() - ); + let cfg = std::fs::read(processor_config).map_err(candle_core::Error::msg)?; + std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?; + } - let cfg = - std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?; - std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?; - } + if let Some(preprocessor_config) = preprocessor_filename { + info!( + "Serializing preprocessor config to `{}`.", + preprocessor_out.display() + ); + + let cfg = + std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?; + std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?; } } let delta = Instant::now().duration_since(t_start).as_secs_f32(); diff --git a/mistralrs-core/src/utils/progress.rs b/mistralrs-core/src/utils/progress.rs index 9e23201a12..85b4256e74 100644 --- a/mistralrs-core/src/utils/progress.rs +++ b/mistralrs-core/src/utils/progress.rs @@ -1,6 +1,3 @@ -use std::thread::JoinHandle; - -use either::Either; use indicatif::{ProgressBar, ProgressBarIter, ProgressIterator, ProgressStyle}; use tqdm::Iter; @@ -23,86 +20,6 @@ pub trait IterWithProgress<'a, T>: Iterator + 'a { impl<'a, T: Iterator + 'a> IterWithProgress<'a, T::Item> for T {} -/// Choose between threading or non-threading depending on if the `metal` -/// feature is enabled. -pub struct Parellelize; - -/// A handle which does not do threading. Instead, it always reports that is is -/// finished and executes the closure lazily. This is used for Metal -/// where the command buffer cannot be used concurrently. -pub struct NonThreadingHandle -where - F: FnOnce() -> T, - F: Send + 'static, - T: Send + 'static, -{ - f: F, -} - -impl NonThreadingHandle -where - F: FnOnce() -> T, - F: Send + 'static, - T: Send + 'static, -{ - fn join(self) -> std::thread::Result { - std::thread::Result::Ok((self.f)()) - } - fn is_finished(&self) -> bool { - true - } -} - -/// A trait representing a joinable handle. -pub trait Joinable { - fn join(self) -> std::thread::Result; - fn is_finished(&self) -> bool; -} - -impl Joinable for Either, NonThreadingHandle> -where - F: FnOnce() -> T, - F: Send + 'static, - T: Send + 'static, -{ - fn is_finished(&self) -> bool { - match self { - Self::Left(l) => l.is_finished(), - Self::Right(r) => r.is_finished(), - } - } - fn join(self) -> std::thread::Result { - match self { - Self::Left(l) => l.join(), - Self::Right(r) => r.join(), - } - } -} - -#[cfg(not(feature = "metal"))] -impl Parellelize { - pub fn spawn(f: F) -> Either, NonThreadingHandle> - where - F: FnOnce() -> T, - F: Send + 'static, - T: Send + 'static, - { - Either::Left(std::thread::spawn(f)) - } -} - -#[cfg(feature = "metal")] -impl Parellelize { - pub fn spawn(f: F) -> Either, NonThreadingHandle> - where - F: FnOnce() -> T, - F: Send + 'static, - T: Send + 'static, - { - Either::Right(NonThreadingHandle { f }) - } -} - /// Nice progress bar with over an iterator and a message. /// COLOR is one of r,g,b pub struct NiceProgressBar(pub T, pub &'static str); diff --git a/mistralrs-core/src/utils/varbuilder_utils.rs b/mistralrs-core/src/utils/varbuilder_utils.rs index 7502313ccb..67511d2039 100644 --- a/mistralrs-core/src/utils/varbuilder_utils.rs +++ b/mistralrs-core/src/utils/varbuilder_utils.rs @@ -1,6 +1,11 @@ //! Utilities for creating a VarBuilder from a VarMap loaded from tensor storage formats. -use std::{collections::HashMap, path::PathBuf, sync::Arc, thread::JoinHandle}; +use std::{ + collections::HashMap, + path::PathBuf, + sync::Arc, + thread::{self, JoinHandle}, +}; use candle_core::{ pickle::PthTensors, safetensors::MmapedSafetensors, DType, Device, Result, Tensor, @@ -9,15 +14,12 @@ use candle_nn::{ var_builder::{SimpleBackend, VarBuilderArgs}, VarBuilder, }; -use either::Either; use regex::Regex; use crate::lora::LoraConfig; use crate::utils::progress::IterWithProgress; use derive_new::new; -use super::progress::{Joinable, NonThreadingHandle, Parellelize}; - trait TensorLoaderBackend { fn get_names(&self) -> Vec; fn load_name(&self, name: &str, device: &Device, dtype: Option) -> Result; @@ -89,21 +91,13 @@ pub(crate) fn from_mmaped_safetensors<'a>( predicate: impl Fn(String) -> bool + Send + Sync + Clone + 'static, ) -> Result>> { #[allow(clippy::type_complexity)] - let mut handles: Vec< - Either< - JoinHandle>>, - NonThreadingHandle< - Result>, - Box Result> + Send + 'static>, - >, - >, - > = Vec::new(); + let mut handles: Vec>>> = Vec::new(); for path in paths { let device = device.clone(); if let Some(regexes) = make_dummy_regexes.clone() { let predicate = predicate.clone(); - handles.push(Parellelize::spawn(Box::new(move || { + handles.push(thread::spawn(Box::new(move || { let loader = Common::new(); loader.load_tensors_from_path(&path, &device, dtype, silent, predicate, |key| { regexes.iter().any(|r| r.is_match(key)) @@ -111,7 +105,7 @@ pub(crate) fn from_mmaped_safetensors<'a>( }))); } else { let predicate = predicate.clone(); - handles.push(Parellelize::spawn(Box::new(move || { + handles.push(thread::spawn(Box::new(move || { let loader = Common::new(); loader.load_tensors_from_path(&path, &device, dtype, silent, predicate, |_| false) }))); @@ -121,7 +115,7 @@ pub(crate) fn from_mmaped_safetensors<'a>( let device = device.clone(); if let Some(regexes) = make_dummy_regexes.clone() { let predicate = predicate.clone(); - handles.push(Parellelize::spawn(Box::new(move || { + handles.push(thread::spawn(Box::new(move || { let loader = XLora::new(i + 1); loader.load_tensors_from_path(&path, &device, dtype, silent, predicate, |key| { regexes.iter().any(|r| r.is_match(key)) @@ -129,7 +123,7 @@ pub(crate) fn from_mmaped_safetensors<'a>( }))); } else { let predicate = predicate.clone(); - handles.push(Parellelize::spawn(Box::new(move || { + handles.push(thread::spawn(Box::new(move || { let loader = XLora::new(i + 1); loader.load_tensors_from_path(&path, &device, dtype, silent, predicate, |_| false) })));