diff --git a/mistralrs-server/src/main.rs b/mistralrs-server/src/main.rs index 1f3ca4332..971bbbb32 100644 --- a/mistralrs-server/src/main.rs +++ b/mistralrs-server/src/main.rs @@ -149,6 +149,10 @@ struct Args { /// Number of tokens to batch the prompt step into. This can help with OOM errors when in the prompt step, but reduces performance. #[arg(long = "prompt-batchsize")] prompt_batchsize: Option, + + /// Use CPU only + #[arg(long)] + cpu: bool, } #[utoipa::path( @@ -314,7 +318,12 @@ async fn main() -> Result<()> { #[cfg(feature = "metal")] let device = Device::new_metal(0)?; #[cfg(not(feature = "metal"))] - let device = Device::cuda_if_available(0)?; + let device = if args.cpu { + args.no_paged_attn = true; + Device::Cpu + } else { + Device::cuda_if_available(0)? + }; if let Some(seed) = args.seed { device.set_seed(seed)?; @@ -464,7 +473,8 @@ async fn main() -> Result<()> { .with_opt_log(args.log) .with_truncate_sequence(args.truncate_sequence) .with_no_kv_cache(args.no_kv_cache) - .with_prefix_cache_n(args.prefix_cache_n); + .with_prefix_cache_n(args.prefix_cache_n) + .with_gemm_full_precision_f16(args.cpu); if args.interactive_mode { interactive_mode(builder.build(), args.throughput_log).await;