diff --git a/Cargo.lock b/Cargo.lock index af448df..87a778b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -980,7 +980,7 @@ dependencies = [ [[package]] name = "dkn-compute" -version = "0.2.7" +version = "0.2.8" dependencies = [ "async-trait", "base64 0.22.1", diff --git a/Cargo.toml b/Cargo.toml index 126c5fc..7d468b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dkn-compute" -version = "0.2.7" +version = "0.2.8" edition = "2021" license = "Apache-2.0" readme = "README.md" diff --git a/src/config/ollama.rs b/src/config/ollama.rs index e720224..53ce2a3 100644 --- a/src/config/ollama.rs +++ b/src/config/ollama.rs @@ -1,4 +1,4 @@ -use eyre::{eyre, Result}; +use eyre::{eyre, Context, Result}; use ollama_workflows::{ ollama_rs::{ generation::{ @@ -63,9 +63,10 @@ impl OllamaConfig { // Ollama workflows may require specific models to be loaded regardless of the choices let hardcoded_models = HARDCODED_MODELS.iter().map(|s| s.to_string()).collect(); + // auto-pull, its true by default let auto_pull = std::env::var("OLLAMA_AUTO_PULL") .map(|s| s == "true") - .unwrap_or_default(); + .unwrap_or(true); Self { host, @@ -109,7 +110,9 @@ impl OllamaConfig { // we dont check workflows for hardcoded models for model in &self.hardcoded_models { if !local_models.contains(model) { - self.try_pull(&ollama, model.to_owned()).await?; + self.try_pull(&ollama, model.to_owned()) + .await + .wrap_err("Could not pull model")?; } } @@ -118,7 +121,9 @@ impl OllamaConfig { let mut good_models = Vec::new(); for model in external_models { if !local_models.contains(&model.to_string()) { - self.try_pull(&ollama, model.to_string()).await?; + self.try_pull(&ollama, model.to_string()) + .await + .wrap_err("Could not pull model")?; } if self diff --git a/src/handlers/pingpong.rs b/src/handlers/pingpong.rs index 28e9caf..53ab332 100644 --- a/src/handlers/pingpong.rs +++ b/src/handlers/pingpong.rs @@ -1,3 +1,4 @@ +use super::ComputeHandler; use crate::{ utils::{get_current_time_nanos, DKNMessage}, DriaComputeNode, @@ -8,8 +9,6 @@ use libp2p::gossipsub::MessageAcceptance; use ollama_workflows::{Model, ModelProvider}; use serde::{Deserialize, Serialize}; -use super::ComputeHandler; - pub struct PingpongHandler; #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/src/handlers/workflow.rs b/src/handlers/workflow.rs index 392afa2..41f2789 100644 --- a/src/handlers/workflow.rs +++ b/src/handlers/workflow.rs @@ -15,7 +15,7 @@ pub struct WorkflowHandler; #[derive(Debug, Deserialize)] struct WorkflowPayload { - /// Workflow object to be parsed. + /// [Workflow](https://github.com/andthattoo/ollama-workflows/) object to be parsed. pub(crate) workflow: Workflow, /// A lıst of model (that can be parsed into `Model`) or model provider names. /// If model provider is given, the first matching model in the node config is used for that. @@ -69,7 +69,8 @@ impl ComputeHandler for WorkflowHandler { let (model_provider, model) = config .model_config .get_any_matching_model(task.input.model)?; - log::info!("Using model {} for task {}", model, task.task_id); + let model_name = model.to_string(); // get model name, we will pass it in payload + log::info!("Using model {} for task {}", model_name, task.task_id); // prepare workflow executor let executor = if model_provider == ModelProvider::Ollama { @@ -108,9 +109,10 @@ impl ComputeHandler for WorkflowHandler { &task.task_id, &task_public_key, &config.secret_key, + model_name, )?; - let payload_str = - serde_json::to_string(&payload).wrap_err("Could not serialize payload")?; + let payload_str = serde_json::to_string(&payload) + .wrap_err("Could not serialize response payload")?; // publish the result let message = DKNMessage::new(payload_str, Self::RESPONSE_TOPIC); @@ -125,8 +127,9 @@ impl ComputeHandler for WorkflowHandler { log::error!("Task {} failed: {}", task.task_id, err_string); // prepare error payload - let error_payload = TaskErrorPayload::new(task.task_id, err_string); - let error_payload_str = serde_json::to_string(&error_payload)?; + let error_payload = TaskErrorPayload::new(task.task_id, err_string, model_name); + let error_payload_str = serde_json::to_string(&error_payload) + .wrap_err("Could not serialize error payload")?; // publish the error result for diagnostics let message = DKNMessage::new(error_payload_str, Self::RESPONSE_TOPIC); diff --git a/src/payloads/error.rs b/src/payloads/error.rs index 0fbaccf..1b4bfa1 100644 --- a/src/payloads/error.rs +++ b/src/payloads/error.rs @@ -8,11 +8,17 @@ pub struct TaskErrorPayload { /// The unique identifier of the task. pub task_id: String, /// The stringified error object - pub(crate) error: String, + pub error: String, + /// Name of the model that caused the error. + pub model: String, } impl TaskErrorPayload { - pub fn new(task_id: String, error: String) -> Self { - Self { task_id, error } + pub fn new(task_id: String, error: String, model: String) -> Self { + Self { + task_id, + error, + model, + } } } diff --git a/src/payloads/response.rs b/src/payloads/response.rs index 362648f..907596d 100644 --- a/src/payloads/response.rs +++ b/src/payloads/response.rs @@ -17,6 +17,8 @@ pub struct TaskResponsePayload { pub signature: String, /// Result encrypted with the public key of the task, Hexadecimally encoded. pub ciphertext: String, + /// Name of the model used for this task. + pub model: String, } impl TaskResponsePayload { @@ -29,6 +31,7 @@ impl TaskResponsePayload { task_id: &str, encrypting_public_key: &PublicKey, signing_secret_key: &SecretKey, + model: String, ) -> Result { // create the message `task_id || payload` let mut preimage = Vec::new(); @@ -43,6 +46,7 @@ impl TaskResponsePayload { task_id, signature, ciphertext, + model, }) } } @@ -58,6 +62,7 @@ mod tests { fn test_task_response_payload() { // this is the result that we are "sending" const RESULT: &[u8; 44] = b"hey im an LLM and I came up with this output"; + const MODEL: &str = "gpt-4-turbo"; // the signer will sign the payload, and it will be verified let signer_sk = SecretKey::random(&mut thread_rng()); @@ -69,8 +74,9 @@ mod tests { let task_id = uuid::Uuid::new_v4().to_string(); // creates a signed and encrypted payload - let payload = TaskResponsePayload::new(RESULT, &task_id, &task_pk, &signer_sk) - .expect("Should create payload"); + let payload = + TaskResponsePayload::new(RESULT, &task_id, &task_pk, &signer_sk, MODEL.to_string()) + .expect("Should create payload"); // decrypt result and compare it to plaintext let ciphertext_bytes = hex::decode(payload.ciphertext).unwrap();