Skip to content

Commit

Permalink
Better parsing of the image path (#906)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Nov 10, 2024
1 parent 84a66b9 commit 3fdf496
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 13 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ url = "2.5.2"
data-url = "0.3.1"
buildstructor = "0.5.4"
float8 = "0.1.1"
regex = "1.10.6"
2 changes: 1 addition & 1 deletion mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ mistralrs-quant = { version = "0.3.2", path = "../mistralrs-quant" }
uuid = { version = "1.10.0", features = ["v4"] }
schemars = "0.8.21"
serde_yaml = "0.9.34"
regex = "1.10.6"
regex.workspace = true
safetensors = "0.4.5"
serde_plain = "1.0.2"
as-any = "0.3.1"
Expand Down
1 change: 1 addition & 0 deletions mistralrs-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ reqwest.workspace = true
image.workspace = true
url.workspace = true
data-url.workspace = true
regex.workspace = true

[features]
cuda = ["mistralrs-core/cuda"]
Expand Down
135 changes: 123 additions & 12 deletions mistralrs-server/src/interactive_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use mistralrs_core::{
ResponseOk, SamplingParams, TERMINATE_ALL_NEXT_STEP,
};
use once_cell::sync::Lazy;
use regex::Regex;
use std::{
io::{self, Write},
sync::{atomic::Ordering, Arc, Mutex},
Expand Down Expand Up @@ -229,6 +230,32 @@ async fn text_interactive_mode(mistralrs: Arc<MistralRs>, throughput: bool) {
}
}

fn parse_image_path_and_message(input: &str) -> Option<(String, String)> {
// Regex to capture the image path and the following message
let re = Regex::new(r#"\\image\s+"([^"]+)"\s*(.*)|\\image\s+(\S+)\s*(.*)"#).unwrap();

if let Some(captures) = re.captures(input) {
// Capture either the quoted or unquoted path and the message
if let Some(path) = captures.get(1) {
if let Some(message) = captures.get(2) {
return Some((
path.as_str().trim().to_string(),
message.as_str().trim().to_string(),
));
}
} else if let Some(path) = captures.get(3) {
if let Some(message) = captures.get(4) {
return Some((
path.as_str().trim().to_string(),
message.as_str().trim().to_string(),
));
}
}
}

None
}

async fn vision_interactive_mode(mistralrs: Arc<MistralRs>, throughput: bool) {
let sender = mistralrs.get_sender().unwrap();
let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
Expand Down Expand Up @@ -312,22 +339,13 @@ async fn vision_interactive_mode(mistralrs: Arc<MistralRs>, throughput: bool) {
continue;
}
prompt if prompt.trim().starts_with(IMAGE_CMD) => {
let mut parts = prompt.trim().strip_prefix(IMAGE_CMD).unwrap().split(' ');
// No space??
if !parts.next().unwrap().is_empty() {
let Some((url, message)) = parse_image_path_and_message(prompt.trim()) else {
println!("Error: Adding an image message should be done with this format: `{IMAGE_CMD} path/to/image.jpg Describe what is in this image.`");
}
let url = match parts.next() {
Some(p) => p.trim(),
None => {
println!("Error: Adding an image message should be done with this format: `{IMAGE_CMD} path/to/image.jpg Describe what is in this image.`");
continue;
}
continue;
};
let message = parts.collect::<Vec<_>>().join(" ");
let message = prefixer.prefix_image(images.len(), &message);

let image = util::parse_image_url(url)
let image = util::parse_image_url(&url)
.await
.expect("Failed to read image from URL/path");
images.push(image);
Expand Down Expand Up @@ -504,3 +522,96 @@ async fn diffusion_interactive_mode(mistralrs: Arc<MistralRs>) {
println!();
}
}

#[cfg(test)]
mod tests {
use super::parse_image_path_and_message;

#[test]
fn test_parse_image_with_unquoted_path_and_message() {
let input = r#"\image image.jpg What is this"#;
let result = parse_image_path_and_message(input);
assert_eq!(
result,
Some(("image.jpg".to_string(), "What is this".to_string()))
);
}

#[test]
fn test_parse_image_with_quoted_path_and_message() {
let input = r#"\image "image name.jpg" What is this?"#;
let result = parse_image_path_and_message(input);
assert_eq!(
result,
Some(("image name.jpg".to_string(), "What is this?".to_string()))
);
}

#[test]
fn test_parse_image_with_only_unquoted_path() {
let input = r#"\image image.jpg"#;
let result = parse_image_path_and_message(input);
assert_eq!(result, Some(("image.jpg".to_string(), "".to_string())));
}

#[test]
fn test_parse_image_with_only_quoted_path() {
let input = r#"\image "image name.jpg""#;
let result = parse_image_path_and_message(input);
assert_eq!(result, Some(("image name.jpg".to_string(), "".to_string())));
}

#[test]
fn test_parse_image_with_extra_spaces() {
let input = r#"\image "image with spaces.jpg" This is a test message with spaces "#;
let result = parse_image_path_and_message(input);
assert_eq!(
result,
Some((
"image with spaces.jpg".to_string(),
"This is a test message with spaces".to_string()
))
);
}

#[test]
fn test_parse_image_with_no_message() {
let input = r#"\image "image.jpg""#;
let result = parse_image_path_and_message(input);
assert_eq!(result, Some(("image.jpg".to_string(), "".to_string())));
}

#[test]
fn test_parse_image_missing_path() {
let input = r#"\image"#;
let result = parse_image_path_and_message(input);
assert_eq!(result, None);
}

#[test]
fn test_parse_image_invalid_command() {
let input = r#"\img "image.jpg" This should fail"#;
let result = parse_image_path_and_message(input);
assert_eq!(result, None);
}

#[test]
fn test_parse_image_with_non_image_text() {
let input = r#"Some random text without command"#;
let result = parse_image_path_and_message(input);
assert_eq!(result, None);
}

#[test]
fn test_parse_image_with_path_and_message_special_chars() {
let input = r#"\image "path with special chars @#$%^&*().jpg" This is a message with special chars !@#$%^&*()"#;
let result = parse_image_path_and_message(input);
assert_eq!(
result,
Some((
"path with special chars @#$%^&*().jpg".to_string(),
"This is a message with special chars !@#$%^&*()".to_string()
))
);
}
}

0 comments on commit 3fdf496

Please sign in to comment.