Skip to content

Commit

Permalink
Add --cpu flag to mistralrs-server
Browse files Browse the repository at this point in the history
  • Loading branch information
cdoko authored Dec 19, 2024
1 parent 0b4532c commit cceceac
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions mistralrs-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ struct Args {
/// Number of tokens to batch the prompt step into. This can help with OOM errors when in the prompt step, but reduces performance.
#[arg(long = "prompt-batchsize")]
prompt_batchsize: Option<usize>,

/// Use CPU only
#[arg(long)]
cpu: bool,
}

#[utoipa::path(
Expand Down Expand Up @@ -314,7 +318,12 @@ async fn main() -> Result<()> {
#[cfg(feature = "metal")]
let device = Device::new_metal(0)?;
#[cfg(not(feature = "metal"))]
let device = Device::cuda_if_available(0)?;
let device = if args.cpu {
args.no_paged_attn = true;
Device::Cpu
} else {
Device::cuda_if_available(0)?
};

if let Some(seed) = args.seed {
device.set_seed(seed)?;
Expand Down Expand Up @@ -464,7 +473,8 @@ async fn main() -> Result<()> {
.with_opt_log(args.log)
.with_truncate_sequence(args.truncate_sequence)
.with_no_kv_cache(args.no_kv_cache)
.with_prefix_cache_n(args.prefix_cache_n);
.with_prefix_cache_n(args.prefix_cache_n)
.with_gemm_full_precision_f16(args.cpu);

if args.interactive_mode {
interactive_mode(builder.build(), args.throughput_log).await;
Expand Down

0 comments on commit cceceac

Please sign in to comment.