Skip to content

Commit

Permalink
Merge branch 'main' into improve_seq_len_messages
Browse files Browse the repository at this point in the history
  • Loading branch information
maxdebayser authored Jun 28, 2024
2 parents cdbcfa7 + 009a2ba commit 48ed1c9
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,32 +491,39 @@ fn num_cuda_devices() -> Option<usize> {
let n_devices = devices.split(',').count();
Some(n_devices)
}

/// Finds a max sequence length for the model. In priority order:
/// 1. MAX_SEQUENCE_LENGTH set by user
/// 2. The sequence length specified in config.json
/// 3. A default of 2048
/// 3. A default of 2048
/// ### Arguments
/// * `max_sequence_length` - Optional user-defined maximum sequence length.
/// * `config_path` - Path to the model configuration file.
/// ### Returns
/// The effective maximum sequence length to be used.
fn get_max_sequence_length(max_sequence_length: Option<usize>, config_path: &Path) -> usize {
if let Some(max_sequence_length) = max_sequence_length {
info!(
"Using configured max_sequence_length: {}",
max_sequence_length
);
return max_sequence_length;
}
let mut length: Option<usize> = max_sequence_length;
let mut source = "user-defined";

if let Ok(model_config) = get_config_json(config_path) {
if let Some(length) = get_max_sequence_length_from_config(&model_config) {
info!(
"Loaded max_sequence_length from model config.json: {}",
length
);
return length;
if let Some(model_length) = get_max_sequence_length_from_config(&model_config) {
if length.is_some_and(|length| length > model_length) {
warn!("User-defined max_sequence_length ({}) is greater than the model's max_sequence_length ({})",
length.unwrap(), model_length
);
}
length.get_or_insert_with(|| {
source = "model";
model_length
});
}
}
info!(
"Using default max_sequence_length: {}",
let result = length.unwrap_or_else(|| {
source = "default";
DEFAULT_MAX_SEQUENCE_LENGTH
);
DEFAULT_MAX_SEQUENCE_LENGTH
});
info!("Using {} max_sequence_length: {}", source, result);
return result;
}

/// Opens the model's config.json file and reads into serde_json value
Expand Down

0 comments on commit 48ed1c9

Please sign in to comment.