Skip to content

Commit

Permalink
Merge pull request #43 from 0xPlaygrounds/fix/context-documents
Browse files Browse the repository at this point in the history
fix: move context documents to user prompt message
  • Loading branch information
cvauclair authored Oct 1, 2024
2 parents 2f1d553 + 9a42d45 commit 78ec4cb
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 31 deletions.
115 changes: 115 additions & 0 deletions rig-core/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,28 @@ pub struct Document {
pub additional_props: HashMap<String, String>,
}

impl std::fmt::Display for Document {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
concat!("<file id: {}>\n", "{}\n", "</file>\n"),
self.id,
if self.additional_props.is_empty() {
self.text.clone()
} else {
let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
sorted_props.sort_by(|a, b| a.0.cmp(b.0));
let metadata = sorted_props
.iter()
.map(|(k, v)| format!("{}: {:?}", k, v))
.collect::<Vec<_>>()
.join(" ");
format!("<metadata {} />\n{}", metadata, self.text)
}
)
}
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolDefinition {
pub name: String,
Expand Down Expand Up @@ -243,6 +265,24 @@ pub struct CompletionRequest {
pub additional_params: Option<serde_json::Value>,
}

impl CompletionRequest {
pub fn prompt_with_context(&self) -> String {
if !self.documents.is_empty() {
format!(
"<attachments>\n{}</attachments>\n\n{}",
self.documents
.iter()
.map(|doc| doc.to_string())
.collect::<Vec<_>>()
.join(""),
self.prompt
)
} else {
self.prompt.clone()
}
}
}

/// Builder struct for constructing a completion request.
///
/// Example usage:
Expand Down Expand Up @@ -432,3 +472,78 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
model.completion(self.build()).await
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_document_display_without_metadata() {
let doc = Document {
id: "123".to_string(),
text: "This is a test document.".to_string(),
additional_props: HashMap::new(),
};

let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
assert_eq!(format!("{}", doc), expected);
}

#[test]
fn test_document_display_with_metadata() {
let mut additional_props = HashMap::new();
additional_props.insert("author".to_string(), "John Doe".to_string());
additional_props.insert("length".to_string(), "42".to_string());

let doc = Document {
id: "123".to_string(),
text: "This is a test document.".to_string(),
additional_props,
};

let expected = concat!(
"<file id: 123>\n",
"<metadata author: \"John Doe\" length: \"42\" />\n",
"This is a test document.\n",
"</file>\n"
);
assert_eq!(format!("{}", doc), expected);
}

#[test]
fn test_prompt_with_context_with_documents() {
let doc1 = Document {
id: "doc1".to_string(),
text: "Document 1 text.".to_string(),
additional_props: HashMap::new(),
};

let doc2 = Document {
id: "doc2".to_string(),
text: "Document 2 text.".to_string(),
additional_props: HashMap::new(),
};

let request = CompletionRequest {
prompt: "What is the capital of France?".to_string(),
preamble: None,
chat_history: Vec::new(),
documents: vec![doc1, doc2],
tools: Vec::new(),
temperature: None,
max_tokens: None,
additional_params: None,
};

let expected = concat!(
"<attachments>\n",
"<file id: doc1>\nDocument 1 text.\n</file>\n",
"<file id: doc2>\nDocument 2 text.\n</file>\n",
"</attachments>\n\n",
"What is the capital of France?"
)
.to_string();

assert_eq!(request.prompt_with_context(), expected);
}
}
8 changes: 3 additions & 5 deletions rig-core/src/providers/anthropic/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,17 @@ impl completion::CompletionModel for CompletionModel {
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let prompt_with_context = completion_request.prompt_with_context();

let request = json!({
"model": self.model,
"messages": completion_request
.chat_history
.into_iter()
.map(Message::from)
.chain(completion_request.documents.into_iter().map(|doc| Message {
role: "system".to_owned(),
content: serde_json::to_string(&doc).expect("Document should serialize"),
}))
.chain(iter::once(Message {
role: "user".to_owned(),
content: completion_request.prompt,
content: prompt_with_context,
}))
.collect::<Vec<_>>(),
"max_tokens": completion_request.max_tokens,
Expand Down
18 changes: 4 additions & 14 deletions rig-core/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,26 +443,16 @@ impl completion::CompletionModel for CompletionModel {
vec![]
};

// Add context documents to chat history
full_history.append(
completion_request
.documents
.into_iter()
.map(|doc| completion::Message {
role: "system".into(),
content: serde_json::to_string(&doc).expect("Document should serialize"),
})
.collect::<Vec<_>>()
.as_mut(),
);
// Extend existing chat history
full_history.append(&mut completion_request.chat_history);

// Add context documents to chat history
full_history.append(&mut completion_request.chat_history);
let prompt_with_context = completion_request.prompt_with_context();

// Add context documents to chat history
full_history.push(completion::Message {
role: "user".into(),
content: completion_request.prompt,
content: prompt_with_context,
});

let request = if completion_request.tools.is_empty() {
Expand Down
14 changes: 2 additions & 12 deletions rig-core/src/providers/perplexity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,25 +203,15 @@ impl completion::CompletionModel for CompletionModel {
};

// Add context documents to chat history
messages.append(
completion_request
.documents
.into_iter()
.map(|doc| completion::Message {
role: "system".into(),
content: serde_json::to_string(&doc).expect("Document should serialize"),
})
.collect::<Vec<_>>()
.as_mut(),
);
let prompt_with_context = completion_request.prompt_with_context();

// Add chat history to messages
messages.extend(completion_request.chat_history);

// Add user prompt to messages
messages.push(completion::Message {
role: "user".to_string(),
content: completion_request.prompt,
content: prompt_with_context,
});

let request = json!({
Expand Down

0 comments on commit 78ec4cb

Please sign in to comment.