diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f1de726e..92441208 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -491,32 +491,39 @@ fn num_cuda_devices() -> Option { let n_devices = devices.split(',').count(); Some(n_devices) } + /// Finds a max sequence length for the model. In priority order: /// 1. MAX_SEQUENCE_LENGTH set by user /// 2. The sequence length specified in config.json -/// 3. A default of 2048 +/// 3. A default of 2048 +/// ### Arguments +/// * `max_sequence_length` - Optional user-defined maximum sequence length. +/// * `config_path` - Path to the model configuration file. +/// ### Returns +/// The effective maximum sequence length to be used. fn get_max_sequence_length(max_sequence_length: Option, config_path: &Path) -> usize { - if let Some(max_sequence_length) = max_sequence_length { - info!( - "Using configured max_sequence_length: {}", - max_sequence_length - ); - return max_sequence_length; - } + let mut length: Option = max_sequence_length; + let mut source = "user-defined"; + if let Ok(model_config) = get_config_json(config_path) { - if let Some(length) = get_max_sequence_length_from_config(&model_config) { - info!( - "Loaded max_sequence_length from model config.json: {}", - length - ); - return length; + if let Some(model_length) = get_max_sequence_length_from_config(&model_config) { + if length.is_some_and(|length| length > model_length) { + warn!("User-defined max_sequence_length ({}) is greater than the model's max_sequence_length ({})", + length.unwrap(), model_length + ); + } + length.get_or_insert_with(|| { + source = "model"; + model_length + }); } } - info!( - "Using default max_sequence_length: {}", + let result = length.unwrap_or_else(|| { + source = "default"; DEFAULT_MAX_SEQUENCE_LENGTH - ); - DEFAULT_MAX_SEQUENCE_LENGTH + }); + info!("Using {} max_sequence_length: {}", source, result); + return result; } /// Opens the model's config.json file and reads into serde_json value