diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index 7c4d27e4..4383e561 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -120,6 +120,28 @@ pub struct Document { pub additional_props: HashMap, } +impl std::fmt::Display for Document { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + concat!("\n", "{}\n", "\n"), + self.id, + if self.additional_props.is_empty() { + self.text.clone() + } else { + let mut sorted_props = self.additional_props.iter().collect::>(); + sorted_props.sort_by(|a, b| a.0.cmp(b.0)); + let metadata = sorted_props + .iter() + .map(|(k, v)| format!("{}: {:?}", k, v)) + .collect::>() + .join(" "); + format!("\n{}", metadata, self.text) + } + ) + } +} + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ToolDefinition { pub name: String, @@ -243,6 +265,24 @@ pub struct CompletionRequest { pub additional_params: Option, } +impl CompletionRequest { + pub fn prompt_with_context(&self) -> String { + if !self.documents.is_empty() { + format!( + "\n{}\n\n{}", + self.documents + .iter() + .map(|doc| doc.to_string()) + .collect::>() + .join(""), + self.prompt + ) + } else { + self.prompt.clone() + } + } +} + /// Builder struct for constructing a completion request. /// /// Example usage: @@ -432,3 +472,78 @@ impl CompletionRequestBuilder { model.completion(self.build()).await } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_document_display_without_metadata() { + let doc = Document { + id: "123".to_string(), + text: "This is a test document.".to_string(), + additional_props: HashMap::new(), + }; + + let expected = "\nThis is a test document.\n\n"; + assert_eq!(format!("{}", doc), expected); + } + + #[test] + fn test_document_display_with_metadata() { + let mut additional_props = HashMap::new(); + additional_props.insert("author".to_string(), "John Doe".to_string()); + additional_props.insert("length".to_string(), "42".to_string()); + + let doc = Document { + id: "123".to_string(), + text: "This is a test document.".to_string(), + additional_props, + }; + + let expected = concat!( + "\n", + "\n", + "This is a test document.\n", + "\n" + ); + assert_eq!(format!("{}", doc), expected); + } + + #[test] + fn test_prompt_with_context_with_documents() { + let doc1 = Document { + id: "doc1".to_string(), + text: "Document 1 text.".to_string(), + additional_props: HashMap::new(), + }; + + let doc2 = Document { + id: "doc2".to_string(), + text: "Document 2 text.".to_string(), + additional_props: HashMap::new(), + }; + + let request = CompletionRequest { + prompt: "What is the capital of France?".to_string(), + preamble: None, + chat_history: Vec::new(), + documents: vec![doc1, doc2], + tools: Vec::new(), + temperature: None, + max_tokens: None, + additional_params: None, + }; + + let expected = concat!( + "\n", + "\nDocument 1 text.\n\n", + "\nDocument 2 text.\n\n", + "\n\n", + "What is the capital of France?" + ) + .to_string(); + + assert_eq!(request.prompt_with_context(), expected); + } +} diff --git a/rig-core/src/providers/anthropic/completion.rs b/rig-core/src/providers/anthropic/completion.rs index ce4ce967..fab42544 100644 --- a/rig-core/src/providers/anthropic/completion.rs +++ b/rig-core/src/providers/anthropic/completion.rs @@ -157,19 +157,17 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { + let prompt_with_context = completion_request.prompt_with_context(); + let request = json!({ "model": self.model, "messages": completion_request .chat_history .into_iter() .map(Message::from) - .chain(completion_request.documents.into_iter().map(|doc| Message { - role: "system".to_owned(), - content: serde_json::to_string(&doc).expect("Document should serialize"), - })) .chain(iter::once(Message { role: "user".to_owned(), - content: completion_request.prompt, + content: prompt_with_context, })) .collect::>(), "max_tokens": completion_request.max_tokens, diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index a7312a9c..6a5119a5 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -443,26 +443,16 @@ impl completion::CompletionModel for CompletionModel { vec![] }; - // Add context documents to chat history - full_history.append( - completion_request - .documents - .into_iter() - .map(|doc| completion::Message { - role: "system".into(), - content: serde_json::to_string(&doc).expect("Document should serialize"), - }) - .collect::>() - .as_mut(), - ); + // Extend existing chat history + full_history.append(&mut completion_request.chat_history); // Add context documents to chat history - full_history.append(&mut completion_request.chat_history); + let prompt_with_context = completion_request.prompt_with_context(); // Add context documents to chat history full_history.push(completion::Message { role: "user".into(), - content: completion_request.prompt, + content: prompt_with_context, }); let request = if completion_request.tools.is_empty() { diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index 7f9509bd..e072c9b6 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -203,17 +203,7 @@ impl completion::CompletionModel for CompletionModel { }; // Add context documents to chat history - messages.append( - completion_request - .documents - .into_iter() - .map(|doc| completion::Message { - role: "system".into(), - content: serde_json::to_string(&doc).expect("Document should serialize"), - }) - .collect::>() - .as_mut(), - ); + let prompt_with_context = completion_request.prompt_with_context(); // Add chat history to messages messages.extend(completion_request.chat_history); @@ -221,7 +211,7 @@ impl completion::CompletionModel for CompletionModel { // Add user prompt to messages messages.push(completion::Message { role: "user".to_string(), - content: completion_request.prompt, + content: prompt_with_context, }); let request = json!({