From 1f2f15dda0964893505a5ece89ce5f77f163092d Mon Sep 17 00:00:00 2001 From: erhant Date: Mon, 7 Oct 2024 11:11:53 +0300 Subject: [PATCH 01/14] subcrates, p2p done, todo models --- Cargo.lock | 30 +- Cargo.toml | 71 +--- compute/Cargo.toml | 64 ++++ Cross.toml => compute/Cross.toml | 0 Dockerfile => compute/Dockerfile | 0 compose.yml => compute/compose.yml | 0 {src => compute/src}/config/mod.rs | 2 +- {src => compute/src}/config/models.rs | 0 {src => compute/src}/config/ollama.rs | 0 {src => compute/src}/config/openai.rs | 10 +- {src => compute/src}/handlers/mod.rs | 2 +- {src => compute/src}/handlers/pingpong.rs | 2 +- {src => compute/src}/handlers/workflow.rs | 2 +- {src => compute/src}/lib.rs | 1 - {src => compute/src}/main.rs | 0 {src => compute/src}/node.rs | 17 +- {src => compute/src}/payloads/error.rs | 0 {src => compute/src}/payloads/mod.rs | 0 {src => compute/src}/payloads/request.rs | 0 {src => compute/src}/payloads/response.rs | 0 {src => compute/src}/utils/available_nodes.rs | 2 +- {src => compute/src}/utils/crypto.rs | 6 +- {src => compute/src}/utils/filter.rs | 0 {src => compute/src}/utils/message.rs | 4 +- {src => compute/src}/utils/mod.rs | 2 +- p2p/Cargo.toml | 31 ++ {src/p2p => p2p/src}/behaviour.rs | 2 +- {src/p2p => p2p/src}/client.rs | 18 +- {src/p2p => p2p/src}/data_transform.rs | 0 src/p2p/mod.rs => p2p/src/lib.rs | 5 + {src/p2p => p2p/src}/versioning.rs | 0 workflows/Cargo.toml | 22 ++ workflows/README.md | 7 + workflows/src/lib.rs | 22 ++ workflows/src/models.rs | 251 ++++++++++++++ workflows/src/ollama.rs | 309 ++++++++++++++++++ workflows/src/openai.rs | 111 +++++++ 37 files changed, 894 insertions(+), 99 deletions(-) create mode 100644 compute/Cargo.toml rename Cross.toml => compute/Cross.toml (100%) rename Dockerfile => compute/Dockerfile (100%) rename compose.yml => compute/compose.yml (100%) rename {src => compute/src}/config/mod.rs (99%) rename {src => compute/src}/config/models.rs (100%) rename {src => compute/src}/config/ollama.rs (100%) rename {src => compute/src}/config/openai.rs (95%) rename {src => compute/src}/handlers/mod.rs (95%) rename {src => compute/src}/handlers/pingpong.rs (97%) rename {src => compute/src}/handlers/workflow.rs (99%) rename {src => compute/src}/lib.rs (95%) rename {src => compute/src}/main.rs (100%) rename {src => compute/src}/node.rs (96%) rename {src => compute/src}/payloads/error.rs (100%) rename {src => compute/src}/payloads/mod.rs (100%) rename {src => compute/src}/payloads/request.rs (100%) rename {src => compute/src}/payloads/response.rs (100%) rename {src => compute/src}/utils/available_nodes.rs (99%) rename {src => compute/src}/utils/crypto.rs (96%) rename {src => compute/src}/utils/filter.rs (100%) rename {src => compute/src}/utils/message.rs (97%) rename {src => compute/src}/utils/mod.rs (97%) create mode 100644 p2p/Cargo.toml rename {src/p2p => p2p/src}/behaviour.rs (98%) rename {src/p2p => p2p/src}/client.rs (96%) rename {src/p2p => p2p/src}/data_transform.rs (100%) rename src/p2p/mod.rs => p2p/src/lib.rs (75%) rename {src/p2p => p2p/src}/versioning.rs (100%) create mode 100644 workflows/Cargo.toml create mode 100644 workflows/README.md create mode 100644 workflows/src/lib.rs create mode 100644 workflows/src/models.rs create mode 100644 workflows/src/ollama.rs create mode 100644 workflows/src/openai.rs diff --git a/Cargo.lock b/Cargo.lock index 01ccd84..9586fc5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -984,6 +984,7 @@ version = "0.2.10" dependencies = [ "async-trait", "base64 0.22.1", + "dkn-p2p", "dotenvy", "ecies", "env_logger 0.11.5", @@ -991,8 +992,6 @@ dependencies = [ "fastbloom-rs", "hex", "hex-literal", - "libp2p", - "libp2p-identity", "libsecp256k1", "log", "ollama-workflows", @@ -1014,6 +1013,33 @@ dependencies = [ "uuid", ] +[[package]] +name = "dkn-p2p" +version = "0.1.0" +dependencies = [ + "env_logger 0.11.5", + "eyre", + "libp2p", + "libp2p-identity", + "log", +] + +[[package]] +name = "dkn-workflows" +version = "0.1.0" +dependencies = [ + "async-trait", + "eyre", + "log", + "ollama-workflows", + "parking_lot", + "reqwest 0.12.8", + "serde", + "serde_json", + "tokio 1.40.0", + "tokio-util 0.7.12", +] + [[package]] name = "dotenv" version = "0.15.0" diff --git a/Cargo.toml b/Cargo.toml index 14bfc4f..fe0a39a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,21 +1,15 @@ -[package] -name = "dkn-compute" -version = "0.2.10" +[workspace] +resolver = "2" +members = ["compute", "p2p", "workflows"] +default-members = ["compute"] + +[workspace.package] edition = "2021" license = "Apache-2.0" -readme = "README.md" -authors = ["Erhan Tezcan "] - -# profiling build for flamegraphs -[profile.profiling] -inherits = "release" -debug = true -[features] -# used by flamegraphs & instruments -profiling = [] +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html -[dependencies] +[workspace.dependencies] tokio-util = { version = "0.7.10", features = ["rt"] } tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal"] } parking_lot = "0.12.2" @@ -24,55 +18,6 @@ serde_json = "1.0" async-trait = "0.1.81" reqwest = "0.12.5" -# utilities -dotenvy = "0.15.7" -base64 = "0.22.0" -hex = "0.4.3" -hex-literal = "0.4.1" -url = "2.5.0" -urlencoding = "2.1.3" -uuid = { version = "1.8.0", features = ["v4"] } -rand = "0.8.5" - -# logging & errors env_logger = "0.11.3" log = "0.4.21" eyre = "0.6.12" - -# encryption (ecies) & signatures (ecdsa) & hashing & bloom-filters -ecies = { version = "0.2", default-features = false, features = ["pure"] } -libsecp256k1 = "0.7.1" -sha2 = "0.10.8" -sha3 = "0.10.8" -fastbloom-rs = "0.5.9" - -# workflows -ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows" } - -# peer-to-peer -libp2p = { git = "https://github.com/anilaltuner/rust-libp2p.git", rev = "7ce9f9e", features = [ - # libp2p = { version = "0.54.1", features = [ - "dcutr", - "ping", - "relay", - "autonat", - "identify", - "tokio", - "gossipsub", - "mdns", - "noise", - "macros", - "tcp", - "yamux", - "quic", - "kad", -] } -libp2p-identity = { version = "0.2.9", features = ["secp256k1"] } -tracing = { version = "0.1.40" } -tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } -port_check = "0.2.1" - -# Vendor OpenSSL so that its easier to build cross-platform packages -[dependencies.openssl] -version = "*" -features = ["vendored"] diff --git a/compute/Cargo.toml b/compute/Cargo.toml new file mode 100644 index 0000000..eab1f98 --- /dev/null +++ b/compute/Cargo.toml @@ -0,0 +1,64 @@ +[package] +name = "dkn-compute" +version = "0.2.10" +edition.workspace = true +license.workspace = true +readme = "README.md" +authors = [ + "Erhan Tezcan ", + "Anil Altuner , + #[allow(unused)] object: String, } @@ -35,7 +35,7 @@ pub struct OpenAIConfig { impl OpenAIConfig { /// Looks at the environment variables for OpenAI API key. pub fn new() -> Self { - let api_key = std::env::var(OPENAI_API_KEY).ok(); + let api_key = std::env::var("OPENAI_API_KEY").ok(); Self { api_key } } diff --git a/src/handlers/mod.rs b/compute/src/handlers/mod.rs similarity index 95% rename from src/handlers/mod.rs rename to compute/src/handlers/mod.rs index edb018e..00ccf51 100644 --- a/src/handlers/mod.rs +++ b/compute/src/handlers/mod.rs @@ -1,7 +1,7 @@ use crate::{utils::DKNMessage, DriaComputeNode}; use async_trait::async_trait; +use dkn_p2p::libp2p::gossipsub::MessageAcceptance; use eyre::Result; -use libp2p::gossipsub::MessageAcceptance; mod pingpong; pub use pingpong::PingpongHandler; diff --git a/src/handlers/pingpong.rs b/compute/src/handlers/pingpong.rs similarity index 97% rename from src/handlers/pingpong.rs rename to compute/src/handlers/pingpong.rs index 53ab332..f328865 100644 --- a/src/handlers/pingpong.rs +++ b/compute/src/handlers/pingpong.rs @@ -4,8 +4,8 @@ use crate::{ DriaComputeNode, }; use async_trait::async_trait; +use dkn_p2p::libp2p::gossipsub::MessageAcceptance; use eyre::{Context, Result}; -use libp2p::gossipsub::MessageAcceptance; use ollama_workflows::{Model, ModelProvider}; use serde::{Deserialize, Serialize}; diff --git a/src/handlers/workflow.rs b/compute/src/handlers/workflow.rs similarity index 99% rename from src/handlers/workflow.rs rename to compute/src/handlers/workflow.rs index 41f2789..c15abbc 100644 --- a/src/handlers/workflow.rs +++ b/compute/src/handlers/workflow.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; +use dkn_p2p::libp2p::gossipsub::MessageAcceptance; use eyre::{eyre, Context, Result}; -use libp2p::gossipsub::MessageAcceptance; use libsecp256k1::PublicKey; use ollama_workflows::{Entry, Executor, ModelProvider, ProgramMemory, Workflow}; use serde::Deserialize; diff --git a/src/lib.rs b/compute/src/lib.rs similarity index 95% rename from src/lib.rs rename to compute/src/lib.rs index adc2b65..d924a68 100644 --- a/src/lib.rs +++ b/compute/src/lib.rs @@ -4,7 +4,6 @@ pub(crate) mod config; pub(crate) mod handlers; pub(crate) mod node; -pub(crate) mod p2p; pub(crate) mod payloads; pub(crate) mod utils; diff --git a/src/main.rs b/compute/src/main.rs similarity index 100% rename from src/main.rs rename to compute/src/main.rs diff --git a/src/node.rs b/compute/src/node.rs similarity index 96% rename from src/node.rs rename to compute/src/node.rs index e824f1c..177fde2 100644 --- a/src/node.rs +++ b/compute/src/node.rs @@ -1,12 +1,11 @@ +use dkn_p2p::{libp2p::gossipsub, P2PClient}; use eyre::{eyre, Result}; -use libp2p::gossipsub; use std::time::Duration; use tokio_util::sync::CancellationToken; use crate::{ config::*, handlers::*, - p2p::P2PClient, utils::{crypto::secret_to_keypair, AvailableNodes, DKNMessage}, }; @@ -52,7 +51,12 @@ impl DriaComputeNode { ) .sort_dedup(); - let p2p = P2PClient::new(keypair, config.p2p_listen_addr.clone(), &available_nodes)?; + let p2p = P2PClient::new( + keypair, + config.p2p_listen_addr.clone(), + &available_nodes.bootstrap_nodes, + &available_nodes.relay_nodes, + )?; Ok(DriaComputeNode { p2p, @@ -97,7 +101,12 @@ impl DriaComputeNode { /// Returns the list of connected peers. #[inline(always)] - pub fn peers(&self) -> Vec<(&libp2p_identity::PeerId, Vec<&gossipsub::TopicHash>)> { + pub fn peers( + &self, + ) -> Vec<( + &dkn_p2p::libp2p_identity::PeerId, + Vec<&gossipsub::TopicHash>, + )> { self.p2p.peers() } diff --git a/src/payloads/error.rs b/compute/src/payloads/error.rs similarity index 100% rename from src/payloads/error.rs rename to compute/src/payloads/error.rs diff --git a/src/payloads/mod.rs b/compute/src/payloads/mod.rs similarity index 100% rename from src/payloads/mod.rs rename to compute/src/payloads/mod.rs diff --git a/src/payloads/request.rs b/compute/src/payloads/request.rs similarity index 100% rename from src/payloads/request.rs rename to compute/src/payloads/request.rs diff --git a/src/payloads/response.rs b/compute/src/payloads/response.rs similarity index 100% rename from src/payloads/response.rs rename to compute/src/payloads/response.rs diff --git a/src/utils/available_nodes.rs b/compute/src/utils/available_nodes.rs similarity index 99% rename from src/utils/available_nodes.rs rename to compute/src/utils/available_nodes.rs index 2215be1..dada7de 100644 --- a/src/utils/available_nodes.rs +++ b/compute/src/utils/available_nodes.rs @@ -1,5 +1,5 @@ +use dkn_p2p::libp2p::{Multiaddr, PeerId}; use eyre::Result; -use libp2p::{Multiaddr, PeerId}; use std::{env, fmt::Debug, str::FromStr}; use crate::utils::split_comma_separated; diff --git a/src/utils/crypto.rs b/compute/src/utils/crypto.rs similarity index 96% rename from src/utils/crypto.rs rename to compute/src/utils/crypto.rs index 86bc376..38eb6ac 100644 --- a/src/utils/crypto.rs +++ b/compute/src/utils/crypto.rs @@ -1,6 +1,6 @@ +use dkn_p2p::libp2p_identity::Keypair; use ecies::PublicKey; use eyre::{Context, Result}; -use libp2p_identity::Keypair; use libsecp256k1::{Message, SecretKey}; use sha2::{Digest, Sha256}; use sha3::Keccak256; @@ -58,9 +58,9 @@ pub fn encrypt_bytes(data: impl AsRef<[u8]>, public_key: &PublicKey) -> Result Keypair { let bytes = secret_key.serialize(); - let secret_key = libp2p_identity::secp256k1::SecretKey::try_from_bytes(bytes) + let secret_key = dkn_p2p::libp2p_identity::secp256k1::SecretKey::try_from_bytes(bytes) .expect("Failed to create secret key"); - libp2p_identity::secp256k1::Keypair::from(secret_key).into() + dkn_p2p::libp2p_identity::secp256k1::Keypair::from(secret_key).into() } #[cfg(test)] diff --git a/src/utils/filter.rs b/compute/src/utils/filter.rs similarity index 100% rename from src/utils/filter.rs rename to compute/src/utils/filter.rs diff --git a/src/utils/message.rs b/compute/src/utils/message.rs similarity index 97% rename from src/utils/message.rs rename to compute/src/utils/message.rs index c7613e2..f05f026 100644 --- a/src/utils/message.rs +++ b/compute/src/utils/message.rs @@ -123,10 +123,10 @@ impl fmt::Display for DKNMessage { } } -impl TryFrom for DKNMessage { +impl TryFrom for DKNMessage { type Error = serde_json::Error; - fn try_from(value: libp2p::gossipsub::Message) -> Result { + fn try_from(value: dkn_p2p::libp2p::gossipsub::Message) -> Result { serde_json::from_slice(&value.data) } } diff --git a/src/utils/mod.rs b/compute/src/utils/mod.rs similarity index 97% rename from src/utils/mod.rs rename to compute/src/utils/mod.rs index 23606ae..f24ffc5 100644 --- a/src/utils/mod.rs +++ b/compute/src/utils/mod.rs @@ -7,7 +7,7 @@ pub use message::DKNMessage; mod available_nodes; pub use available_nodes::AvailableNodes; -use libp2p::{multiaddr::Protocol, Multiaddr}; +use dkn_p2p::libp2p::{multiaddr::Protocol, Multiaddr}; use port_check::is_port_reachable; use std::{ net::{Ipv4Addr, SocketAddrV4}, diff --git a/p2p/Cargo.toml b/p2p/Cargo.toml new file mode 100644 index 0000000..272ec60 --- /dev/null +++ b/p2p/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "dkn-p2p" +version = "0.1.0" +edition.workspace = true +license.workspace = true +readme = "README.md" +authors = ["Erhan Tezcan "] + +[dependencies] +libp2p = { git = "https://github.com/anilaltuner/rust-libp2p.git", rev = "7ce9f9e", features = [ + # libp2p = { version = "0.54.1", features = [ + "dcutr", + "ping", + "relay", + "autonat", + "identify", + "tokio", + "gossipsub", + "mdns", + "noise", + "macros", + "tcp", + "yamux", + "quic", + "kad", +] } +libp2p-identity = { version = "0.2.9", features = ["secp256k1"] } + +env_logger.workspace = true +log.workspace = true +eyre.workspace = true diff --git a/src/p2p/behaviour.rs b/p2p/src/behaviour.rs similarity index 98% rename from src/p2p/behaviour.rs rename to p2p/src/behaviour.rs index b35127e..b28f742 100644 --- a/src/p2p/behaviour.rs +++ b/p2p/src/behaviour.rs @@ -6,7 +6,7 @@ use libp2p::identity::{Keypair, PublicKey}; use libp2p::kad::store::MemoryStore; use libp2p::{autonat, dcutr, gossipsub, identify, kad, relay, swarm::NetworkBehaviour, PeerId}; -use crate::p2p::{P2P_KADEMLIA_PROTOCOL, P2P_PROTOCOL_STRING}; +use super::{P2P_KADEMLIA_PROTOCOL, P2P_PROTOCOL_STRING}; #[derive(NetworkBehaviour)] pub struct DriaBehaviour { diff --git a/src/p2p/client.rs b/p2p/src/client.rs similarity index 96% rename from src/p2p/client.rs rename to p2p/src/client.rs index e389206..f2e7861 100644 --- a/src/p2p/client.rs +++ b/p2p/src/client.rs @@ -13,7 +13,6 @@ use std::time::Duration; use std::time::Instant; use super::*; -use crate::utils::AvailableNodes; /// P2P client, exposes a simple interface to handle P2P communication. pub struct P2PClient { @@ -39,7 +38,8 @@ impl P2PClient { pub fn new( keypair: Keypair, listen_addr: Multiaddr, - available_nodes: &AvailableNodes, + bootstraps: &[Multiaddr], + relays: &[Multiaddr], ) -> Result { // this is our peerId let node_peerid = keypair.public().to_peer_id(); @@ -67,11 +67,8 @@ impl P2PClient { .set_mode(Some(libp2p::kad::Mode::Server)); // initiate bootstrap - log::info!( - "Initiating bootstrap: {:#?}", - available_nodes.bootstrap_nodes - ); - for addr in &available_nodes.bootstrap_nodes { + log::info!("Initiating bootstrap: {:#?}", bootstraps); + for addr in bootstraps { if let Some(peer_id) = addr.iter().find_map(|p| match p { Protocol::P2p(peer_id) => Some(peer_id), _ => None, @@ -101,11 +98,8 @@ impl P2PClient { log::info!("Listening p2p network on: {}", listen_addr); swarm.listen_on(listen_addr)?; - log::info!( - "Listening to relay nodes: {:#?}", - available_nodes.relay_nodes - ); - for addr in &available_nodes.relay_nodes { + log::info!("Listening to relay nodes: {:#?}", relays); + for addr in relays { swarm.listen_on(addr.clone().with(Protocol::P2pCircuit))?; } diff --git a/src/p2p/data_transform.rs b/p2p/src/data_transform.rs similarity index 100% rename from src/p2p/data_transform.rs rename to p2p/src/data_transform.rs diff --git a/src/p2p/mod.rs b/p2p/src/lib.rs similarity index 75% rename from src/p2p/mod.rs rename to p2p/src/lib.rs index 1e4c09c..a28267f 100644 --- a/src/p2p/mod.rs +++ b/p2p/src/lib.rs @@ -8,3 +8,8 @@ mod versioning; pub use versioning::*; mod data_transform; + +// re-exports + +pub use libp2p; +pub use libp2p_identity; diff --git a/src/p2p/versioning.rs b/p2p/src/versioning.rs similarity index 100% rename from src/p2p/versioning.rs rename to p2p/src/versioning.rs diff --git a/workflows/Cargo.toml b/workflows/Cargo.toml new file mode 100644 index 0000000..bf0e686 --- /dev/null +++ b/workflows/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "dkn-workflows" +version = "0.1.0" +edition.workspace = true +license.workspace = true +readme = "README.md" +authors = ["Erhan Tezcan "] + + +[dependencies] +tokio-util.workspace = true +tokio.workspace = true +parking_lot.workspace = true +serde.workspace = true +serde_json.workspace = true +async-trait.workspace = true +reqwest.workspace = true + +log.workspace = true +eyre.workspace = true + +ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows" } diff --git a/workflows/README.md b/workflows/README.md new file mode 100644 index 0000000..1b9aa76 --- /dev/null +++ b/workflows/README.md @@ -0,0 +1,7 @@ +# DKN Workflows + +We make use of Ollama Workflows in DKN; however, we also want to make sure that the chosen models are valid and is performant enough (i.e. have enough TPS). +This crate handles the configurations of models to be used, and implements various service checks. + +- **OpenAI**: We check that the chosen models are enabled for the user's profile by fetching their models with their API key. We filter out the disabled models. +- **Ollama**: We provide a sample workflow to measure TPS and then pick models that are above some TPS threshold. While calculating TPS, there is also a timeout so that beyond that timeout the TPS is not even considered and the model becomes invalid. diff --git a/workflows/src/lib.rs b/workflows/src/lib.rs new file mode 100644 index 0000000..4524aac --- /dev/null +++ b/workflows/src/lib.rs @@ -0,0 +1,22 @@ +use async_trait::async_trait; +use eyre::Result; + +mod models; +pub use models::ModelConfig; + +/// Ollama configurations & service checks +mod ollama; +pub(crate) use ollama::OllamaConfig; + +/// OpenAI configurations & service checks +mod openai; +pub(crate) use openai::OpenAIConfig; + +/// Extension trait for model providers to check if they are ready, and describe themselves. +#[async_trait] +pub trait ProvidersExt { + const PROVIDER_NAME: &str; + + /// Ensures that the required provider is online & ready. + async fn check_service(&self) -> Result<()>; +} diff --git a/workflows/src/models.rs b/workflows/src/models.rs new file mode 100644 index 0000000..4b0a7d8 --- /dev/null +++ b/workflows/src/models.rs @@ -0,0 +1,251 @@ +use crate::{utils::split_comma_separated, OllamaConfig, OpenAIConfig}; +use eyre::{eyre, Result}; +use ollama_workflows::{Model, ModelProvider}; +use rand::seq::IteratorRandom; // provides Vec<_>.choose + +#[derive(Debug, Clone)] +pub struct ModelConfig { + pub models: Vec<(ModelProvider, Model)>, + pub ollama: OllamaConfig, + pub openai: OpenAIConfig, +} + +impl std::fmt::Display for ModelConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let models_str = self + .models + .iter() + .map(|(provider, model)| format!("{:?}:{}", provider, model)) + .collect::>() + .join(","); + write!(f, "{}", models_str) + } +} + +impl ModelConfig { + /// Creates a new config with the given list of models. + pub fn new(models: Vec) -> Self { + // map models to (provider, model) pairs + let models_providers = models + .into_iter() + .map(|m| (m.clone().into(), m)) + .collect::>(); + + let mut providers = Vec::new(); + + // get ollama models & config + let ollama_models = models_providers + .iter() + .filter_map(|(p, m)| { + if *p == ModelProvider::Ollama { + Some(m.clone()) + } else { + None + } + }) + .collect::>(); + let ollama_config = if !ollama_models.is_empty() { + providers.push(ModelProvider::Ollama); + Some(OllamaConfig::new(ollama_models)) + } else { + None + }; + + // get openai models & config + let openai_models = models_providers + .iter() + .filter_map(|(p, m)| { + if *p == ModelProvider::OpenAI { + Some(m.clone()) + } else { + None + } + }) + .collect::>(); + let openai_config = if !openai_models.is_empty() { + providers.push(ModelProvider::OpenAI); + Some(OpenAIConfig::new(openai_models)) + } else { + None + }; + + Self { + models_providers, + providers, + ollama_config, + openai_config, + } + } + + /// Parses Ollama-Workflows compatible models from a comma-separated values string. + /// + /// ## Example + /// + /// ``` + /// let config = ModelConfig::new_from_csv("gpt-4-turbo,gpt-4o-mini"); + /// ``` + pub fn new_from_csv(input: Option) -> Self { + let models_str = split_comma_separated(input); + + let models = models_str + .into_iter() + .filter_map(|s| match Model::try_from(s) { + Ok(model) => Some((model.clone().into(), model)), + Err(e) => { + log::warn!("Error parsing model: {}", e); + None + } + }) + .collect::>(); + + Self { models } + } + + /// Returns the models that belong to a given providers from the config. + pub fn get_models_for_provider(&self, provider: ModelProvider) -> Vec { + self.models + .iter() + .filter_map(|(p, m)| { + if *p == provider { + Some(m.clone()) + } else { + None + } + }) + .collect() + } + + /// Given a raw model name or provider (as a string), returns the first matching model & provider. + /// + /// If this is a model and is supported by this node, it is returned directly. + /// If this is a provider, the first matching model in the node config is returned. + /// + /// If there are no matching models with this logic, an error is returned. + pub fn get_matching_model(&self, model_or_provider: String) -> Result<(ModelProvider, Model)> { + if let Ok(provider) = ModelProvider::try_from(model_or_provider.clone()) { + // this is a valid provider, return the first matching model in the config + self.models + .iter() + .find(|(p, _)| *p == provider) + .ok_or(eyre!( + "Provider {} is not supported by this node.", + provider + )) + .cloned() + } else if let Ok(model) = Model::try_from(model_or_provider.clone()) { + // this is a valid model, return it if it is supported by the node + self.models + .iter() + .find(|(_, m)| *m == model) + .ok_or(eyre!("Model {} is not supported by this node.", model)) + .cloned() + } else { + // this is neither a valid provider or model for this node + Err(eyre!( + "Given string '{}' is neither a model nor provider.", + model_or_provider + )) + } + } + + /// From a list of model or provider names, return a random matching model & provider. + pub fn get_any_matching_model( + &self, + list_model_or_provider: Vec, + ) -> Result<(ModelProvider, Model)> { + // filter models w.r.t supported ones + let matching_models = list_model_or_provider + .into_iter() + .filter_map(|model_or_provider| { + let result = self.get_matching_model(model_or_provider); + match result { + Ok(result) => Some(result), + Err(e) => { + log::debug!("Ignoring model: {}", e); + None + } + } + }) + .collect::>(); + + // choose random model + matching_models + .into_iter() + .choose(&mut rand::thread_rng()) + .ok_or(eyre!("No matching models found.")) + } + + /// Returns the list of unique providers in the config. + pub fn get_providers(&self) -> Vec { + self.models + .iter() + .fold(Vec::new(), |mut unique, (provider, _)| { + if !unique.contains(provider) { + unique.push(provider.clone()); + } + unique + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_csv_parser() { + let cfg = + ModelConfig::new_from_csv(Some("idontexist,i dont either,i332287648762".to_string())); + assert_eq!(cfg.models.len(), 0); + + let cfg = ModelConfig::new_from_csv(Some( + "gemma2:9b-instruct-q8_0,phi3:14b-medium-4k-instruct-q4_1,balblablabl".to_string(), + )); + assert_eq!(cfg.models.len(), 2); + } + + #[test] + fn test_model_matching() { + let cfg = ModelConfig::new_from_csv(Some("gpt-4o,llama3.1:latest".to_string())); + assert_eq!( + cfg.get_matching_model("openai".to_string()).unwrap().1, + Model::GPT4o, + "Should find existing model" + ); + + assert_eq!( + cfg.get_matching_model("llama3.1:latest".to_string()) + .unwrap() + .1, + Model::Llama3_1_8B, + "Should find existing model" + ); + + assert!( + cfg.get_matching_model("gpt-4o-mini".to_string()).is_err(), + "Should not find anything for unsupported model" + ); + + assert!( + cfg.get_matching_model("praise the model".to_string()) + .is_err(), + "Should not find anything for inexisting model" + ); + } + + #[test] + fn test_get_any_matching_model() { + let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,llama3.1:latest".to_string())); + let result = cfg.get_any_matching_model(vec![ + "i-dont-exist".to_string(), + "llama3.1:latest".to_string(), + "gpt-4o".to_string(), + "ollama".to_string(), + ]); + assert_eq!( + result.unwrap().1, + Model::Llama3_1_8B, + "Should find existing model" + ); + } +} diff --git a/workflows/src/ollama.rs b/workflows/src/ollama.rs new file mode 100644 index 0000000..22b85be --- /dev/null +++ b/workflows/src/ollama.rs @@ -0,0 +1,309 @@ +use eyre::{eyre, Context, Result}; +use ollama_workflows::{ + ollama_rs::{ + generation::{ + completion::request::GenerationRequest, + embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}, + options::GenerationOptions, + }, + Ollama, + }, + Model, +}; +use std::time::Duration; + +const DEFAULT_OLLAMA_HOST: &str = "http://127.0.0.1"; +const DEFAULT_OLLAMA_PORT: u16 = 11434; + +/// Some models such as small embedding models, are hardcoded into the node. +const HARDCODED_MODELS: [&str; 1] = ["hellord/mxbai-embed-large-v1:f16"]; + +/// Prompt to be used to see Ollama performance. +const TEST_PROMPT: &str = "Please write a poem about Kapadokya."; + +/// Ollama-specific configurations. +#[derive(Debug, Clone)] +pub struct OllamaConfig { + /// Host, usually `http://127.0.0.1`. + pub(crate) host: String, + /// Port, usually `11434`. + pub(crate) port: u16, + /// List of hardcoded models that are internally used by Ollama workflows. + hardcoded_models: Vec, + /// List of external models that are picked by the user. + pub(crate) models: Vec, + /// Whether to automatically pull models from Ollama. + /// This is useful for CI/CD workflows. + auto_pull: bool, +} + +impl Default for OllamaConfig { + fn default() -> Self { + Self { + host: DEFAULT_OLLAMA_HOST.to_string(), + port: DEFAULT_OLLAMA_PORT, + hardcoded_models: HARDCODED_MODELS + .into_iter() + .map(|s| s.to_string()) + .collect(), + auto_pull: false, + } + } +} +impl OllamaConfig { + /// Looks at the environment variables for Ollama host and port. + /// + /// If not found, defaults to `DEFAULT_OLLAMA_HOST` and `DEFAULT_OLLAMA_PORT`. + pub fn new() -> Self { + let host = std::env::var("OLLAMA_HOST") + .map(|h| h.trim_matches('"').to_string()) + .unwrap_or(DEFAULT_OLLAMA_HOST.to_string()); + let port = std::env::var("OLLAMA_PORT") + .and_then(|port_str| port_str.parse().map_err(|_| std::env::VarError::NotPresent)) + .unwrap_or(DEFAULT_OLLAMA_PORT); + + // Ollama workflows may require specific models to be loaded regardless of the choices + let hardcoded_models = HARDCODED_MODELS.iter().map(|s| s.to_string()).collect(); + + // auto-pull, its true by default + let auto_pull = std::env::var("OLLAMA_AUTO_PULL") + .map(|s| s == "true") + .unwrap_or(true); + + Self { + host, + port, + hardcoded_models, + auto_pull, + } + } + + /// Check if requested models exist in Ollama, and then tests them using a workflow. + pub async fn check( + &self, + external_models: Vec, + timeout: Duration, + min_tps: f64, + ) -> Result> { + log::info!( + "Checking Ollama requirements (auto-pull {}, workflow timeout: {}s)", + if self.auto_pull { "on" } else { "off" }, + timeout.as_secs() + ); + + let ollama = Ollama::new(&self.host, self.port); + + // fetch local models + let local_models = match ollama.list_local_models().await { + Ok(models) => models.into_iter().map(|m| m.name).collect::>(), + Err(e) => { + return { + log::error!("Could not fetch local models from Ollama, is it online?"); + Err(e.into()) + } + } + }; + log::info!("Found local Ollama models: {:#?}", local_models); + + // check hardcoded models & pull them if available + // these are not used directly by the user, but are needed for the workflows + log::debug!("Checking hardcoded models: {:#?}", self.hardcoded_models); + // only check if model is contained in local_models + // we dont check workflows for hardcoded models + for model in &self.hardcoded_models { + if !local_models.contains(model) { + self.try_pull(&ollama, model.to_owned()) + .await + .wrap_err("Could not pull model")?; + } + } + + // check external models & pull them if available + // and also run a test workflow for them + let mut good_models = Vec::new(); + for model in external_models { + if !local_models.contains(&model.to_string()) { + self.try_pull(&ollama, model.to_string()) + .await + .wrap_err("Could not pull model")?; + } + + if self + .test_performance(&ollama, &model, timeout, min_tps) + .await + { + good_models.push(model); + } + } + + log::info!( + "Ollama checks are finished, using models: {:#?}", + good_models + ); + Ok(good_models) + } + + /// Pulls a model if `auto_pull` exists, otherwise returns an error. + async fn try_pull(&self, ollama: &Ollama, model: String) -> Result<()> { + log::warn!("Model {} not found in Ollama", model); + if self.auto_pull { + // if auto-pull is enabled, pull the model + log::info!( + "Downloading missing model {} (this may take a while)", + model + ); + let status = ollama.pull_model(model, false).await?; + log::debug!("Pulled model with Ollama, final status: {:#?}", status); + Ok(()) + } else { + // otherwise, give error + log::error!("Please download missing model with: ollama pull {}", model); + log::error!("Or, set OLLAMA_AUTO_PULL=true to pull automatically."); + Err(eyre!("Required model not pulled in Ollama.")) + } + } + + /// Runs a small workflow to test Ollama Workflows. + /// + /// This is to see if a given system can execute Ollama workflows for their chosen models, + /// e.g. if they have enough RAM/CPU and such. + pub async fn test_performance( + &self, + ollama: &Ollama, + model: &Model, + timeout: Duration, + min_tps: f64, + ) -> bool { + log::info!("Testing model {}", model); + + // first generate a dummy embedding to load the model into memory (warm-up) + let request = GenerateEmbeddingsRequest::new( + model.to_string(), + EmbeddingsInput::Single("embedme".into()), + ); + if let Err(err) = ollama.generate_embeddings(request).await { + log::error!("Failed to generate embedding for model {}: {}", model, err); + return false; + }; + + let mut generation_request = + GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string()); + + // FIXME: temporary workaround, can take num threads from outside + if let Ok(num_thread) = std::env::var("OLLAMA_NUM_THREAD") { + generation_request = generation_request.options( + GenerationOptions::default().num_thread( + num_thread + .parse() + .expect("num threads should be a positive integer"), + ), + ); + } + + // then, run a sample generation with timeout and measure tps + tokio::select! { + _ = tokio::time::sleep(timeout) => { + log::warn!("Ignoring model {}: Workflow timed out", model); + }, + result = ollama.generate(generation_request) => { + match result { + Ok(response) => { + let tps = (response.eval_count.unwrap_or_default() as f64) + / (response.eval_duration.unwrap_or(1) as f64) + * 1_000_000_000f64; + + if tps >= min_tps { + log::info!("Model {} passed the test with tps: {}", model, tps); + return true; + } + + log::warn!( + "Ignoring model {}: tps too low ({:.3} < {:.3})", + model, + tps, + min_tps + ); + } + Err(e) => { + log::warn!("Ignoring model {}: Workflow failed with error {}", model, e); + } + } + } + }; + + false + } +} + +#[cfg(test)] +mod tests { + use ollama_workflows::ollama_rs::{generation::completion::request::GenerationRequest, Ollama}; + use ollama_workflows::{Executor, Model, ProgramMemory, Workflow}; + + #[tokio::test] + #[ignore = "run this manually"] + async fn test_ollama_prompt() { + let model = Model::default().to_string(); + let ollama = Ollama::default(); + ollama.pull_model(model.clone(), false).await.unwrap(); + let prompt = "The sky appears blue during the day because of a process called scattering. \ + When sunlight enters the Earth's atmosphere, it collides with air molecules such as oxygen and nitrogen. \ + These collisions cause some of the light to be absorbed or reflected, which makes the colors we see appear more vivid and vibrant. \ + Blue is one of the brightest colors that is scattered the most by the atmosphere, making it visible to our eyes during the day. \ + What may be the question this answer?".to_string(); + + let response = ollama + .generate(GenerationRequest::new(model, prompt.clone())) + .await + .expect("Should generate response"); + println!("Prompt: {}\n\nResponse:{}", prompt, response.response); + } + + #[tokio::test] + #[ignore = "run this manually"] + async fn test_ollama_workflow() { + let workflow = r#"{ + "name": "Simple", + "description": "This is a simple workflow", + "config": { + "max_steps": 5, + "max_time": 100, + }, + "tasks":[ + { + "id": "A", + "name": "Random Poem", + "description": "Writes a poem about Kapadokya.", + "prompt": "Please write a poem about Kapadokya.", + "operator": "generation", + "outputs": [ + { + "type": "write", + "key": "final_result", + "value": "__result" + } + ] + }, + { + "id": "__end", + "name": "end", + "description": "End of the task", + "prompt": "End of the task", + "operator": "end", + } + ], + "steps":[ + { + "source":"A", + "target":"end" + } + ] + }"#; + let workflow: Workflow = serde_json::from_str(workflow).unwrap(); + let exe = Executor::new(Model::default()); + let mut memory = ProgramMemory::new(); + + let result = exe.execute(None, workflow, &mut memory).await; + println!("Result: {}", result.unwrap()); + } +} diff --git a/workflows/src/openai.rs b/workflows/src/openai.rs new file mode 100644 index 0000000..75e7e8b --- /dev/null +++ b/workflows/src/openai.rs @@ -0,0 +1,111 @@ +use eyre::{eyre, Context, Result}; +use ollama_workflows::Model; +use serde::Deserialize; + +const OPENAI_MODELS_API: &str = "https://api.openai.com/v1/models"; + +/// [Model](https://platform.openai.com/docs/api-reference/models/object) API object. +#[derive(Debug, Clone, Deserialize)] +struct OpenAIModel { + /// The model identifier, which can be referenced in the API endpoints. + id: String, + /// The Unix timestamp (in seconds) when the model was created. + #[allow(unused)] + created: u64, + /// The object type, which is always "model". + #[allow(unused)] + object: String, + /// The organization that owns the model. + #[allow(unused)] + owned_by: String, +} + +#[derive(Debug, Clone, Deserialize)] +struct OpenAIModelsResponse { + data: Vec, + #[allow(unused)] + object: String, +} + +#[derive(Debug, Clone, Default)] +pub struct OpenAIConfig { + /// List of external models that are picked by the user. + pub(crate) models: Vec, +} + +impl OpenAIConfig { + /// Looks at the environment variables for OpenAI API key. + pub fn new() -> Self { + let api_key = std::env::var("OPENAI_API_KEY").ok(); + + Self { api_key } + } + + /// Check if requested models exist & are available in the OpenAI account. + pub async fn check(&self, models: Vec) -> Result> { + log::info!("Checking OpenAI requirements"); + + // check API key + let Some(api_key) = &self.api_key else { + return Err(eyre!("OpenAI API key not found")); + }; + + // fetch models + let client = reqwest::Client::new(); + let request = client + .get(OPENAI_MODELS_API) + .header("Authorization", format!("Bearer {}", api_key)) + .build() + .wrap_err("Failed to build request")?; + + let response = client + .execute(request) + .await + .wrap_err("Failed to send request")?; + + // parse response + if response.status().is_client_error() { + return Err(eyre!( + "Failed to fetch OpenAI models:\n{}", + response.text().await.unwrap_or_default() + )); + } + let openai_models = response.json::().await?; + + // check if models exist and select those that are available + let mut available_models = Vec::new(); + for requested_model in models { + if !openai_models + .data + .iter() + .any(|m| m.id == requested_model.to_string()) + { + log::warn!( + "Model {} not found in your OpenAI account, ignoring it.", + requested_model + ); + } else { + available_models.push(requested_model); + } + } + + log::info!( + "OpenAI checks are finished, using models: {:#?}", + available_models + ); + Ok(available_models) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + #[ignore = "requires OpenAI API key"] + async fn test_openai_check() { + let config = OpenAIConfig::new(); + let res = config.check(vec![]).await; + println!("Result: {}", res.unwrap_err()); + } +} From e69292e4f799c6b1a803da01c24bc5e5693f82c0 Mon Sep 17 00:00:00 2001 From: erhant Date: Mon, 7 Oct 2024 12:54:33 +0300 Subject: [PATCH 02/14] better model parsing added as subcrate --- .env.example | 2 +- Cargo.lock | 1 + Cargo.toml | 2 +- workflows/Cargo.toml | 3 +- workflows/src/{models.rs => config.rs} | 169 +++++++++++++----------- workflows/src/lib.rs | 26 +--- workflows/src/providers/mod.rs | 5 + workflows/src/{ => providers}/ollama.rs | 76 +++++------ workflows/src/{ => providers}/openai.rs | 9 +- workflows/src/utils.rs | 34 +++++ workflows/tests/models_test.rs | 14 ++ 11 files changed, 191 insertions(+), 150 deletions(-) rename workflows/src/{models.rs => config.rs} (64%) create mode 100644 workflows/src/providers/mod.rs rename workflows/src/{ => providers}/ollama.rs (84%) rename workflows/src/{ => providers}/openai.rs (94%) create mode 100644 workflows/src/utils.rs create mode 100644 workflows/tests/models_test.rs diff --git a/.env.example b/.env.example index 1f34c69..2b1d1ae 100644 --- a/.env.example +++ b/.env.example @@ -10,7 +10,7 @@ DKN_ADMIN_PUBLIC_KEY=0208ef5e65a9c656a6f92fb2c770d5d5e2ecffe02a6aade19207f75110b DKN_MODELS= ## DRIA (optional) ## -# P2P address, you don't need to change this unless you really want this port. +# P2P address, you don't need to change this unless this port is already in use. DKN_P2P_LISTEN_ADDR=/ip4/0.0.0.0/tcp/4001 # Comma-separated static relay nodes DKN_RELAY_NODES= diff --git a/Cargo.lock b/Cargo.lock index 9586fc5..f4f061a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1033,6 +1033,7 @@ dependencies = [ "log", "ollama-workflows", "parking_lot", + "rand 0.8.5", "reqwest 0.12.8", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index fe0a39a..a453253 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" async-trait = "0.1.81" reqwest = "0.12.5" - +rand = "0.8.5" env_logger = "0.11.3" log = "0.4.21" eyre = "0.6.12" diff --git a/workflows/Cargo.toml b/workflows/Cargo.toml index bf0e686..e374969 100644 --- a/workflows/Cargo.toml +++ b/workflows/Cargo.toml @@ -15,8 +15,9 @@ serde.workspace = true serde_json.workspace = true async-trait.workspace = true reqwest.workspace = true - +rand.workspace = true log.workspace = true eyre.workspace = true +# ollama-rs is re-exported from ollama-workflows ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows" } diff --git a/workflows/src/models.rs b/workflows/src/config.rs similarity index 64% rename from workflows/src/models.rs rename to workflows/src/config.rs index 4b0a7d8..2afb480 100644 --- a/workflows/src/models.rs +++ b/workflows/src/config.rs @@ -1,89 +1,21 @@ -use crate::{utils::split_comma_separated, OllamaConfig, OpenAIConfig}; +use crate::{split_comma_separated, OllamaConfig, OpenAIConfig}; use eyre::{eyre, Result}; use ollama_workflows::{Model, ModelProvider}; use rand::seq::IteratorRandom; // provides Vec<_>.choose #[derive(Debug, Clone)] pub struct ModelConfig { + /// List of models with their providers. pub models: Vec<(ModelProvider, Model)>, + /// Even if Ollama is not used, we store the host & port here. + /// If Ollama is used, this config will be respected during its instantiations. pub ollama: OllamaConfig, + /// OpenAI API key & its service check implementation. pub openai: OpenAIConfig, } -impl std::fmt::Display for ModelConfig { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let models_str = self - .models - .iter() - .map(|(provider, model)| format!("{:?}:{}", provider, model)) - .collect::>() - .join(","); - write!(f, "{}", models_str) - } -} - impl ModelConfig { - /// Creates a new config with the given list of models. - pub fn new(models: Vec) -> Self { - // map models to (provider, model) pairs - let models_providers = models - .into_iter() - .map(|m| (m.clone().into(), m)) - .collect::>(); - - let mut providers = Vec::new(); - - // get ollama models & config - let ollama_models = models_providers - .iter() - .filter_map(|(p, m)| { - if *p == ModelProvider::Ollama { - Some(m.clone()) - } else { - None - } - }) - .collect::>(); - let ollama_config = if !ollama_models.is_empty() { - providers.push(ModelProvider::Ollama); - Some(OllamaConfig::new(ollama_models)) - } else { - None - }; - - // get openai models & config - let openai_models = models_providers - .iter() - .filter_map(|(p, m)| { - if *p == ModelProvider::OpenAI { - Some(m.clone()) - } else { - None - } - }) - .collect::>(); - let openai_config = if !openai_models.is_empty() { - providers.push(ModelProvider::OpenAI); - Some(OpenAIConfig::new(openai_models)) - } else { - None - }; - - Self { - models_providers, - providers, - ollama_config, - openai_config, - } - } - /// Parses Ollama-Workflows compatible models from a comma-separated values string. - /// - /// ## Example - /// - /// ``` - /// let config = ModelConfig::new_from_csv("gpt-4-turbo,gpt-4o-mini"); - /// ``` pub fn new_from_csv(input: Option) -> Self { let models_str = split_comma_separated(input); @@ -98,7 +30,11 @@ impl ModelConfig { }) .collect::>(); - Self { models } + Self { + models, + openai: OpenAIConfig::new(), + ollama: OllamaConfig::new(), + } } /// Returns the models that belong to a given providers from the config. @@ -117,12 +53,27 @@ impl ModelConfig { /// Given a raw model name or provider (as a string), returns the first matching model & provider. /// - /// If this is a model and is supported by this node, it is returned directly. - /// If this is a provider, the first matching model in the node config is returned. + /// - If input is `*` or `all`, a random model is returned. + /// - if input is `!` the first model is returned. + /// - If input is a model and is supported by this node, it is returned directly. + /// - If input is a provider, the first matching model in the node config is returned. /// /// If there are no matching models with this logic, an error is returned. pub fn get_matching_model(&self, model_or_provider: String) -> Result<(ModelProvider, Model)> { - if let Ok(provider) = ModelProvider::try_from(model_or_provider.clone()) { + if model_or_provider == "*" { + // return a random model + self.models + .iter() + .choose(&mut rand::thread_rng()) + .ok_or_else(|| eyre!("No models to randomly pick for '*'.")) + .cloned() + } else if model_or_provider == "!" { + // return the first model + self.models + .first() + .ok_or_else(|| eyre!("No models to choose first for '!'.")) + .cloned() + } else if let Ok(provider) = ModelProvider::try_from(model_or_provider.clone()) { // this is a valid provider, return the first matching model in the config self.models .iter() @@ -186,6 +137,70 @@ impl ModelConfig { unique }) } + + /// Check if the required compute services are running. + /// This has several steps: + /// + /// - If Ollama models are used, hardcoded models are checked locally, and for + /// external models, the workflow is tested with a simple task with timeout. + /// - If OpenAI models are used, the API key is checked and the models are tested + /// + /// If both type of models are used, both services are checked. + /// In the end, bad models are filtered out and we simply check if we are left if any valid models at all. + /// If not, an error is returned. + pub async fn check_services(&mut self) -> Result<()> { + log::info!("Checking configured services."); + + // TODO: can refactor (provider, model) logic here + let unique_providers = self.get_providers(); + + let mut good_models = Vec::new(); + + // if Ollama is a provider, check that it is running & Ollama models are pulled (or pull them) + if unique_providers.contains(&ModelProvider::Ollama) { + let ollama_models = self.get_models_for_provider(ModelProvider::Ollama); + + // ensure that the models are pulled / pull them if not + let good_ollama_models = self.ollama.check(ollama_models).await?; + good_models.extend( + good_ollama_models + .into_iter() + .map(|m| (ModelProvider::Ollama, m)), + ); + } + + // if OpenAI is a provider, check that the API key is set + if unique_providers.contains(&ModelProvider::OpenAI) { + let openai_models = self.get_models_for_provider(ModelProvider::OpenAI); + + let good_openai_models = self.openai.check(openai_models).await?; + good_models.extend( + good_openai_models + .into_iter() + .map(|m| (ModelProvider::OpenAI, m)), + ); + } + + // update good models + if good_models.is_empty() { + Err(eyre!("No good models found, please check logs for errors.")) + } else { + self.models = good_models; + Ok(()) + } + } +} + +impl std::fmt::Display for ModelConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let models_str = self + .models + .iter() + .map(|(provider, model)| format!("{:?}:{}", provider, model)) + .collect::>() + .join(","); + write!(f, "{}", models_str) + } } #[cfg(test)] diff --git a/workflows/src/lib.rs b/workflows/src/lib.rs index 4524aac..96f9856 100644 --- a/workflows/src/lib.rs +++ b/workflows/src/lib.rs @@ -1,22 +1,8 @@ -use async_trait::async_trait; -use eyre::Result; +mod utils; +pub use utils::*; -mod models; -pub use models::ModelConfig; +mod providers; +pub use providers::*; -/// Ollama configurations & service checks -mod ollama; -pub(crate) use ollama::OllamaConfig; - -/// OpenAI configurations & service checks -mod openai; -pub(crate) use openai::OpenAIConfig; - -/// Extension trait for model providers to check if they are ready, and describe themselves. -#[async_trait] -pub trait ProvidersExt { - const PROVIDER_NAME: &str; - - /// Ensures that the required provider is online & ready. - async fn check_service(&self) -> Result<()>; -} +mod config; +pub use config::ModelConfig; diff --git a/workflows/src/providers/mod.rs b/workflows/src/providers/mod.rs new file mode 100644 index 0000000..a6ea768 --- /dev/null +++ b/workflows/src/providers/mod.rs @@ -0,0 +1,5 @@ +mod ollama; +pub use ollama::OllamaConfig; + +mod openai; +pub use openai::OpenAIConfig; diff --git a/workflows/src/ollama.rs b/workflows/src/providers/ollama.rs similarity index 84% rename from workflows/src/ollama.rs rename to workflows/src/providers/ollama.rs index 22b85be..deaef6d 100644 --- a/workflows/src/ollama.rs +++ b/workflows/src/providers/ollama.rs @@ -10,14 +10,20 @@ use ollama_workflows::{ }, Model, }; +use std::env; use std::time::Duration; const DEFAULT_OLLAMA_HOST: &str = "http://127.0.0.1"; const DEFAULT_OLLAMA_PORT: u16 = 11434; +/// Automatically pull missing models by default? +const DEFAULT_AUTO_PULL: bool = true; +/// Timeout duration for checking model performance during a generation. +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(80); +/// Minimum tokens per second (TPS) for checking model performance during a generation. +const DEFAULT_MIN_TPS: f64 = 15.0; /// Some models such as small embedding models, are hardcoded into the node. const HARDCODED_MODELS: [&str; 1] = ["hellord/mxbai-embed-large-v1:f16"]; - /// Prompt to be used to see Ollama performance. const TEST_PROMPT: &str = "Please write a poem about Kapadokya."; @@ -25,16 +31,16 @@ const TEST_PROMPT: &str = "Please write a poem about Kapadokya."; #[derive(Debug, Clone)] pub struct OllamaConfig { /// Host, usually `http://127.0.0.1`. - pub(crate) host: String, + host: String, /// Port, usually `11434`. - pub(crate) port: u16, - /// List of hardcoded models that are internally used by Ollama workflows. - hardcoded_models: Vec, - /// List of external models that are picked by the user. - pub(crate) models: Vec, + port: u16, /// Whether to automatically pull models from Ollama. /// This is useful for CI/CD workflows. auto_pull: bool, + /// Timeout duration for checking model performance during a generation. + timeout: Duration, + /// Minimum tokens per second (TPS) for checking model performance during a generation. + min_tps: f64, } impl Default for OllamaConfig { @@ -42,11 +48,9 @@ impl Default for OllamaConfig { Self { host: DEFAULT_OLLAMA_HOST.to_string(), port: DEFAULT_OLLAMA_PORT, - hardcoded_models: HARDCODED_MODELS - .into_iter() - .map(|s| s.to_string()) - .collect(), - auto_pull: false, + auto_pull: DEFAULT_AUTO_PULL, + timeout: DEFAULT_TIMEOUT, + min_tps: DEFAULT_MIN_TPS, } } } @@ -55,40 +59,32 @@ impl OllamaConfig { /// /// If not found, defaults to `DEFAULT_OLLAMA_HOST` and `DEFAULT_OLLAMA_PORT`. pub fn new() -> Self { - let host = std::env::var("OLLAMA_HOST") + let host = env::var("OLLAMA_HOST") .map(|h| h.trim_matches('"').to_string()) .unwrap_or(DEFAULT_OLLAMA_HOST.to_string()); - let port = std::env::var("OLLAMA_PORT") + let port = env::var("OLLAMA_PORT") .and_then(|port_str| port_str.parse().map_err(|_| std::env::VarError::NotPresent)) .unwrap_or(DEFAULT_OLLAMA_PORT); - // Ollama workflows may require specific models to be loaded regardless of the choices - let hardcoded_models = HARDCODED_MODELS.iter().map(|s| s.to_string()).collect(); - // auto-pull, its true by default - let auto_pull = std::env::var("OLLAMA_AUTO_PULL") + let auto_pull = env::var("OLLAMA_AUTO_PULL") .map(|s| s == "true") .unwrap_or(true); Self { host, port, - hardcoded_models, auto_pull, + ..Default::default() } } /// Check if requested models exist in Ollama, and then tests them using a workflow. - pub async fn check( - &self, - external_models: Vec, - timeout: Duration, - min_tps: f64, - ) -> Result> { + pub async fn check(&self, external_models: Vec) -> Result> { log::info!( "Checking Ollama requirements (auto-pull {}, workflow timeout: {}s)", if self.auto_pull { "on" } else { "off" }, - timeout.as_secs() + self.timeout.as_secs() ); let ollama = Ollama::new(&self.host, self.port); @@ -107,11 +103,10 @@ impl OllamaConfig { // check hardcoded models & pull them if available // these are not used directly by the user, but are needed for the workflows - log::debug!("Checking hardcoded models: {:#?}", self.hardcoded_models); - // only check if model is contained in local_models - // we dont check workflows for hardcoded models - for model in &self.hardcoded_models { - if !local_models.contains(model) { + // we only check if model is contained in local_models, we dont check workflows for these + for model in HARDCODED_MODELS { + // `contains` doesnt work for &str so we equality check instead + if !&local_models.iter().any(|s| s == model) { self.try_pull(&ollama, model.to_owned()) .await .wrap_err("Could not pull model")?; @@ -128,10 +123,7 @@ impl OllamaConfig { .wrap_err("Could not pull model")?; } - if self - .test_performance(&ollama, &model, timeout, min_tps) - .await - { + if self.test_performance(&ollama, &model).await { good_models.push(model); } } @@ -167,13 +159,7 @@ impl OllamaConfig { /// /// This is to see if a given system can execute Ollama workflows for their chosen models, /// e.g. if they have enough RAM/CPU and such. - pub async fn test_performance( - &self, - ollama: &Ollama, - model: &Model, - timeout: Duration, - min_tps: f64, - ) -> bool { + pub async fn test_performance(&self, ollama: &Ollama, model: &Model) -> bool { log::info!("Testing model {}", model); // first generate a dummy embedding to load the model into memory (warm-up) @@ -202,7 +188,7 @@ impl OllamaConfig { // then, run a sample generation with timeout and measure tps tokio::select! { - _ = tokio::time::sleep(timeout) => { + _ = tokio::time::sleep(self.timeout) => { log::warn!("Ignoring model {}: Workflow timed out", model); }, result = ollama.generate(generation_request) => { @@ -212,7 +198,7 @@ impl OllamaConfig { / (response.eval_duration.unwrap_or(1) as f64) * 1_000_000_000f64; - if tps >= min_tps { + if tps >= self.min_tps { log::info!("Model {} passed the test with tps: {}", model, tps); return true; } @@ -221,7 +207,7 @@ impl OllamaConfig { "Ignoring model {}: tps too low ({:.3} < {:.3})", model, tps, - min_tps + self.min_tps ); } Err(e) => { diff --git a/workflows/src/openai.rs b/workflows/src/providers/openai.rs similarity index 94% rename from workflows/src/openai.rs rename to workflows/src/providers/openai.rs index 75e7e8b..4f6950c 100644 --- a/workflows/src/openai.rs +++ b/workflows/src/providers/openai.rs @@ -29,16 +29,15 @@ struct OpenAIModelsResponse { #[derive(Debug, Clone, Default)] pub struct OpenAIConfig { - /// List of external models that are picked by the user. - pub(crate) models: Vec, + pub(crate) api_key: Option, } impl OpenAIConfig { /// Looks at the environment variables for OpenAI API key. pub fn new() -> Self { - let api_key = std::env::var("OPENAI_API_KEY").ok(); - - Self { api_key } + Self { + api_key: std::env::var("OPENAI_API_KEY").ok(), + } } /// Check if requested models exist & are available in the OpenAI account. diff --git a/workflows/src/utils.rs b/workflows/src/utils.rs new file mode 100644 index 0000000..7831d6e --- /dev/null +++ b/workflows/src/utils.rs @@ -0,0 +1,34 @@ +/// Utility to parse comma-separated string values, mostly read from the environment. +/// +/// - Trims `"` from both ends at the start +/// - For each item, trims whitespace from both ends +pub fn split_comma_separated(input: Option) -> Vec { + match input { + Some(s) => s + .trim_matches('"') + .split(',') + .filter_map(|s| { + let s = s.trim().to_string(); + if s.is_empty() { + None + } else { + Some(s) + } + }) + .collect::>(), + None => vec![], + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_split_comma_separated() { + // should ignore whitespaces and `"` at both ends, and ignore empty items + let input = Some("\"a, b , c ,, \"".to_string()); + let expected = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + assert_eq!(split_comma_separated(input), expected); + } +} diff --git a/workflows/tests/models_test.rs b/workflows/tests/models_test.rs new file mode 100644 index 0000000..1e24f10 --- /dev/null +++ b/workflows/tests/models_test.rs @@ -0,0 +1,14 @@ +use dkn_workflows::ModelConfig; +use eyre::Result; + +// #[tokio::test] +// async fn test_ollama() -> Result<()> {} + +// #[tokio::test] +// async fn test_openai() -> Result<()> {} + +// #[tokio::test] +// async fn test_empty() -> Result<()> { +// let mut model_config = ModelConfig::default(); +// model_config.check_services().await +// } From 96a5ef35f078a4a90f8b0e27adb4856e63f084f7 Mon Sep 17 00:00:00 2001 From: erhant Date: Mon, 7 Oct 2024 14:11:01 +0300 Subject: [PATCH 03/14] some fixes and better interface --- Cargo.lock | 1 + workflows/Cargo.toml | 3 +++ workflows/README.md | 12 +++++++++ workflows/src/config.rs | 41 +++++++++++++++--------------- workflows/src/lib.rs | 2 +- workflows/src/providers/ollama.rs | 4 +-- workflows/src/providers/openai.rs | 8 +++--- workflows/src/utils.rs | 38 +++++++++++++++------------- workflows/tests/models_test.rs | 42 ++++++++++++++++++++++++------- 9 files changed, 98 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f4f061a..eb48b2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1029,6 +1029,7 @@ name = "dkn-workflows" version = "0.1.0" dependencies = [ "async-trait", + "env_logger 0.11.5", "eyre", "log", "ollama-workflows", diff --git a/workflows/Cargo.toml b/workflows/Cargo.toml index e374969..1de7c5e 100644 --- a/workflows/Cargo.toml +++ b/workflows/Cargo.toml @@ -21,3 +21,6 @@ eyre.workspace = true # ollama-rs is re-exported from ollama-workflows ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows" } + +[dev-dependencies] +env_logger.workspace = true diff --git a/workflows/README.md b/workflows/README.md index 1b9aa76..06a7da4 100644 --- a/workflows/README.md +++ b/workflows/README.md @@ -5,3 +5,15 @@ This crate handles the configurations of models to be used, and implements vario - **OpenAI**: We check that the chosen models are enabled for the user's profile by fetching their models with their API key. We filter out the disabled models. - **Ollama**: We provide a sample workflow to measure TPS and then pick models that are above some TPS threshold. While calculating TPS, there is also a timeout so that beyond that timeout the TPS is not even considered and the model becomes invalid. + +## Environment Variables + +DKN Workflows make use of several environment variables, respecting the providers. + +- `OPENAI_API_KEY` is used for OpenAI requests +- `OLLAMA_HOST` is used to connect to Ollama server +- `OLLAMA_PORT` is used to connect to Ollama server +- `OLLAMA_AUTO_PULL` indicates whether we should pull missing models automatically or not + +SERPER_API_KEY= +JINA_API_KEY= diff --git a/workflows/src/config.rs b/workflows/src/config.rs index 2afb480..c67c8c3 100644 --- a/workflows/src/config.rs +++ b/workflows/src/config.rs @@ -15,27 +15,29 @@ pub struct ModelConfig { } impl ModelConfig { - /// Parses Ollama-Workflows compatible models from a comma-separated values string. - pub fn new_from_csv(input: Option) -> Self { - let models_str = split_comma_separated(input); - - let models = models_str + pub fn new(models: Vec) -> Self { + let models_and_providers = models .into_iter() - .filter_map(|s| match Model::try_from(s) { - Ok(model) => Some((model.clone().into(), model)), - Err(e) => { - log::warn!("Error parsing model: {}", e); - None - } - }) + .map(|model| (model.clone().into(), model)) .collect::>(); Self { - models, + models: models_and_providers, openai: OpenAIConfig::new(), ollama: OllamaConfig::new(), } } + /// Parses Ollama-Workflows compatible models from a comma-separated values string. + pub fn new_from_csv(input: &str) -> Self { + let models_str = split_comma_separated(input); + + let models = models_str + .into_iter() + .filter_map(|s| Model::try_from(s).ok()) + .collect(); + + Self::new(models) + } /// Returns the models that belong to a given providers from the config. pub fn get_models_for_provider(&self, provider: ModelProvider) -> Vec { @@ -209,19 +211,18 @@ mod tests { #[test] fn test_csv_parser() { - let cfg = - ModelConfig::new_from_csv(Some("idontexist,i dont either,i332287648762".to_string())); + let cfg = ModelConfig::new_from_csv("idontexist,i dont either,i332287648762"); assert_eq!(cfg.models.len(), 0); - let cfg = ModelConfig::new_from_csv(Some( - "gemma2:9b-instruct-q8_0,phi3:14b-medium-4k-instruct-q4_1,balblablabl".to_string(), - )); + let cfg = ModelConfig::new_from_csv( + "gemma2:9b-instruct-q8_0,phi3:14b-medium-4k-instruct-q4_1,balblablabl", + ); assert_eq!(cfg.models.len(), 2); } #[test] fn test_model_matching() { - let cfg = ModelConfig::new_from_csv(Some("gpt-4o,llama3.1:latest".to_string())); + let cfg = ModelConfig::new_from_csv("gpt-4o,llama3.1:latest"); assert_eq!( cfg.get_matching_model("openai".to_string()).unwrap().1, Model::GPT4o, @@ -250,7 +251,7 @@ mod tests { #[test] fn test_get_any_matching_model() { - let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,llama3.1:latest".to_string())); + let cfg = ModelConfig::new_from_csv("gpt-3.5-turbo,llama3.1:latest"); let result = cfg.get_any_matching_model(vec![ "i-dont-exist".to_string(), "llama3.1:latest".to_string(), diff --git a/workflows/src/lib.rs b/workflows/src/lib.rs index 96f9856..1791330 100644 --- a/workflows/src/lib.rs +++ b/workflows/src/lib.rs @@ -2,7 +2,7 @@ mod utils; pub use utils::*; mod providers; -pub use providers::*; +use providers::*; mod config; pub use config::ModelConfig; diff --git a/workflows/src/providers/ollama.rs b/workflows/src/providers/ollama.rs index deaef6d..57d2ce5 100644 --- a/workflows/src/providers/ollama.rs +++ b/workflows/src/providers/ollama.rs @@ -227,7 +227,7 @@ mod tests { use ollama_workflows::{Executor, Model, ProgramMemory, Workflow}; #[tokio::test] - #[ignore = "run this manually"] + #[ignore = "requires Ollama"] async fn test_ollama_prompt() { let model = Model::default().to_string(); let ollama = Ollama::default(); @@ -246,7 +246,7 @@ mod tests { } #[tokio::test] - #[ignore = "run this manually"] + #[ignore = "requires Ollama"] async fn test_ollama_workflow() { let workflow = r#"{ "name": "Simple", diff --git a/workflows/src/providers/openai.rs b/workflows/src/providers/openai.rs index 4f6950c..b48c3e2 100644 --- a/workflows/src/providers/openai.rs +++ b/workflows/src/providers/openai.rs @@ -1,5 +1,6 @@ use eyre::{eyre, Context, Result}; use ollama_workflows::Model; +use reqwest::Client; use serde::Deserialize; const OPENAI_MODELS_API: &str = "https://api.openai.com/v1/models"; @@ -29,6 +30,7 @@ struct OpenAIModelsResponse { #[derive(Debug, Clone, Default)] pub struct OpenAIConfig { + /// API key, if available. pub(crate) api_key: Option, } @@ -50,17 +52,17 @@ impl OpenAIConfig { }; // fetch models - let client = reqwest::Client::new(); + let client = Client::new(); let request = client .get(OPENAI_MODELS_API) .header("Authorization", format!("Bearer {}", api_key)) .build() - .wrap_err("Failed to build request")?; + .wrap_err("failed to build request")?; let response = client .execute(request) .await - .wrap_err("Failed to send request")?; + .wrap_err("failed to send request")?; // parse response if response.status().is_client_error() { diff --git a/workflows/src/utils.rs b/workflows/src/utils.rs index 7831d6e..f88b293 100644 --- a/workflows/src/utils.rs +++ b/workflows/src/utils.rs @@ -2,22 +2,19 @@ /// /// - Trims `"` from both ends at the start /// - For each item, trims whitespace from both ends -pub fn split_comma_separated(input: Option) -> Vec { - match input { - Some(s) => s - .trim_matches('"') - .split(',') - .filter_map(|s| { - let s = s.trim().to_string(); - if s.is_empty() { - None - } else { - Some(s) - } - }) - .collect::>(), - None => vec![], - } +pub fn split_comma_separated(input: &str) -> Vec { + input + .trim_matches('"') + .split(',') + .filter_map(|s| { + let s = s.trim().to_string(); + if s.is_empty() { + None + } else { + Some(s) + } + }) + .collect::>() } #[cfg(test)] @@ -25,10 +22,15 @@ mod tests { use super::*; #[test] - fn test_split_comma_separated() { + fn test_example() { // should ignore whitespaces and `"` at both ends, and ignore empty items - let input = Some("\"a, b , c ,, \"".to_string()); + let input = "\"a, b , c ,, \""; let expected = vec!["a".to_string(), "b".to_string(), "c".to_string()]; assert_eq!(split_comma_separated(input), expected); } + + #[test] + fn test_empty() { + assert!(split_comma_separated(Default::default()).is_empty()); + } } diff --git a/workflows/tests/models_test.rs b/workflows/tests/models_test.rs index 1e24f10..027457d 100644 --- a/workflows/tests/models_test.rs +++ b/workflows/tests/models_test.rs @@ -1,14 +1,38 @@ +use std::env; + use dkn_workflows::ModelConfig; use eyre::Result; +use ollama_workflows::Model; + +#[tokio::test] +#[ignore = "requires Ollama"] +async fn test_ollama() -> Result<()> { + env::set_var("RUST_LOG", "none,dkn_workflows=debug"); + let _ = env_logger::try_init(); + + let models = vec![Model::Phi3_5Mini]; + let mut model_config = ModelConfig::new(models); + + model_config.check_services().await +} + +#[tokio::test] +async fn test_openai() -> Result<()> { + env::set_var("RUST_LOG", "debug"); + let _ = env_logger::try_init(); + + let models = vec![Model::GPT4Turbo]; + let mut model_config = ModelConfig::new(models); + + model_config.check_services().await +} -// #[tokio::test] -// async fn test_ollama() -> Result<()> {} +#[tokio::test] +async fn test_empty() -> Result<()> { + let mut model_config = ModelConfig::new(vec![]); -// #[tokio::test] -// async fn test_openai() -> Result<()> {} + let result = model_config.check_services().await; + assert!(result.is_err()); -// #[tokio::test] -// async fn test_empty() -> Result<()> { -// let mut model_config = ModelConfig::default(); -// model_config.check_services().await -// } + Ok(()) +} From ea74fdfbe3e56cc613ed14451b785ad0211b8316 Mon Sep 17 00:00:00 2001 From: erhant Date: Mon, 7 Oct 2024 15:31:41 +0300 Subject: [PATCH 04/14] use workflows subcrate in compute --- Cargo.lock | 3 +- Cargo.toml | 1 + Makefile | 6 +- compute/Cargo.toml | 6 +- compute/src/config/mod.rs | 89 +------- compute/src/config/models.rs | 189 ----------------- compute/src/config/ollama.rs | 307 --------------------------- compute/src/config/openai.rs | 110 ---------- compute/src/handlers/pingpong.rs | 2 +- compute/src/handlers/workflow.rs | 8 +- compute/src/main.rs | 4 +- compute/src/utils/available_nodes.rs | 6 +- compute/src/utils/mod.rs | 20 -- workflows/.env.example | 13 ++ workflows/Cargo.toml | 1 + workflows/src/config.rs | 4 +- workflows/src/lib.rs | 7 +- workflows/src/providers/ollama.rs | 4 +- workflows/src/utils.rs | 10 +- workflows/tests/models_test.rs | 29 ++- 20 files changed, 75 insertions(+), 744 deletions(-) delete mode 100644 compute/src/config/models.rs delete mode 100644 compute/src/config/ollama.rs delete mode 100644 compute/src/config/openai.rs create mode 100644 workflows/.env.example diff --git a/Cargo.lock b/Cargo.lock index eb48b2e..0a3efa6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -985,6 +985,7 @@ dependencies = [ "async-trait", "base64 0.22.1", "dkn-p2p", + "dkn-workflows", "dotenvy", "ecies", "env_logger 0.11.5", @@ -994,7 +995,6 @@ dependencies = [ "hex-literal", "libsecp256k1", "log", - "ollama-workflows", "openssl", "parking_lot", "port_check", @@ -1029,6 +1029,7 @@ name = "dkn-workflows" version = "0.1.0" dependencies = [ "async-trait", + "dotenvy", "env_logger 0.11.5", "eyre", "log", diff --git a/Cargo.toml b/Cargo.toml index a453253..7dba9b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ parking_lot = "0.12.2" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" async-trait = "0.1.81" +dotenvy = "0.15.7" reqwest = "0.12.5" rand = "0.8.5" env_logger = "0.11.3" diff --git a/Makefile b/Makefile index 5ec333b..e46a10e 100644 --- a/Makefile +++ b/Makefile @@ -7,15 +7,15 @@ endif ############################################################################### .PHONY: launch # | Run with INFO logs in release mode launch: - RUST_LOG=none,dkn_compute=info cargo run --release + RUST_LOG=none,dkn_compute=info,dkn_workflows=info cargo run --release .PHONY: run # | Run with INFO logs run: - RUST_LOG=none,dkn_compute=info cargo run + RUST_LOG=none,dkn_compute=info,dkn_workflows=info cargo run .PHONY: debug # | Run with DEBUG logs with INFO log-level workflows debug: - RUST_LOG=warn,dkn_compute=debug,ollama_workflows=info cargo run + RUST_LOG=warn,dkn_compute=debug,,dkn_workflows=debug,ollama_workflows=info cargo run .PHONY: trace # | Run with TRACE logs trace: diff --git a/compute/Cargo.toml b/compute/Cargo.toml index eab1f98..db69e12 100644 --- a/compute/Cargo.toml +++ b/compute/Cargo.toml @@ -52,11 +52,9 @@ sha2 = "0.10.8" sha3 = "0.10.8" fastbloom-rs = "0.5.9" -# workflows -ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows" } - -# p2p +# dria subcrates dkn-p2p = { path = "../p2p" } +dkn-workflows = { path = "../workflows" } # Vendor OpenSSL so that its easier to build cross-platform packages [dependencies.openssl] diff --git a/compute/src/config/mod.rs b/compute/src/config/mod.rs index 8bfe346..e111d46 100644 --- a/compute/src/config/mod.rs +++ b/compute/src/config/mod.rs @@ -1,23 +1,10 @@ -mod models; -mod ollama; -mod openai; - use crate::utils::{address_in_use, crypto::to_address}; use dkn_p2p::libp2p::Multiaddr; +use dkn_workflows::ModelConfig; use eyre::{eyre, Result}; use libsecp256k1::{PublicKey, SecretKey}; -use models::ModelConfig; -use ollama::OllamaConfig; -use ollama_workflows::ModelProvider; -use openai::OpenAIConfig; - -use std::{env, str::FromStr, time::Duration}; - -/// Timeout duration for checking model performance during a generation. -const CHECK_TIMEOUT_DURATION: Duration = Duration::from_secs(80); -/// Minimum tokens per second (TPS) for checking model performance during a generation. -const CHECK_TPS: f64 = 15.0; +use std::{env, str::FromStr}; #[derive(Debug, Clone)] pub struct DriaComputeNodeConfig { @@ -33,11 +20,6 @@ pub struct DriaComputeNodeConfig { pub p2p_listen_addr: Multiaddr, /// Available LLM models & providers for the node. pub model_config: ModelConfig, - /// Even if Ollama is not used, we store the host & port here. - /// If Ollama is used, this config will be respected during its instantiations. - pub ollama_config: OllamaConfig, - /// OpenAI API key & its service check implementation. - pub openai_config: OpenAIConfig, } /// The default P2P network listen address. @@ -97,7 +79,7 @@ impl DriaComputeNodeConfig { let address = to_address(&public_key); log::info!("Node Address: 0x{}", hex::encode(address)); - let model_config = ModelConfig::new_from_csv(env::var("DKN_MODELS").ok()); + let model_config = ModelConfig::new_from_csv(&env::var("DKN_MODELS").unwrap_or_default()); #[cfg(not(test))] if model_config.models.is_empty() { log::error!("No models were provided, make sure to restart with at least one model provided within DKN_MODELS."); @@ -118,72 +100,11 @@ impl DriaComputeNodeConfig { address, model_config, p2p_listen_addr, - ollama_config: OllamaConfig::new(), - openai_config: OpenAIConfig::new(), - } - } - - /// Check if the required compute services are running. - /// This has several steps: - /// - /// - If Ollama models are used, hardcoded models are checked locally, and for - /// external models, the workflow is tested with a simple task with timeout. - /// - If OpenAI models are used, the API key is checked and the models are tested - /// - /// If both type of models are used, both services are checked. - /// In the end, bad models are filtered out and we simply check if we are left if any valid models at all. - /// If not, an error is returned. - pub async fn check_services(&mut self) -> Result<()> { - log::info!("Checking configured services."); - - // TODO: can refactor (provider, model) logic here - let unique_providers = self.model_config.get_providers(); - - let mut good_models = Vec::new(); - - // if Ollama is a provider, check that it is running & Ollama models are pulled (or pull them) - if unique_providers.contains(&ModelProvider::Ollama) { - let ollama_models = self - .model_config - .get_models_for_provider(ModelProvider::Ollama); - - // ensure that the models are pulled / pull them if not - let good_ollama_models = self - .ollama_config - .check(ollama_models, CHECK_TIMEOUT_DURATION, CHECK_TPS) - .await?; - good_models.extend( - good_ollama_models - .into_iter() - .map(|m| (ModelProvider::Ollama, m)), - ); - } - - // if OpenAI is a provider, check that the API key is set - if unique_providers.contains(&ModelProvider::OpenAI) { - let openai_models = self - .model_config - .get_models_for_provider(ModelProvider::OpenAI); - - let good_openai_models = self.openai_config.check(openai_models).await?; - good_models.extend( - good_openai_models - .into_iter() - .map(|m| (ModelProvider::OpenAI, m)), - ); - } - - // update good models - if good_models.is_empty() { - Err(eyre!("No good models found, please check logs for errors.")) - } else { - self.model_config.models = good_models; - Ok(()) } } - // ensure that listen address is free - pub fn check_address_in_use(&self) -> Result<()> { + /// Asserts that the configured listen address is free. + pub fn assert_address_not_in_use(&self) -> Result<()> { if address_in_use(&self.p2p_listen_addr) { return Err(eyre!( "Listen address {} is already in use.", diff --git a/compute/src/config/models.rs b/compute/src/config/models.rs deleted file mode 100644 index 943a8dc..0000000 --- a/compute/src/config/models.rs +++ /dev/null @@ -1,189 +0,0 @@ -use crate::utils::split_comma_separated; -use eyre::{eyre, Result}; -use ollama_workflows::{Model, ModelProvider}; -use rand::seq::IteratorRandom; // provides Vec<_>.choose - -#[derive(Debug, Clone)] -pub struct ModelConfig { - pub(crate) models: Vec<(ModelProvider, Model)>, -} - -impl std::fmt::Display for ModelConfig { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let models_str = self - .models - .iter() - .map(|(provider, model)| format!("{:?}:{}", provider, model)) - .collect::>() - .join(","); - write!(f, "{}", models_str) - } -} - -impl ModelConfig { - /// Parses Ollama-Workflows compatible models from a comma-separated values string. - pub fn new_from_csv(input: Option) -> Self { - let models_str = split_comma_separated(input); - - let models = models_str - .into_iter() - .filter_map(|s| match Model::try_from(s) { - Ok(model) => Some((model.clone().into(), model)), - Err(e) => { - log::warn!("Error parsing model: {}", e); - None - } - }) - .collect::>(); - - Self { models } - } - - /// Returns the models that belong to a given providers from the config. - pub fn get_models_for_provider(&self, provider: ModelProvider) -> Vec { - self.models - .iter() - .filter_map(|(p, m)| { - if *p == provider { - Some(m.clone()) - } else { - None - } - }) - .collect() - } - - /// Given a raw model name or provider (as a string), returns the first matching model & provider. - /// - /// If this is a model and is supported by this node, it is returned directly. - /// If this is a provider, the first matching model in the node config is returned. - /// - /// If there are no matching models with this logic, an error is returned. - pub fn get_matching_model(&self, model_or_provider: String) -> Result<(ModelProvider, Model)> { - if let Ok(provider) = ModelProvider::try_from(model_or_provider.clone()) { - // this is a valid provider, return the first matching model in the config - self.models - .iter() - .find(|(p, _)| *p == provider) - .ok_or(eyre!( - "Provider {} is not supported by this node.", - provider - )) - .cloned() - } else if let Ok(model) = Model::try_from(model_or_provider.clone()) { - // this is a valid model, return it if it is supported by the node - self.models - .iter() - .find(|(_, m)| *m == model) - .ok_or(eyre!("Model {} is not supported by this node.", model)) - .cloned() - } else { - // this is neither a valid provider or model for this node - Err(eyre!( - "Given string '{}' is neither a model nor provider.", - model_or_provider - )) - } - } - - /// From a list of model or provider names, return a random matching model & provider. - pub fn get_any_matching_model( - &self, - list_model_or_provider: Vec, - ) -> Result<(ModelProvider, Model)> { - // filter models w.r.t supported ones - let matching_models = list_model_or_provider - .into_iter() - .filter_map(|model_or_provider| { - let result = self.get_matching_model(model_or_provider); - match result { - Ok(result) => Some(result), - Err(e) => { - log::debug!("Ignoring model: {}", e); - None - } - } - }) - .collect::>(); - - // choose random model - matching_models - .into_iter() - .choose(&mut rand::thread_rng()) - .ok_or(eyre!("No matching models found.")) - } - - /// Returns the list of unique providers in the config. - pub fn get_providers(&self) -> Vec { - self.models - .iter() - .fold(Vec::new(), |mut unique, (provider, _)| { - if !unique.contains(provider) { - unique.push(provider.clone()); - } - unique - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_csv_parser() { - let cfg = - ModelConfig::new_from_csv(Some("idontexist,i dont either,i332287648762".to_string())); - assert_eq!(cfg.models.len(), 0); - - let cfg = ModelConfig::new_from_csv(Some( - "gemma2:9b-instruct-q8_0,phi3:14b-medium-4k-instruct-q4_1,balblablabl".to_string(), - )); - assert_eq!(cfg.models.len(), 2); - } - - #[test] - fn test_model_matching() { - let cfg = ModelConfig::new_from_csv(Some("gpt-4o,llama3.1:latest".to_string())); - assert_eq!( - cfg.get_matching_model("openai".to_string()).unwrap().1, - Model::GPT4o, - "Should find existing model" - ); - - assert_eq!( - cfg.get_matching_model("llama3.1:latest".to_string()) - .unwrap() - .1, - Model::Llama3_1_8B, - "Should find existing model" - ); - - assert!( - cfg.get_matching_model("gpt-4o-mini".to_string()).is_err(), - "Should not find anything for unsupported model" - ); - - assert!( - cfg.get_matching_model("praise the model".to_string()) - .is_err(), - "Should not find anything for inexisting model" - ); - } - - #[test] - fn test_get_any_matching_model() { - let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,llama3.1:latest".to_string())); - let result = cfg.get_any_matching_model(vec![ - "i-dont-exist".to_string(), - "llama3.1:latest".to_string(), - "gpt-4o".to_string(), - "ollama".to_string(), - ]); - assert_eq!( - result.unwrap().1, - Model::Llama3_1_8B, - "Should find existing model" - ); - } -} diff --git a/compute/src/config/ollama.rs b/compute/src/config/ollama.rs deleted file mode 100644 index 53ce2a3..0000000 --- a/compute/src/config/ollama.rs +++ /dev/null @@ -1,307 +0,0 @@ -use eyre::{eyre, Context, Result}; -use ollama_workflows::{ - ollama_rs::{ - generation::{ - completion::request::GenerationRequest, - embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}, - options::GenerationOptions, - }, - Ollama, - }, - Model, -}; -use std::time::Duration; - -const DEFAULT_OLLAMA_HOST: &str = "http://127.0.0.1"; -const DEFAULT_OLLAMA_PORT: u16 = 11434; - -/// Some models such as small embedding models, are hardcoded into the node. -const HARDCODED_MODELS: [&str; 1] = ["hellord/mxbai-embed-large-v1:f16"]; - -/// Prompt to be used to see Ollama performance. -const TEST_PROMPT: &str = "Please write a poem about Kapadokya."; - -/// Ollama-specific configurations. -#[derive(Debug, Clone)] -pub struct OllamaConfig { - /// Host, usually `http://127.0.0.1`. - pub(crate) host: String, - /// Port, usually `11434`. - pub(crate) port: u16, - /// List of hardcoded models that are internally used by Ollama workflows. - hardcoded_models: Vec, - /// Whether to automatically pull models from Ollama. - /// This is useful for CI/CD workflows. - auto_pull: bool, -} - -impl Default for OllamaConfig { - fn default() -> Self { - Self { - host: DEFAULT_OLLAMA_HOST.to_string(), - port: DEFAULT_OLLAMA_PORT, - hardcoded_models: HARDCODED_MODELS - .into_iter() - .map(|s| s.to_string()) - .collect(), - auto_pull: false, - } - } -} -impl OllamaConfig { - /// Looks at the environment variables for Ollama host and port. - /// - /// If not found, defaults to `DEFAULT_OLLAMA_HOST` and `DEFAULT_OLLAMA_PORT`. - pub fn new() -> Self { - let host = std::env::var("OLLAMA_HOST") - .map(|h| h.trim_matches('"').to_string()) - .unwrap_or(DEFAULT_OLLAMA_HOST.to_string()); - let port = std::env::var("OLLAMA_PORT") - .and_then(|port_str| port_str.parse().map_err(|_| std::env::VarError::NotPresent)) - .unwrap_or(DEFAULT_OLLAMA_PORT); - - // Ollama workflows may require specific models to be loaded regardless of the choices - let hardcoded_models = HARDCODED_MODELS.iter().map(|s| s.to_string()).collect(); - - // auto-pull, its true by default - let auto_pull = std::env::var("OLLAMA_AUTO_PULL") - .map(|s| s == "true") - .unwrap_or(true); - - Self { - host, - port, - hardcoded_models, - auto_pull, - } - } - - /// Check if requested models exist in Ollama, and then tests them using a workflow. - pub async fn check( - &self, - external_models: Vec, - timeout: Duration, - min_tps: f64, - ) -> Result> { - log::info!( - "Checking Ollama requirements (auto-pull {}, workflow timeout: {}s)", - if self.auto_pull { "on" } else { "off" }, - timeout.as_secs() - ); - - let ollama = Ollama::new(&self.host, self.port); - - // fetch local models - let local_models = match ollama.list_local_models().await { - Ok(models) => models.into_iter().map(|m| m.name).collect::>(), - Err(e) => { - return { - log::error!("Could not fetch local models from Ollama, is it online?"); - Err(e.into()) - } - } - }; - log::info!("Found local Ollama models: {:#?}", local_models); - - // check hardcoded models & pull them if available - // these are not used directly by the user, but are needed for the workflows - log::debug!("Checking hardcoded models: {:#?}", self.hardcoded_models); - // only check if model is contained in local_models - // we dont check workflows for hardcoded models - for model in &self.hardcoded_models { - if !local_models.contains(model) { - self.try_pull(&ollama, model.to_owned()) - .await - .wrap_err("Could not pull model")?; - } - } - - // check external models & pull them if available - // and also run a test workflow for them - let mut good_models = Vec::new(); - for model in external_models { - if !local_models.contains(&model.to_string()) { - self.try_pull(&ollama, model.to_string()) - .await - .wrap_err("Could not pull model")?; - } - - if self - .test_performance(&ollama, &model, timeout, min_tps) - .await - { - good_models.push(model); - } - } - - log::info!( - "Ollama checks are finished, using models: {:#?}", - good_models - ); - Ok(good_models) - } - - /// Pulls a model if `auto_pull` exists, otherwise returns an error. - async fn try_pull(&self, ollama: &Ollama, model: String) -> Result<()> { - log::warn!("Model {} not found in Ollama", model); - if self.auto_pull { - // if auto-pull is enabled, pull the model - log::info!( - "Downloading missing model {} (this may take a while)", - model - ); - let status = ollama.pull_model(model, false).await?; - log::debug!("Pulled model with Ollama, final status: {:#?}", status); - Ok(()) - } else { - // otherwise, give error - log::error!("Please download missing model with: ollama pull {}", model); - log::error!("Or, set OLLAMA_AUTO_PULL=true to pull automatically."); - Err(eyre!("Required model not pulled in Ollama.")) - } - } - - /// Runs a small workflow to test Ollama Workflows. - /// - /// This is to see if a given system can execute Ollama workflows for their chosen models, - /// e.g. if they have enough RAM/CPU and such. - pub async fn test_performance( - &self, - ollama: &Ollama, - model: &Model, - timeout: Duration, - min_tps: f64, - ) -> bool { - log::info!("Testing model {}", model); - - // first generate a dummy embedding to load the model into memory (warm-up) - let request = GenerateEmbeddingsRequest::new( - model.to_string(), - EmbeddingsInput::Single("embedme".into()), - ); - if let Err(err) = ollama.generate_embeddings(request).await { - log::error!("Failed to generate embedding for model {}: {}", model, err); - return false; - }; - - let mut generation_request = - GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string()); - - // FIXME: temporary workaround, can take num threads from outside - if let Ok(num_thread) = std::env::var("OLLAMA_NUM_THREAD") { - generation_request = generation_request.options( - GenerationOptions::default().num_thread( - num_thread - .parse() - .expect("num threads should be a positive integer"), - ), - ); - } - - // then, run a sample generation with timeout and measure tps - tokio::select! { - _ = tokio::time::sleep(timeout) => { - log::warn!("Ignoring model {}: Workflow timed out", model); - }, - result = ollama.generate(generation_request) => { - match result { - Ok(response) => { - let tps = (response.eval_count.unwrap_or_default() as f64) - / (response.eval_duration.unwrap_or(1) as f64) - * 1_000_000_000f64; - - if tps >= min_tps { - log::info!("Model {} passed the test with tps: {}", model, tps); - return true; - } - - log::warn!( - "Ignoring model {}: tps too low ({:.3} < {:.3})", - model, - tps, - min_tps - ); - } - Err(e) => { - log::warn!("Ignoring model {}: Workflow failed with error {}", model, e); - } - } - } - }; - - false - } -} - -#[cfg(test)] -mod tests { - use ollama_workflows::ollama_rs::{generation::completion::request::GenerationRequest, Ollama}; - use ollama_workflows::{Executor, Model, ProgramMemory, Workflow}; - - #[tokio::test] - #[ignore = "run this manually"] - async fn test_ollama_prompt() { - let model = Model::default().to_string(); - let ollama = Ollama::default(); - ollama.pull_model(model.clone(), false).await.unwrap(); - let prompt = "The sky appears blue during the day because of a process called scattering. \ - When sunlight enters the Earth's atmosphere, it collides with air molecules such as oxygen and nitrogen. \ - These collisions cause some of the light to be absorbed or reflected, which makes the colors we see appear more vivid and vibrant. \ - Blue is one of the brightest colors that is scattered the most by the atmosphere, making it visible to our eyes during the day. \ - What may be the question this answer?".to_string(); - - let response = ollama - .generate(GenerationRequest::new(model, prompt.clone())) - .await - .expect("Should generate response"); - println!("Prompt: {}\n\nResponse:{}", prompt, response.response); - } - - #[tokio::test] - #[ignore = "run this manually"] - async fn test_ollama_workflow() { - let workflow = r#"{ - "name": "Simple", - "description": "This is a simple workflow", - "config": { - "max_steps": 5, - "max_time": 100, - }, - "tasks":[ - { - "id": "A", - "name": "Random Poem", - "description": "Writes a poem about Kapadokya.", - "prompt": "Please write a poem about Kapadokya.", - "operator": "generation", - "outputs": [ - { - "type": "write", - "key": "final_result", - "value": "__result" - } - ] - }, - { - "id": "__end", - "name": "end", - "description": "End of the task", - "prompt": "End of the task", - "operator": "end", - } - ], - "steps":[ - { - "source":"A", - "target":"end" - } - ] - }"#; - let workflow: Workflow = serde_json::from_str(workflow).unwrap(); - let exe = Executor::new(Model::default()); - let mut memory = ProgramMemory::new(); - - let result = exe.execute(None, workflow, &mut memory).await; - println!("Result: {}", result.unwrap()); - } -} diff --git a/compute/src/config/openai.rs b/compute/src/config/openai.rs deleted file mode 100644 index 31a9feb..0000000 --- a/compute/src/config/openai.rs +++ /dev/null @@ -1,110 +0,0 @@ -use eyre::{eyre, Context, Result}; -use ollama_workflows::Model; -use serde::Deserialize; - -const OPENAI_MODELS_API: &str = "https://api.openai.com/v1/models"; - -/// [Model](https://platform.openai.com/docs/api-reference/models/object) API object. -#[derive(Debug, Clone, Deserialize)] -struct OpenAIModel { - /// The model identifier, which can be referenced in the API endpoints. - id: String, - /// The Unix timestamp (in seconds) when the model was created. - #[allow(unused)] - created: u64, - /// The object type, which is always "model". - #[allow(unused)] - object: String, - /// The organization that owns the model. - #[allow(unused)] - owned_by: String, -} - -#[derive(Debug, Clone, Deserialize)] -struct OpenAIModelsResponse { - data: Vec, - #[allow(unused)] - object: String, -} - -#[derive(Debug, Clone, Default)] -pub struct OpenAIConfig { - pub(crate) api_key: Option, -} - -impl OpenAIConfig { - /// Looks at the environment variables for OpenAI API key. - pub fn new() -> Self { - let api_key = std::env::var("OPENAI_API_KEY").ok(); - - Self { api_key } - } - - /// Check if requested models exist & are available in the OpenAI account. - pub async fn check(&self, models: Vec) -> Result> { - log::info!("Checking OpenAI requirements"); - - // check API key - let Some(api_key) = &self.api_key else { - return Err(eyre!("OpenAI API key not found")); - }; - - // fetch models - let client = reqwest::Client::new(); - let request = client - .get(OPENAI_MODELS_API) - .header("Authorization", format!("Bearer {}", api_key)) - .build() - .wrap_err("Failed to build request")?; - - let response = client - .execute(request) - .await - .wrap_err("Failed to send request")?; - - // parse response - if response.status().is_client_error() { - return Err(eyre!( - "Failed to fetch OpenAI models:\n{}", - response.text().await.unwrap_or_default() - )); - } - let openai_models = response.json::().await?; - - // check if models exist and select those that are available - let mut available_models = Vec::new(); - for requested_model in models { - if !openai_models - .data - .iter() - .any(|m| m.id == requested_model.to_string()) - { - log::warn!( - "Model {} not found in your OpenAI account, ignoring it.", - requested_model - ); - } else { - available_models.push(requested_model); - } - } - - log::info!( - "OpenAI checks are finished, using models: {:#?}", - available_models - ); - Ok(available_models) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - #[ignore = "requires OpenAI API key"] - async fn test_openai_check() { - let config = OpenAIConfig::new(); - let res = config.check(vec![]).await; - println!("Result: {}", res.unwrap_err()); - } -} diff --git a/compute/src/handlers/pingpong.rs b/compute/src/handlers/pingpong.rs index f328865..2982f4b 100644 --- a/compute/src/handlers/pingpong.rs +++ b/compute/src/handlers/pingpong.rs @@ -5,8 +5,8 @@ use crate::{ }; use async_trait::async_trait; use dkn_p2p::libp2p::gossipsub::MessageAcceptance; +use dkn_workflows::{Model, ModelProvider}; use eyre::{Context, Result}; -use ollama_workflows::{Model, ModelProvider}; use serde::{Deserialize, Serialize}; pub struct PingpongHandler; diff --git a/compute/src/handlers/workflow.rs b/compute/src/handlers/workflow.rs index c15abbc..92fca8e 100644 --- a/compute/src/handlers/workflow.rs +++ b/compute/src/handlers/workflow.rs @@ -1,8 +1,8 @@ use async_trait::async_trait; use dkn_p2p::libp2p::gossipsub::MessageAcceptance; +use dkn_workflows::ollama_workflows::{Entry, Executor, ModelProvider, ProgramMemory, Workflow}; use eyre::{eyre, Context, Result}; use libsecp256k1::PublicKey; -use ollama_workflows::{Entry, Executor, ModelProvider, ProgramMemory, Workflow}; use serde::Deserialize; use crate::payloads::{TaskErrorPayload, TaskRequestPayload, TaskResponsePayload}; @@ -74,7 +74,11 @@ impl ComputeHandler for WorkflowHandler { // prepare workflow executor let executor = if model_provider == ModelProvider::Ollama { - Executor::new_at(model, &config.ollama_config.host, config.ollama_config.port) + Executor::new_at( + model, + &config.model_config.ollama.host, + config.model_config.ollama.port, + ) } else { Executor::new(model) }; diff --git a/compute/src/main.rs b/compute/src/main.rs index a2fd86d..b5b882f 100644 --- a/compute/src/main.rs +++ b/compute/src/main.rs @@ -46,7 +46,7 @@ async fn main() -> Result<()> { // create configurations & check required services let config = DriaComputeNodeConfig::new(); - config.check_address_in_use()?; + config.assert_address_not_in_use()?; let service_check_token = token.clone(); let mut config_clone = config.clone(); let service_check_handle = tokio::spawn(async move { @@ -54,7 +54,7 @@ async fn main() -> Result<()> { _ = service_check_token.cancelled() => { log::info!("Service check cancelled."); } - result = config_clone.check_services() => { + result = config_clone.model_config.check_services() => { if let Err(err) = result { log::error!("Error checking services: {:?}", err); panic!("Service check failed.") diff --git a/compute/src/utils/available_nodes.rs b/compute/src/utils/available_nodes.rs index dada7de..50ede8a 100644 --- a/compute/src/utils/available_nodes.rs +++ b/compute/src/utils/available_nodes.rs @@ -2,7 +2,7 @@ use dkn_p2p::libp2p::{Multiaddr, PeerId}; use eyre::Result; use std::{env, fmt::Debug, str::FromStr}; -use crate::utils::split_comma_separated; +use dkn_workflows::split_csv_line; /// Static bootstrap nodes for the Kademlia DHT bootstrap step. const STATIC_BOOTSTRAP_NODES: [&str; 4] = [ @@ -48,7 +48,7 @@ impl AvailableNodes { /// - `DRIA_RELAY_NODES`: comma-separated list of relay nodes pub fn new_from_env() -> Self { // parse bootstrap nodes - let bootstrap_nodes = split_comma_separated(env::var("DKN_BOOTSTRAP_NODES").ok()); + let bootstrap_nodes = split_csv_line(&env::var("DKN_BOOTSTRAP_NODES").unwrap_or_default()); if bootstrap_nodes.is_empty() { log::debug!("No additional bootstrap nodes provided."); } else { @@ -56,7 +56,7 @@ impl AvailableNodes { } // parse relay nodes - let relay_nodes = split_comma_separated(env::var("DKN_RELAY_NODES").ok()); + let relay_nodes = split_csv_line(&env::var("DKN_RELAY_NODES").unwrap_or_default()); if relay_nodes.is_empty() { log::debug!("No additional relay nodes provided."); } else { diff --git a/compute/src/utils/mod.rs b/compute/src/utils/mod.rs index f24ffc5..8721f91 100644 --- a/compute/src/utils/mod.rs +++ b/compute/src/utils/mod.rs @@ -55,23 +55,3 @@ pub fn address_in_use(addr: &Multiaddr) -> bool { false }) } - -/// Utility to parse comma-separated string values, mostly read from the environment. -/// - Trims `"` from both ends at the start -/// - For each item, trims whitespace from both ends -pub fn split_comma_separated(input: Option) -> Vec { - match input { - Some(s) => s - .trim_matches('"') - .split(',') - .filter_map(|s| { - if s.is_empty() { - None - } else { - Some(s.trim().to_string()) - } - }) - .collect::>(), - None => vec![], - } -} diff --git a/workflows/.env.example b/workflows/.env.example new file mode 100644 index 0000000..eec9599 --- /dev/null +++ b/workflows/.env.example @@ -0,0 +1,13 @@ +## Open AI (if used, required) ## +OPENAI_API_KEY= + +## Ollama (if used, optional) ## +OLLAMA_HOST=http://localhost +OLLAMA_PORT=11434 +# if "true", automatically pull models from Ollama +# if "false", you have to download manually +OLLAMA_AUTO_PULL=true + +## Additional Services (optional) +SERPER_API_KEY= +JINA_API_KEY= diff --git a/workflows/Cargo.toml b/workflows/Cargo.toml index 1de7c5e..cbcc82e 100644 --- a/workflows/Cargo.toml +++ b/workflows/Cargo.toml @@ -24,3 +24,4 @@ ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows" } [dev-dependencies] env_logger.workspace = true +dotenvy.workspace = true diff --git a/workflows/src/config.rs b/workflows/src/config.rs index c67c8c3..6b44aae 100644 --- a/workflows/src/config.rs +++ b/workflows/src/config.rs @@ -1,4 +1,4 @@ -use crate::{split_comma_separated, OllamaConfig, OpenAIConfig}; +use crate::{split_csv_line, OllamaConfig, OpenAIConfig}; use eyre::{eyre, Result}; use ollama_workflows::{Model, ModelProvider}; use rand::seq::IteratorRandom; // provides Vec<_>.choose @@ -29,7 +29,7 @@ impl ModelConfig { } /// Parses Ollama-Workflows compatible models from a comma-separated values string. pub fn new_from_csv(input: &str) -> Self { - let models_str = split_comma_separated(input); + let models_str = split_csv_line(input); let models = models_str .into_iter() diff --git a/workflows/src/lib.rs b/workflows/src/lib.rs index 1791330..4c258ea 100644 --- a/workflows/src/lib.rs +++ b/workflows/src/lib.rs @@ -1,8 +1,11 @@ mod utils; -pub use utils::*; +pub use utils::split_csv_line; mod providers; -use providers::*; +use providers::{OllamaConfig, OpenAIConfig}; mod config; pub use config::ModelConfig; + +pub use ollama_workflows; +pub use ollama_workflows::{Model, ModelProvider}; diff --git a/workflows/src/providers/ollama.rs b/workflows/src/providers/ollama.rs index 57d2ce5..41c806b 100644 --- a/workflows/src/providers/ollama.rs +++ b/workflows/src/providers/ollama.rs @@ -31,9 +31,9 @@ const TEST_PROMPT: &str = "Please write a poem about Kapadokya."; #[derive(Debug, Clone)] pub struct OllamaConfig { /// Host, usually `http://127.0.0.1`. - host: String, + pub host: String, /// Port, usually `11434`. - port: u16, + pub port: u16, /// Whether to automatically pull models from Ollama. /// This is useful for CI/CD workflows. auto_pull: bool, diff --git a/workflows/src/utils.rs b/workflows/src/utils.rs index f88b293..030f0e2 100644 --- a/workflows/src/utils.rs +++ b/workflows/src/utils.rs @@ -1,8 +1,8 @@ -/// Utility to parse comma-separated string values, mostly read from the environment. +/// Utility to parse comma-separated string value line. /// -/// - Trims `"` from both ends at the start +/// - Trims `"` from both ends for the input /// - For each item, trims whitespace from both ends -pub fn split_comma_separated(input: &str) -> Vec { +pub fn split_csv_line(input: &str) -> Vec { input .trim_matches('"') .split(',') @@ -26,11 +26,11 @@ mod tests { // should ignore whitespaces and `"` at both ends, and ignore empty items let input = "\"a, b , c ,, \""; let expected = vec!["a".to_string(), "b".to_string(), "c".to_string()]; - assert_eq!(split_comma_separated(input), expected); + assert_eq!(split_csv_line(input), expected); } #[test] fn test_empty() { - assert!(split_comma_separated(Default::default()).is_empty()); + assert!(split_csv_line(Default::default()).is_empty()); } } diff --git a/workflows/tests/models_test.rs b/workflows/tests/models_test.rs index 027457d..d8e9c7b 100644 --- a/workflows/tests/models_test.rs +++ b/workflows/tests/models_test.rs @@ -1,30 +1,45 @@ use std::env; -use dkn_workflows::ModelConfig; +use dkn_workflows::{ModelConfig, ModelProvider}; use eyre::Result; use ollama_workflows::Model; +const LOG_LEVEL: &str = "none,dkn_workflows=debug"; + #[tokio::test] #[ignore = "requires Ollama"] -async fn test_ollama() -> Result<()> { - env::set_var("RUST_LOG", "none,dkn_workflows=debug"); +async fn test_ollama_check() -> Result<()> { + env::set_var("RUST_LOG", LOG_LEVEL); let _ = env_logger::try_init(); let models = vec![Model::Phi3_5Mini]; let mut model_config = ModelConfig::new(models); + model_config.check_services().await?; + + assert_eq!( + model_config.models[0], + (ModelProvider::Ollama, Model::Phi3_5Mini) + ); - model_config.check_services().await + Ok(()) } #[tokio::test] -async fn test_openai() -> Result<()> { - env::set_var("RUST_LOG", "debug"); +#[ignore = "requires OpenAI"] +async fn test_openai_check() -> Result<()> { + let _ = dotenvy::dotenv(); // read api key + env::set_var("RUST_LOG", LOG_LEVEL); let _ = env_logger::try_init(); let models = vec![Model::GPT4Turbo]; let mut model_config = ModelConfig::new(models); + model_config.check_services().await?; - model_config.check_services().await + assert_eq!( + model_config.models[0], + (ModelProvider::OpenAI, Model::GPT4Turbo) + ); + Ok(()) } #[tokio::test] From 16ae371b1692000fc36ba0287d95d50520d146dc Mon Sep 17 00:00:00 2001 From: erhant Date: Mon, 7 Oct 2024 15:56:00 +0300 Subject: [PATCH 05/14] slight renamings --- Cargo.toml | 5 ++++- compute/Cargo.toml | 5 +++-- compute/src/{config/mod.rs => config.rs} | 11 ++++++----- compute/src/handlers/pingpong.rs | 2 +- compute/src/handlers/workflow.rs | 8 +++----- compute/src/main.rs | 2 +- compute/src/node.rs | 6 +++--- p2p/src/client.rs | 7 +++---- p2p/src/lib.rs | 5 ++--- p2p/src/{data_transform.rs => transform.rs} | 0 workflows/src/config.rs | 21 +++++++++++---------- workflows/src/lib.rs | 2 +- workflows/src/providers/openai.rs | 3 ++- workflows/tests/models_test.rs | 11 +++++------ 14 files changed, 45 insertions(+), 43 deletions(-) rename compute/src/{config/mod.rs => config.rs} (94%) rename p2p/src/{data_transform.rs => transform.rs} (100%) diff --git a/Cargo.toml b/Cargo.toml index 7dba9b1..dd0c36b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,9 +16,12 @@ parking_lot = "0.12.2" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" async-trait = "0.1.81" -dotenvy = "0.15.7" reqwest = "0.12.5" + +dotenvy = "0.15.7" + rand = "0.8.5" + env_logger = "0.11.3" log = "0.4.21" eyre = "0.6.12" diff --git a/compute/Cargo.toml b/compute/Cargo.toml index db69e12..77d62ae 100644 --- a/compute/Cargo.toml +++ b/compute/Cargo.toml @@ -28,17 +28,18 @@ async-trait = "0.1.81" reqwest = "0.12.5" # utilities -dotenvy = "0.15.7" +dotenvy.workspace = true base64 = "0.22.0" hex = "0.4.3" hex-literal = "0.4.1" url = "2.5.0" urlencoding = "2.1.3" uuid = { version = "1.8.0", features = ["v4"] } -rand = "0.8.5" + port_check = "0.2.1" # logging & errors +rand.workspace = true env_logger.workspace = true log.workspace = true eyre.workspace = true diff --git a/compute/src/config/mod.rs b/compute/src/config.rs similarity index 94% rename from compute/src/config/mod.rs rename to compute/src/config.rs index e111d46..c5e54de 100644 --- a/compute/src/config/mod.rs +++ b/compute/src/config.rs @@ -1,6 +1,6 @@ use crate::utils::{address_in_use, crypto::to_address}; use dkn_p2p::libp2p::Multiaddr; -use dkn_workflows::ModelConfig; +use dkn_workflows::DriaWorkflowsConfig; use eyre::{eyre, Result}; use libsecp256k1::{PublicKey, SecretKey}; @@ -18,8 +18,8 @@ pub struct DriaComputeNodeConfig { pub admin_public_key: PublicKey, /// P2P listen address, e.g. `/ip4/0.0.0.0/tcp/4001`. pub p2p_listen_addr: Multiaddr, - /// Available LLM models & providers for the node. - pub model_config: ModelConfig, + /// Workflow configurations, e.g. models and providers. + pub workflows: DriaWorkflowsConfig, } /// The default P2P network listen address. @@ -79,7 +79,8 @@ impl DriaComputeNodeConfig { let address = to_address(&public_key); log::info!("Node Address: 0x{}", hex::encode(address)); - let model_config = ModelConfig::new_from_csv(&env::var("DKN_MODELS").unwrap_or_default()); + let model_config = + DriaWorkflowsConfig::new_from_csv(&env::var("DKN_MODELS").unwrap_or_default()); #[cfg(not(test))] if model_config.models.is_empty() { log::error!("No models were provided, make sure to restart with at least one model provided within DKN_MODELS."); @@ -98,7 +99,7 @@ impl DriaComputeNodeConfig { secret_key, public_key, address, - model_config, + workflows: model_config, p2p_listen_addr, } } diff --git a/compute/src/handlers/pingpong.rs b/compute/src/handlers/pingpong.rs index 2982f4b..f4fa7d7 100644 --- a/compute/src/handlers/pingpong.rs +++ b/compute/src/handlers/pingpong.rs @@ -54,7 +54,7 @@ impl ComputeHandler for PingpongHandler { // respond let response_body = PingpongResponse { uuid: pingpong.uuid.clone(), - models: node.config.model_config.models.clone(), + models: node.config.workflows.models.clone(), timestamp: get_current_time_nanos(), }; diff --git a/compute/src/handlers/workflow.rs b/compute/src/handlers/workflow.rs index 92fca8e..4486cc6 100644 --- a/compute/src/handlers/workflow.rs +++ b/compute/src/handlers/workflow.rs @@ -66,9 +66,7 @@ impl ComputeHandler for WorkflowHandler { } // read model / provider from the task - let (model_provider, model) = config - .model_config - .get_any_matching_model(task.input.model)?; + let (model_provider, model) = config.workflows.get_any_matching_model(task.input.model)?; let model_name = model.to_string(); // get model name, we will pass it in payload log::info!("Using model {} for task {}", model_name, task.task_id); @@ -76,8 +74,8 @@ impl ComputeHandler for WorkflowHandler { let executor = if model_provider == ModelProvider::Ollama { Executor::new_at( model, - &config.model_config.ollama.host, - config.model_config.ollama.port, + &config.workflows.ollama.host, + config.workflows.ollama.port, ) } else { Executor::new(model) diff --git a/compute/src/main.rs b/compute/src/main.rs index b5b882f..4b424a4 100644 --- a/compute/src/main.rs +++ b/compute/src/main.rs @@ -54,7 +54,7 @@ async fn main() -> Result<()> { _ = service_check_token.cancelled() => { log::info!("Service check cancelled."); } - result = config_clone.model_config.check_services() => { + result = config_clone.workflows.check_services() => { if let Err(err) = result { log::error!("Error checking services: {:?}", err); panic!("Service check failed.") diff --git a/compute/src/node.rs b/compute/src/node.rs index 177fde2..7d17e8d 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -1,4 +1,4 @@ -use dkn_p2p::{libp2p::gossipsub, P2PClient}; +use dkn_p2p::{libp2p::gossipsub, DriaP2P}; use eyre::{eyre, Result}; use std::time::Duration; use tokio_util::sync::CancellationToken; @@ -27,7 +27,7 @@ const RPC_PEER_ID_REFRESH_INTERVAL_SECS: u64 = 30; /// ``` pub struct DriaComputeNode { pub config: DriaComputeNodeConfig, - pub p2p: P2PClient, + pub p2p: DriaP2P, pub available_nodes: AvailableNodes, pub available_nodes_last_refreshed: tokio::time::Instant, pub cancellation: CancellationToken, @@ -51,7 +51,7 @@ impl DriaComputeNode { ) .sort_dedup(); - let p2p = P2PClient::new( + let p2p = DriaP2P::new( keypair, config.p2p_listen_addr.clone(), &available_nodes.bootstrap_nodes, diff --git a/p2p/src/client.rs b/p2p/src/client.rs index f2e7861..fbb817e 100644 --- a/p2p/src/client.rs +++ b/p2p/src/client.rs @@ -1,3 +1,4 @@ +use super::*; use eyre::Result; use libp2p::futures::StreamExt; use libp2p::gossipsub::{ @@ -12,10 +13,8 @@ use libp2p_identity::Keypair; use std::time::Duration; use std::time::Instant; -use super::*; - /// P2P client, exposes a simple interface to handle P2P communication. -pub struct P2PClient { +pub struct DriaP2P { /// `Swarm` instance, everything is accesses through this one. swarm: Swarm, /// Peer count for All and Mesh peers. @@ -33,7 +32,7 @@ const IDLE_CONNECTION_TIMEOUT_SECS: u64 = 60; /// Number of seconds between refreshing the Kademlia DHT. const PEER_REFRESH_INTERVAL_SECS: u64 = 30; -impl P2PClient { +impl DriaP2P { /// Creates a new P2P client with the given keypair and listen address. pub fn new( keypair: Keypair, diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index a28267f..66aece0 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -2,14 +2,13 @@ mod behaviour; pub use behaviour::{DriaBehaviour, DriaBehaviourEvent}; mod client; -pub use client::P2PClient; +pub use client::DriaP2P; mod versioning; pub use versioning::*; -mod data_transform; +mod transform; // re-exports - pub use libp2p; pub use libp2p_identity; diff --git a/p2p/src/data_transform.rs b/p2p/src/transform.rs similarity index 100% rename from p2p/src/data_transform.rs rename to p2p/src/transform.rs diff --git a/workflows/src/config.rs b/workflows/src/config.rs index 6b44aae..d2bd29b 100644 --- a/workflows/src/config.rs +++ b/workflows/src/config.rs @@ -4,17 +4,18 @@ use ollama_workflows::{Model, ModelProvider}; use rand::seq::IteratorRandom; // provides Vec<_>.choose #[derive(Debug, Clone)] -pub struct ModelConfig { +pub struct DriaWorkflowsConfig { /// List of models with their providers. pub models: Vec<(ModelProvider, Model)>, - /// Even if Ollama is not used, we store the host & port here. - /// If Ollama is used, this config will be respected during its instantiations. + /// Ollama configurations, in case Ollama is used. + /// Otherwise, can be ignored. pub ollama: OllamaConfig, - /// OpenAI API key & its service check implementation. + /// OpenAI configurations, e.g. API key, in case OpenAI is used. + /// Otherwise, can be ignored. pub openai: OpenAIConfig, } -impl ModelConfig { +impl DriaWorkflowsConfig { pub fn new(models: Vec) -> Self { let models_and_providers = models .into_iter() @@ -193,7 +194,7 @@ impl ModelConfig { } } -impl std::fmt::Display for ModelConfig { +impl std::fmt::Display for DriaWorkflowsConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let models_str = self .models @@ -211,10 +212,10 @@ mod tests { #[test] fn test_csv_parser() { - let cfg = ModelConfig::new_from_csv("idontexist,i dont either,i332287648762"); + let cfg = DriaWorkflowsConfig::new_from_csv("idontexist,i dont either,i332287648762"); assert_eq!(cfg.models.len(), 0); - let cfg = ModelConfig::new_from_csv( + let cfg = DriaWorkflowsConfig::new_from_csv( "gemma2:9b-instruct-q8_0,phi3:14b-medium-4k-instruct-q4_1,balblablabl", ); assert_eq!(cfg.models.len(), 2); @@ -222,7 +223,7 @@ mod tests { #[test] fn test_model_matching() { - let cfg = ModelConfig::new_from_csv("gpt-4o,llama3.1:latest"); + let cfg = DriaWorkflowsConfig::new_from_csv("gpt-4o,llama3.1:latest"); assert_eq!( cfg.get_matching_model("openai".to_string()).unwrap().1, Model::GPT4o, @@ -251,7 +252,7 @@ mod tests { #[test] fn test_get_any_matching_model() { - let cfg = ModelConfig::new_from_csv("gpt-3.5-turbo,llama3.1:latest"); + let cfg = DriaWorkflowsConfig::new_from_csv("gpt-3.5-turbo,llama3.1:latest"); let result = cfg.get_any_matching_model(vec![ "i-dont-exist".to_string(), "llama3.1:latest".to_string(), diff --git a/workflows/src/lib.rs b/workflows/src/lib.rs index 4c258ea..1088758 100644 --- a/workflows/src/lib.rs +++ b/workflows/src/lib.rs @@ -5,7 +5,7 @@ mod providers; use providers::{OllamaConfig, OpenAIConfig}; mod config; -pub use config::ModelConfig; +pub use config::DriaWorkflowsConfig; pub use ollama_workflows; pub use ollama_workflows::{Model, ModelProvider}; diff --git a/workflows/src/providers/openai.rs b/workflows/src/providers/openai.rs index b48c3e2..2252e36 100644 --- a/workflows/src/providers/openai.rs +++ b/workflows/src/providers/openai.rs @@ -28,10 +28,11 @@ struct OpenAIModelsResponse { object: String, } +/// OpenAI-specific configurations. #[derive(Debug, Clone, Default)] pub struct OpenAIConfig { /// API key, if available. - pub(crate) api_key: Option, + api_key: Option, } impl OpenAIConfig { diff --git a/workflows/tests/models_test.rs b/workflows/tests/models_test.rs index d8e9c7b..beec9bd 100644 --- a/workflows/tests/models_test.rs +++ b/workflows/tests/models_test.rs @@ -1,8 +1,7 @@ -use std::env; - -use dkn_workflows::{ModelConfig, ModelProvider}; +use dkn_workflows::{DriaWorkflowsConfig, ModelProvider}; use eyre::Result; use ollama_workflows::Model; +use std::env; const LOG_LEVEL: &str = "none,dkn_workflows=debug"; @@ -13,7 +12,7 @@ async fn test_ollama_check() -> Result<()> { let _ = env_logger::try_init(); let models = vec![Model::Phi3_5Mini]; - let mut model_config = ModelConfig::new(models); + let mut model_config = DriaWorkflowsConfig::new(models); model_config.check_services().await?; assert_eq!( @@ -32,7 +31,7 @@ async fn test_openai_check() -> Result<()> { let _ = env_logger::try_init(); let models = vec![Model::GPT4Turbo]; - let mut model_config = ModelConfig::new(models); + let mut model_config = DriaWorkflowsConfig::new(models); model_config.check_services().await?; assert_eq!( @@ -44,7 +43,7 @@ async fn test_openai_check() -> Result<()> { #[tokio::test] async fn test_empty() -> Result<()> { - let mut model_config = ModelConfig::new(vec![]); + let mut model_config = DriaWorkflowsConfig::new(vec![]); let result = model_config.check_services().await; assert!(result.is_err()); From a417b7bf99cec356004621c1ff2e7401c6de6552 Mon Sep 17 00:00:00 2001 From: erhant Date: Mon, 7 Oct 2024 17:19:32 +0300 Subject: [PATCH 06/14] update workflows doc, todo p2p versioning --- Cargo.lock | 3 +-- Makefile | 6 +++--- compute/Cargo.toml | 7 ++----- p2p/Cargo.toml | 10 +++++++--- p2p/src/behaviour.rs | 9 ++++----- p2p/src/client.rs | 1 + p2p/src/lib.rs | 8 +++----- workflows/Cargo.toml | 1 - workflows/README.md | 19 ++++++++++++++++--- workflows/tests/models_test.rs | 3 +-- 10 files changed, 38 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0a3efa6..c716b47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -980,7 +980,7 @@ dependencies = [ [[package]] name = "dkn-compute" -version = "0.2.10" +version = "0.2.11" dependencies = [ "async-trait", "base64 0.22.1", @@ -1034,7 +1034,6 @@ dependencies = [ "eyre", "log", "ollama-workflows", - "parking_lot", "rand 0.8.5", "reqwest 0.12.8", "serde", diff --git a/Makefile b/Makefile index e46a10e..8d6f3a8 100644 --- a/Makefile +++ b/Makefile @@ -7,15 +7,15 @@ endif ############################################################################### .PHONY: launch # | Run with INFO logs in release mode launch: - RUST_LOG=none,dkn_compute=info,dkn_workflows=info cargo run --release + RUST_LOG=none,dkn_compute=info,dkn_workflows=info,dkn_p2p=info cargo run --release .PHONY: run # | Run with INFO logs run: - RUST_LOG=none,dkn_compute=info,dkn_workflows=info cargo run + RUST_LOG=none,dkn_compute=info,dkn_workflows=info,dkn_p2p=info cargo run .PHONY: debug # | Run with DEBUG logs with INFO log-level workflows debug: - RUST_LOG=warn,dkn_compute=debug,,dkn_workflows=debug,ollama_workflows=info cargo run + RUST_LOG=warn,dkn_compute=debug,dkn_workflows=debug,dkn_p2p=debug,ollama_workflows=info cargo run .PHONY: trace # | Run with TRACE logs trace: diff --git a/compute/Cargo.toml b/compute/Cargo.toml index 77d62ae..952c0cb 100644 --- a/compute/Cargo.toml +++ b/compute/Cargo.toml @@ -1,13 +1,10 @@ [package] name = "dkn-compute" -version = "0.2.10" +version = "0.2.11" edition.workspace = true license.workspace = true readme = "README.md" -authors = [ - "Erhan Tezcan ", - "Anil Altuner "] # profiling build for flamegraphs [profile.profiling] diff --git a/p2p/Cargo.toml b/p2p/Cargo.toml index 272ec60..61e750e 100644 --- a/p2p/Cargo.toml +++ b/p2p/Cargo.toml @@ -4,7 +4,10 @@ version = "0.1.0" edition.workspace = true license.workspace = true readme = "README.md" -authors = ["Erhan Tezcan "] +authors = [ + "Erhan Tezcan ", + "Anil Altuner gossipsub::Behaviour { /// This helps to avoid memory exhaustion during high load const MAX_SEND_QUEUE_SIZE: usize = 400; - // message id's are simply hashes of the message data + // message id's are simply hashes of the message data, via SipHash13 let message_id_fn = |message: &Message| { - // uses siphash by default let mut hasher = hash_map::DefaultHasher::new(); message.data.hash(&mut hasher); let digest = hasher.finish(); diff --git a/p2p/src/client.rs b/p2p/src/client.rs index fbb817e..7c76173 100644 --- a/p2p/src/client.rs +++ b/p2p/src/client.rs @@ -12,6 +12,7 @@ use libp2p::{Multiaddr, PeerId, Swarm, SwarmBuilder}; use libp2p_identity::Keypair; use std::time::Duration; use std::time::Instant; +use versioning::{P2P_KADEMLIA_PREFIX, P2P_KADEMLIA_PROTOCOL, P2P_PROTOCOL_STRING}; /// P2P client, exposes a simple interface to handle P2P communication. pub struct DriaP2P { diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index 66aece0..cce6980 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -1,14 +1,12 @@ +mod transform; +mod versioning; + mod behaviour; pub use behaviour::{DriaBehaviour, DriaBehaviourEvent}; mod client; pub use client::DriaP2P; -mod versioning; -pub use versioning::*; - -mod transform; - // re-exports pub use libp2p; pub use libp2p_identity; diff --git a/workflows/Cargo.toml b/workflows/Cargo.toml index cbcc82e..e43e5dd 100644 --- a/workflows/Cargo.toml +++ b/workflows/Cargo.toml @@ -10,7 +10,6 @@ authors = ["Erhan Tezcan "] [dependencies] tokio-util.workspace = true tokio.workspace = true -parking_lot.workspace = true serde.workspace = true serde_json.workspace = true async-trait.workspace = true diff --git a/workflows/README.md b/workflows/README.md index 06a7da4..53f965b 100644 --- a/workflows/README.md +++ b/workflows/README.md @@ -6,7 +6,11 @@ This crate handles the configurations of models to be used, and implements vario - **OpenAI**: We check that the chosen models are enabled for the user's profile by fetching their models with their API key. We filter out the disabled models. - **Ollama**: We provide a sample workflow to measure TPS and then pick models that are above some TPS threshold. While calculating TPS, there is also a timeout so that beyond that timeout the TPS is not even considered and the model becomes invalid. -## Environment Variables +## Installation + +TODO: !!! + +## Usage DKN Workflows make use of several environment variables, respecting the providers. @@ -14,6 +18,15 @@ DKN Workflows make use of several environment variables, respecting the provider - `OLLAMA_HOST` is used to connect to Ollama server - `OLLAMA_PORT` is used to connect to Ollama server - `OLLAMA_AUTO_PULL` indicates whether we should pull missing models automatically or not +- `SERPER_API_KEY` is optional API key to use **Serper**, for better Workflow executions +- `JINA_API_KEY` is optional API key to use **Jina**, for better Workflow executions + +With the environment variables ready, you can simply create a new configuration and call `check_services` to ensure all models are correctly setup: + +```rs +use dkn_workflows::{DriaWorkflowsConfig, Model}; -SERPER_API_KEY= -JINA_API_KEY= +let models = vec![Model::Phi3_5Mini]; +let mut config = DriaWorkflowsConfig::new(models); +config.check_services().await?; +``` diff --git a/workflows/tests/models_test.rs b/workflows/tests/models_test.rs index beec9bd..acfc786 100644 --- a/workflows/tests/models_test.rs +++ b/workflows/tests/models_test.rs @@ -1,6 +1,5 @@ -use dkn_workflows::{DriaWorkflowsConfig, ModelProvider}; +use dkn_workflows::{DriaWorkflowsConfig, Model, ModelProvider}; use eyre::Result; -use ollama_workflows::Model; use std::env; const LOG_LEVEL: &str = "none,dkn_workflows=debug"; From b35d7f4cea3b0862fe589ee4ec25d487e3900cc4 Mon Sep 17 00:00:00 2001 From: erhant Date: Mon, 7 Oct 2024 18:30:15 +0300 Subject: [PATCH 07/14] some p2p refactors on versioning --- Cargo.lock | 2 +- Cargo.toml | 11 +++++-- compute/Cargo.toml | 1 - compute/src/node.rs | 16 +++++++++-- compute/src/utils/available_nodes.rs | 3 +- compute/src/utils/crypto.rs | 6 ++-- p2p/Cargo.toml | 1 + p2p/README.md | 13 +++++++++ p2p/src/behaviour.rs | 28 ++++++++++++------ p2p/src/client.rs | 43 +++++++++++++++++++++------- p2p/src/lib.rs | 11 +++++-- p2p/src/versioning.rs | 35 ---------------------- p2p/tests/listen_test.rs | 40 ++++++++++++++++++++++++++ workflows/Cargo.toml | 6 ++-- workflows/README.md | 8 +++++- 15 files changed, 151 insertions(+), 73 deletions(-) create mode 100644 p2p/README.md delete mode 100644 p2p/src/versioning.rs create mode 100644 p2p/tests/listen_test.rs diff --git a/Cargo.lock b/Cargo.lock index c716b47..36864b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -996,7 +996,6 @@ dependencies = [ "libsecp256k1", "log", "openssl", - "parking_lot", "port_check", "rand 0.8.5", "reqwest 0.12.8", @@ -1022,6 +1021,7 @@ dependencies = [ "libp2p", "libp2p-identity", "log", + "tokio 1.40.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index dd0c36b..613de19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,18 +10,25 @@ license = "Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [workspace.dependencies] +# async stuff tokio-util = { version = "0.7.10", features = ["rt"] } tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal"] } -parking_lot = "0.12.2" +async-trait = "0.1.81" + +# serialize & deserialize serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -async-trait = "0.1.81" + +# http client reqwest = "0.12.5" +# env reading dotenvy = "0.15.7" +# randomization rand = "0.8.5" +# logging & errors env_logger = "0.11.3" log = "0.4.21" eyre = "0.6.12" diff --git a/compute/Cargo.toml b/compute/Cargo.toml index 952c0cb..bf36d26 100644 --- a/compute/Cargo.toml +++ b/compute/Cargo.toml @@ -18,7 +18,6 @@ profiling = [] [dependencies] tokio-util = { version = "0.7.10", features = ["rt"] } tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal"] } -parking_lot = "0.12.2" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" async-trait = "0.1.81" diff --git a/compute/src/node.rs b/compute/src/node.rs index 7d17e8d..e89a961 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -1,4 +1,4 @@ -use dkn_p2p::{libp2p::gossipsub, DriaP2P}; +use dkn_p2p::{libp2p::gossipsub, DriaP2PClient}; use eyre::{eyre, Result}; use std::time::Duration; use tokio_util::sync::CancellationToken; @@ -27,7 +27,7 @@ const RPC_PEER_ID_REFRESH_INTERVAL_SECS: u64 = 30; /// ``` pub struct DriaComputeNode { pub config: DriaComputeNodeConfig, - pub p2p: DriaP2P, + pub p2p: DriaP2PClient, pub available_nodes: AvailableNodes, pub available_nodes_last_refreshed: tokio::time::Instant, pub cancellation: CancellationToken, @@ -51,11 +51,21 @@ impl DriaComputeNode { ) .sort_dedup(); - let p2p = DriaP2P::new( + // we are using the major.minor version as the P2P version + // so that patch versions do not interfere with the protocol + const P2P_VERSION: &str = concat!( + env!("CARGO_PKG_VERSION_MAJOR"), + ".", + env!("CARGO_PKG_VERSION_MINOR") + ); + + // create p2p client + let p2p = DriaP2PClient::new( keypair, config.p2p_listen_addr.clone(), &available_nodes.bootstrap_nodes, &available_nodes.relay_nodes, + P2P_VERSION, )?; Ok(DriaComputeNode { diff --git a/compute/src/utils/available_nodes.rs b/compute/src/utils/available_nodes.rs index 50ede8a..45a6f73 100644 --- a/compute/src/utils/available_nodes.rs +++ b/compute/src/utils/available_nodes.rs @@ -1,9 +1,8 @@ use dkn_p2p::libp2p::{Multiaddr, PeerId}; +use dkn_workflows::split_csv_line; use eyre::Result; use std::{env, fmt::Debug, str::FromStr}; -use dkn_workflows::split_csv_line; - /// Static bootstrap nodes for the Kademlia DHT bootstrap step. const STATIC_BOOTSTRAP_NODES: [&str; 4] = [ "/ip4/44.206.245.139/tcp/4001/p2p/16Uiu2HAm4q3LZU2T9kgjKK4ysy6KZYKLq8KiXQyae4RHdF7uqSt4", diff --git a/compute/src/utils/crypto.rs b/compute/src/utils/crypto.rs index 38eb6ac..21a69f1 100644 --- a/compute/src/utils/crypto.rs +++ b/compute/src/utils/crypto.rs @@ -1,4 +1,4 @@ -use dkn_p2p::libp2p_identity::Keypair; +use dkn_p2p::libp2p_identity; use ecies::PublicKey; use eyre::{Context, Result}; use libsecp256k1::{Message, SecretKey}; @@ -55,12 +55,12 @@ pub fn encrypt_bytes(data: impl AsRef<[u8]>, public_key: &PublicKey) -> Result Keypair { +pub fn secret_to_keypair(secret_key: &SecretKey) -> libp2p_identity::Keypair { let bytes = secret_key.serialize(); let secret_key = dkn_p2p::libp2p_identity::secp256k1::SecretKey::try_from_bytes(bytes) .expect("Failed to create secret key"); - dkn_p2p::libp2p_identity::secp256k1::Keypair::from(secret_key).into() + libp2p_identity::secp256k1::Keypair::from(secret_key).into() } #[cfg(test)] diff --git a/p2p/Cargo.toml b/p2p/Cargo.toml index 61e750e..dd22a60 100644 --- a/p2p/Cargo.toml +++ b/p2p/Cargo.toml @@ -33,3 +33,4 @@ eyre.workspace = true [dev-dependencies] env_logger.workspace = true +tokio.workspace = true diff --git a/p2p/README.md b/p2p/README.md new file mode 100644 index 0000000..391df53 --- /dev/null +++ b/p2p/README.md @@ -0,0 +1,13 @@ +# DKN Peer-to-Peer Client + +## Installation + +Add the package via `git` within your Cargo dependencies: + +```toml +dkn-p2p = { git = "https://github.com/firstbatchxyz/dkn-compute-node" } +``` + +## Usage + +TODO: !!! diff --git a/p2p/src/behaviour.rs b/p2p/src/behaviour.rs index bc45dde..ea8ba18 100644 --- a/p2p/src/behaviour.rs +++ b/p2p/src/behaviour.rs @@ -4,9 +4,10 @@ use std::time::Duration; use libp2p::identity::{Keypair, PeerId, PublicKey}; use libp2p::kad::store::MemoryStore; +use libp2p::StreamProtocol; use libp2p::{autonat, dcutr, gossipsub, identify, kad, relay, swarm::NetworkBehaviour}; -use crate::versioning::{P2P_KADEMLIA_PROTOCOL, P2P_PROTOCOL_STRING}; +// use crate::versioning::{P2P_KADEMLIA_PROTOCOL, P2P_PROTOCOL_STRING}; #[derive(NetworkBehaviour)] pub struct DriaBehaviour { @@ -19,29 +20,37 @@ pub struct DriaBehaviour { } impl DriaBehaviour { - pub fn new(key: &Keypair, relay_behavior: relay::client::Behaviour) -> Self { + pub fn new( + key: &Keypair, + relay_behavior: relay::client::Behaviour, + identity_protocol: String, + kademlia_protocol: StreamProtocol, + ) -> Self { let public_key = key.public(); let peer_id = public_key.to_peer_id(); Self { relay: relay_behavior, gossipsub: create_gossipsub_behavior(peer_id), - kademlia: create_kademlia_behavior(peer_id), + kademlia: create_kademlia_behavior(peer_id, kademlia_protocol), autonat: create_autonat_behavior(peer_id), dcutr: create_dcutr_behavior(peer_id), - identify: create_identify_behavior(public_key), + identify: create_identify_behavior(public_key, identity_protocol), } } } /// Configures the Kademlia DHT behavior for the node. #[inline] -fn create_kademlia_behavior(local_peer_id: PeerId) -> kad::Behaviour { +fn create_kademlia_behavior( + local_peer_id: PeerId, + protocol_name: StreamProtocol, +) -> kad::Behaviour { use kad::{Behaviour, Config}; const QUERY_TIMEOUT_SECS: u64 = 5 * 60; const RECORD_TTL_SECS: u64 = 30; - let mut cfg = Config::new(P2P_KADEMLIA_PROTOCOL); + let mut cfg = Config::new(protocol_name); cfg.set_query_timeout(Duration::from_secs(QUERY_TIMEOUT_SECS)) .set_record_ttl(Some(Duration::from_secs(RECORD_TTL_SECS))); @@ -50,10 +59,13 @@ fn create_kademlia_behavior(local_peer_id: PeerId) -> kad::Behaviour identify::Behaviour { +fn create_identify_behavior( + local_public_key: PublicKey, + protocol_version: String, +) -> identify::Behaviour { use identify::{Behaviour, Config}; - let cfg = Config::new(P2P_PROTOCOL_STRING.to_string(), local_public_key); + let cfg = Config::new(protocol_version, local_public_key); Behaviour::new(cfg) } diff --git a/p2p/src/client.rs b/p2p/src/client.rs index 7c76173..bae6926 100644 --- a/p2p/src/client.rs +++ b/p2p/src/client.rs @@ -8,14 +8,12 @@ use libp2p::kad::{GetClosestPeersError, GetClosestPeersOk, QueryResult}; use libp2p::{ autonat, gossipsub, identify, kad, multiaddr::Protocol, noise, swarm::SwarmEvent, tcp, yamux, }; -use libp2p::{Multiaddr, PeerId, Swarm, SwarmBuilder}; +use libp2p::{Multiaddr, PeerId, StreamProtocol, Swarm, SwarmBuilder}; use libp2p_identity::Keypair; -use std::time::Duration; -use std::time::Instant; -use versioning::{P2P_KADEMLIA_PREFIX, P2P_KADEMLIA_PROTOCOL, P2P_PROTOCOL_STRING}; +use std::time::{Duration, Instant}; /// P2P client, exposes a simple interface to handle P2P communication. -pub struct DriaP2P { +pub struct DriaP2PClient { /// `Swarm` instance, everything is accesses through this one. swarm: Swarm, /// Peer count for All and Mesh peers. @@ -25,6 +23,10 @@ pub struct DriaP2P { peer_count: (usize, usize), /// Last time the peer count was refreshed. peer_last_refreshed: Instant, + /// Identity protocol string to be used for the Identity behaviour. + identity_protocol: String, + /// Kademlia protocol, must match with other peers in the network. + kademlia_protocol: StreamProtocol, } /// Number of seconds before an idle connection is closed. @@ -33,14 +35,24 @@ const IDLE_CONNECTION_TIMEOUT_SECS: u64 = 60; /// Number of seconds between refreshing the Kademlia DHT. const PEER_REFRESH_INTERVAL_SECS: u64 = 30; -impl DriaP2P { +impl DriaP2PClient { /// Creates a new P2P client with the given keypair and listen address. + /// + /// Can provide a list of bootstrap and relay nodes to connect to as well at the start. + /// + /// The `version` is used to create the protocol strings for the client, and its very important that + /// they match with the clients existing within the network. pub fn new( keypair: Keypair, listen_addr: Multiaddr, bootstraps: &[Multiaddr], relays: &[Multiaddr], + version: &str, ) -> Result { + let identity_protocol = format!("{}{}", P2P_IDENTITY_PREFIX, version); + let kademlia_protocol = + StreamProtocol::try_from_owned(format!("{}{}", P2P_KADEMLIA_PREFIX, version))?; + // this is our peerId let node_peerid = keypair.public().to_peer_id(); log::info!("Compute node peer address: {}", node_peerid); @@ -54,7 +66,14 @@ impl DriaP2P { )? .with_quic() .with_relay_client(noise::Config::new, yamux::Config::default)? - .with_behaviour(|key, relay_behavior| Ok(DriaBehaviour::new(key, relay_behavior)))? + .with_behaviour(|key, relay_behavior| { + Ok(DriaBehaviour::new( + key, + relay_behavior, + identity_protocol.clone(), + kademlia_protocol.clone(), + )) + })? .with_swarm_config(|c| { c.with_idle_connection_timeout(Duration::from_secs(IDLE_CONNECTION_TIMEOUT_SECS)) }) @@ -107,6 +126,8 @@ impl DriaP2P { swarm, peer_count: (0, 0), peer_last_refreshed: Instant::now(), + identity_protocol, + kademlia_protocol, }) } @@ -231,12 +252,12 @@ impl DriaP2P { /// - For Kademlia, we check the kademlia protocol and then add the address to the Kademlia routing table. fn handle_identify_event(&mut self, peer_id: PeerId, info: identify::Info) { // check identify protocol string - if info.protocol_version != P2P_PROTOCOL_STRING { + if info.protocol_version != self.identity_protocol { log::warn!( "Identify: Peer {} has different Identify protocol: (them {}, you {})", peer_id, info.protocol_version, - P2P_PROTOCOL_STRING + self.identity_protocol ); return; } @@ -248,7 +269,7 @@ impl DriaP2P { .find(|p| p.to_string().starts_with(P2P_KADEMLIA_PREFIX)) { // if it matches our protocol, add it to the Kademlia routing table - if *kad_protocol == P2P_KADEMLIA_PROTOCOL { + if *kad_protocol == self.kademlia_protocol { // filter listen addresses let addrs = info.listen_addrs.into_iter().filter(|listen_addr| { if let Some(Protocol::Ip4(ipv4_addr)) = listen_addr.iter().next() { @@ -279,7 +300,7 @@ impl DriaP2P { "Identify: Peer {} has different Kademlia version: (them {}, you {})", peer_id, kad_protocol, - P2P_KADEMLIA_PROTOCOL + self.kademlia_protocol ); } } diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index cce6980..97eacec 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -1,11 +1,16 @@ mod transform; -mod versioning; mod behaviour; -pub use behaviour::{DriaBehaviour, DriaBehaviourEvent}; +use behaviour::{DriaBehaviour, DriaBehaviourEvent}; mod client; -pub use client::DriaP2P; +pub use client::DriaP2PClient; + +/// Prefix for Kademlia protocol, must start with `/`! +pub(crate) const P2P_KADEMLIA_PREFIX: &str = "/dria/kad/"; + +/// Prefix for Identity protocol string. +pub(crate) const P2P_IDENTITY_PREFIX: &str = "dria/"; // re-exports pub use libp2p; diff --git a/p2p/src/versioning.rs b/p2p/src/versioning.rs deleted file mode 100644 index 11b2a4d..0000000 --- a/p2p/src/versioning.rs +++ /dev/null @@ -1,35 +0,0 @@ -use libp2p::StreamProtocol; - -/// Kademlia protocol prefix, as a macro so that it can be used in literal-expecting constants. -macro_rules! P2P_KADEMLIA_PREFIX { - () => { - "/dria/kad/" - }; -} -pub const P2P_KADEMLIA_PREFIX: &str = P2P_KADEMLIA_PREFIX!(); - -/// Identity protocol name prefix, as a macro so that it can be used in literal-expecting constants. -macro_rules! P2P_IDENTITY_PREFIX { - () => { - "dria/" - }; -} - -/// Kademlia protocol version, in the form of `/dria/kad/`, **notice the `/` at the start**. -/// -/// It is important that this protocol matches EXACTLY among the nodes, otherwise there is a protocol-level logic -/// that will prevent peers from finding eachother within the DHT. -pub const P2P_KADEMLIA_PROTOCOL: StreamProtocol = StreamProtocol::new(concat!( - P2P_KADEMLIA_PREFIX!(), - env!("CARGO_PKG_VERSION_MAJOR"), - ".", - env!("CARGO_PKG_VERSION_MINOR") -)); - -/// Protocol string, checked by Identify protocol handlers. -pub const P2P_PROTOCOL_STRING: &str = concat!( - P2P_IDENTITY_PREFIX!(), - env!("CARGO_PKG_VERSION_MAJOR"), - ".", - env!("CARGO_PKG_VERSION_MINOR") -); diff --git a/p2p/tests/listen_test.rs b/p2p/tests/listen_test.rs new file mode 100644 index 0000000..e286429 --- /dev/null +++ b/p2p/tests/listen_test.rs @@ -0,0 +1,40 @@ +use dkn_p2p::DriaP2PClient; +use eyre::Result; +use libp2p::Multiaddr; +use libp2p_identity::Keypair; +use std::{env, str::FromStr}; + +const LOG_LEVEL: &str = "none,dkn_p2p=debug"; + +#[tokio::test] +#[ignore = "run manually with logs"] +async fn test_listen_topic_once() -> Result<()> { + // topic to be listened to + const TOPIC: &str = "pong"; + + env::set_var("RUST_LOG", LOG_LEVEL); + let _ = env_logger::try_init(); + + // setup client + let keypair = Keypair::generate_secp256k1(); + let addr = Multiaddr::from_str("/ip4/0.0.0.0/tcp/4001")?; + let bootstraps = vec![Multiaddr::from_str( + "/ip4/44.206.245.139/tcp/4001/p2p/16Uiu2HAm4q3LZU2T9kgjKK4ysy6KZYKLq8KiXQyae4RHdF7uqSt4", + )?]; + let relays = vec![Multiaddr::from_str( + "/ip4/34.201.33.141/tcp/4001/p2p/16Uiu2HAkuXiV2CQkC9eJgU6cMnJ9SMARa85FZ6miTkvn5fuHNufa", + )?]; + let mut client = DriaP2PClient::new(keypair, addr, &bootstraps, &relays, "0.2")?; + + // subscribe to the given topic + client.subscribe(TOPIC)?; + + // wait for a single gossipsub message on this topic + let message = client.process_events().await; + log::info!("Received {} message: {:?}", TOPIC, message); + + // unsubscribe gracefully + client.unsubscribe(TOPIC)?; + + Ok(()) +} diff --git a/workflows/Cargo.toml b/workflows/Cargo.toml index e43e5dd..34a6179 100644 --- a/workflows/Cargo.toml +++ b/workflows/Cargo.toml @@ -8,6 +8,9 @@ authors = ["Erhan Tezcan "] [dependencies] +# ollama-rs is re-exported from ollama-workflows as well +ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows" } + tokio-util.workspace = true tokio.workspace = true serde.workspace = true @@ -18,9 +21,6 @@ rand.workspace = true log.workspace = true eyre.workspace = true -# ollama-rs is re-exported from ollama-workflows -ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows" } - [dev-dependencies] env_logger.workspace = true dotenvy.workspace = true diff --git a/workflows/README.md b/workflows/README.md index 53f965b..46eea30 100644 --- a/workflows/README.md +++ b/workflows/README.md @@ -8,7 +8,13 @@ This crate handles the configurations of models to be used, and implements vario ## Installation -TODO: !!! +Add the package via `git` within your Cargo dependencies: + +```toml +dkn-workflows = { git = "https://github.com/firstbatchxyz/dkn-compute-node" } +``` + +Note that the underlying Ollama Workflows crate is re-exported by this crate. ## Usage From 84895b829c5e7d7a29e292c34ce5d540236c18ea Mon Sep 17 00:00:00 2001 From: erhant Date: Mon, 7 Oct 2024 21:25:02 +0300 Subject: [PATCH 08/14] rm profiling feature --- Cargo.toml | 7 ++++ Makefile | 5 +-- compute/.env.example | 37 ++++++++++++++++++++ compute/Cargo.toml | 9 ----- compute/src/config.rs | 9 ++--- compute/src/main.rs | 56 ++++++++++++++----------------- workflows/src/providers/ollama.rs | 5 +-- 7 files changed, 81 insertions(+), 47 deletions(-) create mode 100644 compute/.env.example diff --git a/Cargo.toml b/Cargo.toml index 613de19..e94c65a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,15 @@ [workspace] resolver = "2" members = ["compute", "p2p", "workflows"] +# compute node is the default member, until Oracle comes in +# then, a Launcher will be the default member default-members = ["compute"] +# profiling build for flamegraphs +[profile.profiling] +inherits = "release" +debug = true + [workspace.package] edition = "2021" license = "Apache-2.0" diff --git a/Makefile b/Makefile index 8d6f3a8..c41c9fb 100644 --- a/Makefile +++ b/Makefile @@ -27,11 +27,11 @@ build: .PHONY: profile-cpu # | Profile CPU usage with flamegraph profile-cpu: - cargo flamegraph --root --profile=profiling --features=profiling + DKN_EXIT_TIMEOUT=120 cargo flamegraph --root --profile=profiling .PHONY: profile-mem # | Profile memory usage with instruments profile-mem: - cargo instruments --profile=profiling --features=profiling -t Allocations + DKN_EXIT_TIMEOUT=120 cargo instruments --profile=profiling -t Allocations ############################################################################### .PHONY: test # | Run tests @@ -42,6 +42,7 @@ test: .PHONY: lint # | Run linter (clippy) lint: cargo clippy + cargo clippy .PHONY: format # | Run formatter (cargo fmt) format: diff --git a/compute/.env.example b/compute/.env.example new file mode 100644 index 0000000..1238c5e --- /dev/null +++ b/compute/.env.example @@ -0,0 +1,37 @@ +## DRIA (required) ## +# Secret key of your compute node, 32 byte in hexadecimal. +# e.g.: DKN_WALLET_SECRET_KEY=0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80 +DKN_WALLET_SECRET_KEY= +# Public key of Dria Admin node, 33-byte (compressed) in hexadecimal. +# You don't need to change this, simply copy and paste it. +DKN_ADMIN_PUBLIC_KEY=0208ef5e65a9c656a6f92fb2c770d5d5e2ecffe02a6aade19207f75110be6ae658 +# model1,model2,model3,... (comma separated, case-insensitive) +# example: phi3:3.8b,gpt-4o-mini +DKN_MODELS= + +## DRIA (optional) ## +# P2P address, you don't need to change this unless this port is already in use. +DKN_P2P_LISTEN_ADDR=/ip4/0.0.0.0/tcp/4001 +# Comma-separated static relay nodes +DKN_RELAY_NODES= +# Comma-separated static bootstrap nodes +DKN_BOOTSTRAP_NODES= + +# PROFILING ONLY: set to a number of seconds to wait before exiting +# DKN_EXIT_TIMEOUT= + +## Open AI (if used, required) ## +OPENAI_API_KEY= + +## Ollama (if used, optional) ## +# do not change this, it is used by Docker +OLLAMA_HOST=http://host.docker.internal +# you can change the port if you would like +OLLAMA_PORT=11434 +# if "true", automatically pull models from Ollama +# if "false", you have to download manually +OLLAMA_AUTO_PULL=true + +## Additional Services (optional) +SERPER_API_KEY= +JINA_API_KEY= diff --git a/compute/Cargo.toml b/compute/Cargo.toml index bf36d26..fb226d7 100644 --- a/compute/Cargo.toml +++ b/compute/Cargo.toml @@ -6,15 +6,6 @@ license.workspace = true readme = "README.md" authors = ["Erhan Tezcan "] -# profiling build for flamegraphs -[profile.profiling] -inherits = "release" -debug = true - -[features] -# used by flamegraphs & instruments -profiling = [] - [dependencies] tokio-util = { version = "0.7.10", features = ["rt"] } tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal"] } diff --git a/compute/src/config.rs b/compute/src/config.rs index c5e54de..ce8cf72 100644 --- a/compute/src/config.rs +++ b/compute/src/config.rs @@ -79,14 +79,14 @@ impl DriaComputeNodeConfig { let address = to_address(&public_key); log::info!("Node Address: 0x{}", hex::encode(address)); - let model_config = + let workflows = DriaWorkflowsConfig::new_from_csv(&env::var("DKN_MODELS").unwrap_or_default()); #[cfg(not(test))] - if model_config.models.is_empty() { + if workflows.models.is_empty() { log::error!("No models were provided, make sure to restart with at least one model provided within DKN_MODELS."); panic!("No models provided."); } - log::info!("Models: {:?}", model_config.models); + log::info!("Models: {:?}", workflows.models); let p2p_listen_addr_str = env::var("DKN_P2P_LISTEN_ADDR") .map(|addr| addr.trim_matches('"').to_string()) @@ -99,12 +99,13 @@ impl DriaComputeNodeConfig { secret_key, public_key, address, - workflows: model_config, + workflows, p2p_listen_addr, } } /// Asserts that the configured listen address is free. + /// Throws an error if the address is already in use. pub fn assert_address_not_in_use(&self) -> Result<()> { if address_in_use(&self.p2p_listen_addr) { return Err(eyre!( diff --git a/compute/src/main.rs b/compute/src/main.rs index 4b424a4..f61e4dd 100644 --- a/compute/src/main.rs +++ b/compute/src/main.rs @@ -1,16 +1,18 @@ +use std::env; + use dkn_compute::*; -use eyre::Result; +use eyre::{Context, Result}; use tokio_util::sync::CancellationToken; #[tokio::main] async fn main() -> Result<()> { - if let Err(e) = dotenvy::dotenv() { - log::warn!("Could not load .env file: {}", e); - } - + let dotenv_result = dotenvy::dotenv(); env_logger::builder() .format_timestamp(Some(env_logger::TimestampPrecision::Millis)) .init(); + if let Err(e) = dotenv_result { + log::warn!("Could not load .env file: {}", e); + } log::info!( r#" @@ -26,49 +28,45 @@ async fn main() -> Result<()> { let token = CancellationToken::new(); let cancellation_token = token.clone(); - // add cancellation check tokio::spawn(async move { - // FIXME: weird feature-gating here bugs with IDE, fix this later - #[cfg(feature = "profiling")] - { - const PROFILE_DURATION_SECS: u64 = 120; - tokio::time::sleep(tokio::time::Duration::from_secs(PROFILE_DURATION_SECS)).await; + if let Some(timeout_str) = env::var("DKN_EXIT_TIMEOUT").ok() { + // add cancellation check + let duration_secs = timeout_str.parse().unwrap_or(120); + tokio::time::sleep(tokio::time::Duration::from_secs(duration_secs)).await; cancellation_token.cancel(); + } else { + if let Err(err) = wait_for_termination(cancellation_token.clone()).await { + log::error!("Error waiting for termination: {:?}", err); + log::error!("Cancelling due to unexpected error."); + cancellation_token.cancel(); + }; } - - #[cfg(not(feature = "profiling"))] - if let Err(err) = wait_for_termination(cancellation_token.clone()).await { - log::error!("Error waiting for termination: {:?}", err); - log::error!("Cancelling due to unexpected error."); - cancellation_token.cancel(); - }; }); - // create configurations & check required services - let config = DriaComputeNodeConfig::new(); + // create configurations & check required services & address in use + let mut config = DriaComputeNodeConfig::new(); config.assert_address_not_in_use()?; let service_check_token = token.clone(); - let mut config_clone = config.clone(); let service_check_handle = tokio::spawn(async move { tokio::select! { _ = service_check_token.cancelled() => { log::info!("Service check cancelled."); + config } - result = config_clone.workflows.check_services() => { + result = config.workflows.check_services() => { if let Err(err) = result { log::error!("Error checking services: {:?}", err); panic!("Service check failed.") } + config } } }); + let config = service_check_handle + .await + .wrap_err("error during service checks")?; - // wait for service check to complete - if let Err(err) = service_check_handle.await { - log::error!("Service check handle error: {}", err); - panic!("Could not exit service check thread handle."); - }; - + log::warn!("Using models: {:#?}", config.workflows.models); if !token.is_cancelled() { // launch the node let node_token = token.clone(); @@ -97,11 +95,9 @@ async fn main() -> Result<()> { Ok(()) } -// FIXME: remove this `unused` once we have a better way to handle this /// Waits for various termination signals, and cancels the given token when the signal is received. /// /// Handles Unix and Windows [target families](https://doc.rust-lang.org/reference/conditional-compilation.html#target_family). -#[allow(unused)] async fn wait_for_termination(cancellation: CancellationToken) -> Result<()> { #[cfg(unix)] { diff --git a/workflows/src/providers/ollama.rs b/workflows/src/providers/ollama.rs index 41c806b..d507130 100644 --- a/workflows/src/providers/ollama.rs +++ b/workflows/src/providers/ollama.rs @@ -82,9 +82,10 @@ impl OllamaConfig { /// Check if requested models exist in Ollama, and then tests them using a workflow. pub async fn check(&self, external_models: Vec) -> Result> { log::info!( - "Checking Ollama requirements (auto-pull {}, workflow timeout: {}s)", + "Checking Ollama requirements (auto-pull {}, timeout: {}s, min tps: {})", if self.auto_pull { "on" } else { "off" }, - self.timeout.as_secs() + self.timeout.as_secs(), + self.min_tps ); let ollama = Ollama::new(&self.host, self.port); From 1edf491510ae1f9276a27d1f22344b5d7c0801cd Mon Sep 17 00:00:00 2001 From: erhant Date: Tue, 8 Oct 2024 10:52:28 +0300 Subject: [PATCH 09/14] try dev build --- .env.example | 5 ++- .github/workflows/build_dev_container.yml | 14 ++++----- compute/Cross.toml => Cross.toml | 0 compute/.env.example | 37 ----------------------- p2p/src/behaviour.rs | 6 ++-- p2p/src/client.rs | 6 ++++ workflows/.env.example | 13 -------- 7 files changed, 19 insertions(+), 62 deletions(-) rename compute/Cross.toml => Cross.toml (100%) delete mode 100644 compute/.env.example delete mode 100644 workflows/.env.example diff --git a/.env.example b/.env.example index 2b1d1ae..61b3481 100644 --- a/.env.example +++ b/.env.example @@ -16,7 +16,10 @@ DKN_P2P_LISTEN_ADDR=/ip4/0.0.0.0/tcp/4001 DKN_RELAY_NODES= # Comma-separated static bootstrap nodes DKN_BOOTSTRAP_NODES= - +# Set to a number of seconds to wait before exiting, only use in profiling build! +# Otherwise, leave this empty. +DKN_EXIT_TIMEOUT= + ## Open AI (if used, required) ## OPENAI_API_KEY= diff --git a/.github/workflows/build_dev_container.yml b/.github/workflows/build_dev_container.yml index b0290d0..372f026 100644 --- a/.github/workflows/build_dev_container.yml +++ b/.github/workflows/build_dev_container.yml @@ -1,13 +1,13 @@ name: Create Dev Image on: push: - branches: ["master"] - paths: - - "src/**" - - "Cargo.lock" - - "Cargo.toml" - - "Dockerfile" - - "compose.yml" + branches: ["master", "erhant/subcrates"] + # paths: + # - "src/**" + # - "Cargo.lock" + # - "Cargo.toml" + # - "Dockerfile" + # - "compose.yml" jobs: build-and-push: diff --git a/compute/Cross.toml b/Cross.toml similarity index 100% rename from compute/Cross.toml rename to Cross.toml diff --git a/compute/.env.example b/compute/.env.example deleted file mode 100644 index 1238c5e..0000000 --- a/compute/.env.example +++ /dev/null @@ -1,37 +0,0 @@ -## DRIA (required) ## -# Secret key of your compute node, 32 byte in hexadecimal. -# e.g.: DKN_WALLET_SECRET_KEY=0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80 -DKN_WALLET_SECRET_KEY= -# Public key of Dria Admin node, 33-byte (compressed) in hexadecimal. -# You don't need to change this, simply copy and paste it. -DKN_ADMIN_PUBLIC_KEY=0208ef5e65a9c656a6f92fb2c770d5d5e2ecffe02a6aade19207f75110be6ae658 -# model1,model2,model3,... (comma separated, case-insensitive) -# example: phi3:3.8b,gpt-4o-mini -DKN_MODELS= - -## DRIA (optional) ## -# P2P address, you don't need to change this unless this port is already in use. -DKN_P2P_LISTEN_ADDR=/ip4/0.0.0.0/tcp/4001 -# Comma-separated static relay nodes -DKN_RELAY_NODES= -# Comma-separated static bootstrap nodes -DKN_BOOTSTRAP_NODES= - -# PROFILING ONLY: set to a number of seconds to wait before exiting -# DKN_EXIT_TIMEOUT= - -## Open AI (if used, required) ## -OPENAI_API_KEY= - -## Ollama (if used, optional) ## -# do not change this, it is used by Docker -OLLAMA_HOST=http://host.docker.internal -# you can change the port if you would like -OLLAMA_PORT=11434 -# if "true", automatically pull models from Ollama -# if "false", you have to download manually -OLLAMA_AUTO_PULL=true - -## Additional Services (optional) -SERPER_API_KEY= -JINA_API_KEY= diff --git a/p2p/src/behaviour.rs b/p2p/src/behaviour.rs index ea8ba18..bf64399 100644 --- a/p2p/src/behaviour.rs +++ b/p2p/src/behaviour.rs @@ -5,11 +5,9 @@ use std::time::Duration; use libp2p::identity::{Keypair, PeerId, PublicKey}; use libp2p::kad::store::MemoryStore; use libp2p::StreamProtocol; -use libp2p::{autonat, dcutr, gossipsub, identify, kad, relay, swarm::NetworkBehaviour}; +use libp2p::{autonat, dcutr, gossipsub, identify, kad, relay}; -// use crate::versioning::{P2P_KADEMLIA_PROTOCOL, P2P_PROTOCOL_STRING}; - -#[derive(NetworkBehaviour)] +#[derive(libp2p::swarm::NetworkBehaviour)] pub struct DriaBehaviour { pub(crate) relay: relay::client::Behaviour, pub(crate) gossipsub: gossipsub::Behaviour, diff --git a/p2p/src/client.rs b/p2p/src/client.rs index bae6926..f388f12 100644 --- a/p2p/src/client.rs +++ b/p2p/src/client.rs @@ -24,12 +24,18 @@ pub struct DriaP2PClient { /// Last time the peer count was refreshed. peer_last_refreshed: Instant, /// Identity protocol string to be used for the Identity behaviour. + /// + /// This is usually `dria/{version}`. identity_protocol: String, /// Kademlia protocol, must match with other peers in the network. + /// + /// This is usually `/dria/kad/{version}`, notice the `/` at the start + /// which is mandatory for a `StreamProtocol`. kademlia_protocol: StreamProtocol, } /// Number of seconds before an idle connection is closed. +/// TODO: default is 0, is 60 a good value? const IDLE_CONNECTION_TIMEOUT_SECS: u64 = 60; /// Number of seconds between refreshing the Kademlia DHT. diff --git a/workflows/.env.example b/workflows/.env.example deleted file mode 100644 index eec9599..0000000 --- a/workflows/.env.example +++ /dev/null @@ -1,13 +0,0 @@ -## Open AI (if used, required) ## -OPENAI_API_KEY= - -## Ollama (if used, optional) ## -OLLAMA_HOST=http://localhost -OLLAMA_PORT=11434 -# if "true", automatically pull models from Ollama -# if "false", you have to download manually -OLLAMA_AUTO_PULL=true - -## Additional Services (optional) -SERPER_API_KEY= -JINA_API_KEY= From f4812e8679201bb71e7524c4b114575cfacb0f55 Mon Sep 17 00:00:00 2001 From: erhant Date: Tue, 8 Oct 2024 10:54:04 +0300 Subject: [PATCH 10/14] tiny edit --- .github/workflows/build_dev_container.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_dev_container.yml b/.github/workflows/build_dev_container.yml index 372f026..b5cf3ea 100644 --- a/.github/workflows/build_dev_container.yml +++ b/.github/workflows/build_dev_container.yml @@ -44,7 +44,7 @@ jobs: - name: Set Image Tag id: itag - run: echo "itag=${{ steps.branch.outputs.branch }}-${{ steps.sha.outputs.sha }}-${{ steps.timestamp.outputs.timestamp }}" >> $GITHUB_OUTPUT + run: echo "itag=testing-${{ steps.sha.outputs.sha }}-${{ steps.timestamp.outputs.timestamp }}" >> $GITHUB_OUTPUT - name: Build and push uses: docker/build-push-action@v6 From a4ba7ccfe89e9191be9b1f0937b547628169c716 Mon Sep 17 00:00:00 2001 From: erhant Date: Tue, 8 Oct 2024 10:55:17 +0300 Subject: [PATCH 11/14] dockerfile --- compute/Dockerfile => Dockerfile | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename compute/Dockerfile => Dockerfile (100%) diff --git a/compute/Dockerfile b/Dockerfile similarity index 100% rename from compute/Dockerfile rename to Dockerfile From 540a226019b4428dd11e75cbee3bd4f435a28d9e Mon Sep 17 00:00:00 2001 From: erhant Date: Tue, 8 Oct 2024 14:44:38 +0300 Subject: [PATCH 12/14] fix versioning --- .github/workflows/build_dev_container.yml | 22 +++++++++++++++------- Cargo.lock | 4 ++-- Cargo.toml | 9 ++++++--- compute/compose.yml => compose.yml | 0 compute/Cargo.toml | 2 +- p2p/Cargo.toml | 2 +- workflows/Cargo.toml | 2 +- 7 files changed, 26 insertions(+), 15 deletions(-) rename compute/compose.yml => compose.yml (100%) diff --git a/.github/workflows/build_dev_container.yml b/.github/workflows/build_dev_container.yml index b5cf3ea..63d5c88 100644 --- a/.github/workflows/build_dev_container.yml +++ b/.github/workflows/build_dev_container.yml @@ -1,13 +1,21 @@ name: Create Dev Image on: push: - branches: ["master", "erhant/subcrates"] - # paths: - # - "src/**" - # - "Cargo.lock" - # - "Cargo.toml" - # - "Dockerfile" - # - "compose.yml" + branches: ["master"] + paths: + # Source files in each member + - "compute/src/**" + - "p2p/src/**" + - "workflows/src/**" + # Cargo in each member + - "compute/Cargo.toml" + - "p2p/Cargo.toml" + - "workflows/Cargo.toml" + # root-level changes + - "Cargo.lock" + - "Cross.toml" + - "Dockerfile" + - "compose.yml" jobs: build-and-push: diff --git a/Cargo.lock b/Cargo.lock index 36864b3..402170e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1014,7 +1014,7 @@ dependencies = [ [[package]] name = "dkn-p2p" -version = "0.1.0" +version = "0.2.11" dependencies = [ "env_logger 0.11.5", "eyre", @@ -1026,7 +1026,7 @@ dependencies = [ [[package]] name = "dkn-workflows" -version = "0.1.0" +version = "0.2.11" dependencies = [ "async-trait", "dotenvy", diff --git a/Cargo.toml b/Cargo.toml index e94c65a..c6bb77a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,14 +5,17 @@ members = ["compute", "p2p", "workflows"] # then, a Launcher will be the default member default-members = ["compute"] +[workspace.package] +edition = "2021" +version = "0.2.11" +license = "Apache-2.0" +readme = "README.md" + # profiling build for flamegraphs [profile.profiling] inherits = "release" debug = true -[workspace.package] -edition = "2021" -license = "Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/compute/compose.yml b/compose.yml similarity index 100% rename from compute/compose.yml rename to compose.yml diff --git a/compute/Cargo.toml b/compute/Cargo.toml index fb226d7..c9e9a7e 100644 --- a/compute/Cargo.toml +++ b/compute/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dkn-compute" -version = "0.2.11" +version.workspace = true edition.workspace = true license.workspace = true readme = "README.md" diff --git a/p2p/Cargo.toml b/p2p/Cargo.toml index dd22a60..f988166 100644 --- a/p2p/Cargo.toml +++ b/p2p/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dkn-p2p" -version = "0.1.0" +version.workspace = true edition.workspace = true license.workspace = true readme = "README.md" diff --git a/workflows/Cargo.toml b/workflows/Cargo.toml index 34a6179..e068f37 100644 --- a/workflows/Cargo.toml +++ b/workflows/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dkn-workflows" -version = "0.1.0" +version.workspace = true edition.workspace = true license.workspace = true readme = "README.md" From f4d2312952a296d824aa86a08947bf6ca5496e81 Mon Sep 17 00:00:00 2001 From: erhant Date: Tue, 8 Oct 2024 14:53:49 +0300 Subject: [PATCH 13/14] small workflow edits --- .github/workflows/build_dev_container.yml | 2 +- .github/workflows/tests.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build_dev_container.yml b/.github/workflows/build_dev_container.yml index 63d5c88..60a1df8 100644 --- a/.github/workflows/build_dev_container.yml +++ b/.github/workflows/build_dev_container.yml @@ -52,7 +52,7 @@ jobs: - name: Set Image Tag id: itag - run: echo "itag=testing-${{ steps.sha.outputs.sha }}-${{ steps.timestamp.outputs.timestamp }}" >> $GITHUB_OUTPUT + run: echo "itag=${{ steps.sha.outputs.branch }}-${{ steps.sha.outputs.sha }}-${{ steps.timestamp.outputs.timestamp }}" >> $GITHUB_OUTPUT - name: Build and push uses: docker/build-push-action@v6 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d83e82f..1f093c1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -2,9 +2,9 @@ name: tests on: # push: + # branches: + # - master workflow_dispatch: - branches: - - master jobs: test: From 29308308f973329ba32f5d8411634bf4d1446d3f Mon Sep 17 00:00:00 2001 From: erhant Date: Tue, 8 Oct 2024 15:11:35 +0300 Subject: [PATCH 14/14] fix test workflow command and `make` commands --- .github/workflows/tests.yml | 8 ++++---- Makefile | 5 ++--- README.md | 2 +- p2p/README.md | 30 +++++++++++++++++++++++++++++- 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1f093c1..a6ad599 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,9 +1,9 @@ name: tests on: - # push: - # branches: - # - master + push: + branches: + - master workflow_dispatch: jobs: @@ -18,4 +18,4 @@ jobs: uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Run tests - run: cargo test + run: cargo test --workspace diff --git a/Makefile b/Makefile index c41c9fb..328c2e4 100644 --- a/Makefile +++ b/Makefile @@ -36,13 +36,12 @@ profile-mem: ############################################################################### .PHONY: test # | Run tests test: - cargo test + cargo test --workspace ############################################################################### .PHONY: lint # | Run linter (clippy) lint: - cargo clippy - cargo clippy + cargo clippy --workspace .PHONY: format # | Run formatter (cargo fmt) format: diff --git a/README.md b/README.md index 9cda781..5df0d7f 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Compute nodes can technically do any arbitrary task, from computing the square r - **Ping/Pong**: Dria Admin Node broadcasts **ping** messages at a set interval, it is a required duty of the compute node to respond with a **pong** to these so that they can be included in the list of available nodes for task assignment. These tasks will respect the type of model provided within the pong message, e.g. if a task requires `gpt-4o` and you are running `phi3`, you won't be selected for that task. -- **Workflows**: Each task is given in the form of a workflow, based on [Ollama Workflows](https://github.com/andthattoo/ollama-workflows) (see repository for more information). In simple terms, each workflow defines the agentic behavior of an LLM, all captured in a single JSON file, and can represent things ranging from simple LLM generations to iterative web searching. +- **Workflows**: Each task is given in the form of a workflow, based on [Ollama Workflows](https://github.com/andthattoo/ollama-workflows). In simple terms, each workflow defines the agentic behavior of an LLM, all captured in a single JSON file, and can represent things ranging from simple LLM generations to iterative web searching. ## Node Running diff --git a/p2p/README.md b/p2p/README.md index 391df53..4ff7908 100644 --- a/p2p/README.md +++ b/p2p/README.md @@ -1,5 +1,7 @@ # DKN Peer-to-Peer Client +Dria Knowledge Network is a peer-to-peer network, built over libp2p. This crate is a wrapper client to easily interact with DKN. + ## Installation Add the package via `git` within your Cargo dependencies: @@ -10,4 +12,30 @@ dkn-p2p = { git = "https://github.com/firstbatchxyz/dkn-compute-node" } ## Usage -TODO: !!! +You can create the client as follows: + +```rs +use dkn_p2p::DriaP2PClient; + +// your wallet, or something random maybe +let keypair = Keypair::generate_secp256k1(); + +// your listen address +let addr = Multiaddr::from_str("/ip4/0.0.0.0/tcp/4001")?; + +// static bootstrap & relay addresses +let bootstraps = vec![Multiaddr::from_str( + "some-multiaddrs-here" +)?]; +let relays = vec![Multiaddr::from_str( + "some-multiaddrs-here" +)?]; + +// protocol version number, usually derived as `{major}.{minor}` +let version = "0.2"; + +// create the client! +let mut client = DriaP2PClient::new(keypair, addr, &bootstraps, &relays, "0.2")?; +``` + +Then, you can use its underlying functions, such as `subscribe`, `process_events` and `unsubscribe`. In particular, `process_events` handles all p2p events and returns a GossipSub message when it is received.