Skip to content

Commit

Permalink
add model name to response payloads
Browse files Browse the repository at this point in the history
  • Loading branch information
erhant committed Oct 3, 2024
1 parent 936a9b6 commit 6997c8f
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dkn-compute"
version = "0.2.7"
version = "0.2.8"
edition = "2021"
license = "Apache-2.0"
readme = "README.md"
Expand Down
13 changes: 9 additions & 4 deletions src/config/ollama.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use eyre::{eyre, Result};
use eyre::{eyre, Context, Result};
use ollama_workflows::{
ollama_rs::{
generation::{
Expand Down Expand Up @@ -63,9 +63,10 @@ impl OllamaConfig {
// Ollama workflows may require specific models to be loaded regardless of the choices
let hardcoded_models = HARDCODED_MODELS.iter().map(|s| s.to_string()).collect();

// auto-pull, its true by default
let auto_pull = std::env::var("OLLAMA_AUTO_PULL")
.map(|s| s == "true")
.unwrap_or_default();
.unwrap_or(true);

Self {
host,
Expand Down Expand Up @@ -109,7 +110,9 @@ impl OllamaConfig {
// we dont check workflows for hardcoded models
for model in &self.hardcoded_models {
if !local_models.contains(model) {
self.try_pull(&ollama, model.to_owned()).await?;
self.try_pull(&ollama, model.to_owned())
.await
.wrap_err("Could not pull model")?;
}
}

Expand All @@ -118,7 +121,9 @@ impl OllamaConfig {
let mut good_models = Vec::new();
for model in external_models {
if !local_models.contains(&model.to_string()) {
self.try_pull(&ollama, model.to_string()).await?;
self.try_pull(&ollama, model.to_string())
.await
.wrap_err("Could not pull model")?;
}

if self
Expand Down
3 changes: 1 addition & 2 deletions src/handlers/pingpong.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::ComputeHandler;
use crate::{
utils::{get_current_time_nanos, DKNMessage},
DriaComputeNode,
Expand All @@ -8,8 +9,6 @@ use libp2p::gossipsub::MessageAcceptance;
use ollama_workflows::{Model, ModelProvider};
use serde::{Deserialize, Serialize};

use super::ComputeHandler;

pub struct PingpongHandler;

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand Down
15 changes: 9 additions & 6 deletions src/handlers/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub struct WorkflowHandler;

#[derive(Debug, Deserialize)]
struct WorkflowPayload {
/// Workflow object to be parsed.
/// [Workflow](https://github.com/andthattoo/ollama-workflows/) object to be parsed.
pub(crate) workflow: Workflow,
/// A lıst of model (that can be parsed into `Model`) or model provider names.
/// If model provider is given, the first matching model in the node config is used for that.
Expand Down Expand Up @@ -69,7 +69,8 @@ impl ComputeHandler for WorkflowHandler {
let (model_provider, model) = config
.model_config
.get_any_matching_model(task.input.model)?;
log::info!("Using model {} for task {}", model, task.task_id);
let model_name = model.to_string(); // get model name, we will pass it in payload
log::info!("Using model {} for task {}", model_name, task.task_id);

// prepare workflow executor
let executor = if model_provider == ModelProvider::Ollama {
Expand Down Expand Up @@ -108,9 +109,10 @@ impl ComputeHandler for WorkflowHandler {
&task.task_id,
&task_public_key,
&config.secret_key,
model_name,
)?;
let payload_str =
serde_json::to_string(&payload).wrap_err("Could not serialize payload")?;
let payload_str = serde_json::to_string(&payload)
.wrap_err("Could not serialize response payload")?;

// publish the result
let message = DKNMessage::new(payload_str, Self::RESPONSE_TOPIC);
Expand All @@ -125,8 +127,9 @@ impl ComputeHandler for WorkflowHandler {
log::error!("Task {} failed: {}", task.task_id, err_string);

// prepare error payload
let error_payload = TaskErrorPayload::new(task.task_id, err_string);
let error_payload_str = serde_json::to_string(&error_payload)?;
let error_payload = TaskErrorPayload::new(task.task_id, err_string, model_name);
let error_payload_str = serde_json::to_string(&error_payload)
.wrap_err("Could not serialize error payload")?;

// publish the error result for diagnostics
let message = DKNMessage::new(error_payload_str, Self::RESPONSE_TOPIC);
Expand Down
12 changes: 9 additions & 3 deletions src/payloads/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@ pub struct TaskErrorPayload {
/// The unique identifier of the task.
pub task_id: String,
/// The stringified error object
pub(crate) error: String,
pub error: String,
/// Name of the model that caused the error.
pub model: String,
}

impl TaskErrorPayload {
pub fn new(task_id: String, error: String) -> Self {
Self { task_id, error }
pub fn new(task_id: String, error: String, model: String) -> Self {
Self {
task_id,
error,
model,
}
}
}
10 changes: 8 additions & 2 deletions src/payloads/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub struct TaskResponsePayload {
pub signature: String,
/// Result encrypted with the public key of the task, Hexadecimally encoded.
pub ciphertext: String,
/// Name of the model used for this task.
pub model: String,
}

impl TaskResponsePayload {
Expand All @@ -29,6 +31,7 @@ impl TaskResponsePayload {
task_id: &str,
encrypting_public_key: &PublicKey,
signing_secret_key: &SecretKey,
model: String,
) -> Result<Self> {
// create the message `task_id || payload`
let mut preimage = Vec::new();
Expand All @@ -43,6 +46,7 @@ impl TaskResponsePayload {
task_id,
signature,
ciphertext,
model,
})
}
}
Expand All @@ -58,6 +62,7 @@ mod tests {
fn test_task_response_payload() {
// this is the result that we are "sending"
const RESULT: &[u8; 44] = b"hey im an LLM and I came up with this output";
const MODEL: &str = "gpt-4-turbo";

// the signer will sign the payload, and it will be verified
let signer_sk = SecretKey::random(&mut thread_rng());
Expand All @@ -69,8 +74,9 @@ mod tests {
let task_id = uuid::Uuid::new_v4().to_string();

// creates a signed and encrypted payload
let payload = TaskResponsePayload::new(RESULT, &task_id, &task_pk, &signer_sk)
.expect("Should create payload");
let payload =
TaskResponsePayload::new(RESULT, &task_id, &task_pk, &signer_sk, MODEL.to_string())
.expect("Should create payload");

// decrypt result and compare it to plaintext
let ciphertext_bytes = hex::decode(payload.ciphertext).unwrap();
Expand Down

0 comments on commit 6997c8f

Please sign in to comment.