diff --git a/.env.example b/.env.example index 1f34c69..61b3481 100644 --- a/.env.example +++ b/.env.example @@ -10,13 +10,16 @@ DKN_ADMIN_PUBLIC_KEY=0208ef5e65a9c656a6f92fb2c770d5d5e2ecffe02a6aade19207f75110b DKN_MODELS= ## DRIA (optional) ## -# P2P address, you don't need to change this unless you really want this port. +# P2P address, you don't need to change this unless this port is already in use. DKN_P2P_LISTEN_ADDR=/ip4/0.0.0.0/tcp/4001 # Comma-separated static relay nodes DKN_RELAY_NODES= # Comma-separated static bootstrap nodes DKN_BOOTSTRAP_NODES= - +# Set to a number of seconds to wait before exiting, only use in profiling build! +# Otherwise, leave this empty. +DKN_EXIT_TIMEOUT= + ## Open AI (if used, required) ## OPENAI_API_KEY= diff --git a/.github/workflows/build_dev_container.yml b/.github/workflows/build_dev_container.yml index b0290d0..60a1df8 100644 --- a/.github/workflows/build_dev_container.yml +++ b/.github/workflows/build_dev_container.yml @@ -3,9 +3,17 @@ on: push: branches: ["master"] paths: - - "src/**" + # Source files in each member + - "compute/src/**" + - "p2p/src/**" + - "workflows/src/**" + # Cargo in each member + - "compute/Cargo.toml" + - "p2p/Cargo.toml" + - "workflows/Cargo.toml" + # root-level changes - "Cargo.lock" - - "Cargo.toml" + - "Cross.toml" - "Dockerfile" - "compose.yml" @@ -44,7 +52,7 @@ jobs: - name: Set Image Tag id: itag - run: echo "itag=${{ steps.branch.outputs.branch }}-${{ steps.sha.outputs.sha }}-${{ steps.timestamp.outputs.timestamp }}" >> $GITHUB_OUTPUT + run: echo "itag=${{ steps.sha.outputs.branch }}-${{ steps.sha.outputs.sha }}-${{ steps.timestamp.outputs.timestamp }}" >> $GITHUB_OUTPUT - name: Build and push uses: docker/build-push-action@v6 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d83e82f..a6ad599 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,10 +1,10 @@ name: tests on: - # push: - workflow_dispatch: + push: branches: - master + workflow_dispatch: jobs: test: @@ -18,4 +18,4 @@ jobs: uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Run tests - run: cargo test + run: cargo test --workspace diff --git a/Cargo.lock b/Cargo.lock index 01ccd84..402170e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -980,10 +980,12 @@ dependencies = [ [[package]] name = "dkn-compute" -version = "0.2.10" +version = "0.2.11" dependencies = [ "async-trait", "base64 0.22.1", + "dkn-p2p", + "dkn-workflows", "dotenvy", "ecies", "env_logger 0.11.5", @@ -991,13 +993,9 @@ dependencies = [ "fastbloom-rs", "hex", "hex-literal", - "libp2p", - "libp2p-identity", "libsecp256k1", "log", - "ollama-workflows", "openssl", - "parking_lot", "port_check", "rand 0.8.5", "reqwest 0.12.8", @@ -1014,6 +1012,36 @@ dependencies = [ "uuid", ] +[[package]] +name = "dkn-p2p" +version = "0.2.11" +dependencies = [ + "env_logger 0.11.5", + "eyre", + "libp2p", + "libp2p-identity", + "log", + "tokio 1.40.0", +] + +[[package]] +name = "dkn-workflows" +version = "0.2.11" +dependencies = [ + "async-trait", + "dotenvy", + "env_logger 0.11.5", + "eyre", + "log", + "ollama-workflows", + "rand 0.8.5", + "reqwest 0.12.8", + "serde", + "serde_json", + "tokio 1.40.0", + "tokio-util 0.7.12", +] + [[package]] name = "dotenv" version = "0.15.0" diff --git a/Cargo.toml b/Cargo.toml index 14bfc4f..c6bb77a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,78 +1,44 @@ -[package] -name = "dkn-compute" -version = "0.2.10" +[workspace] +resolver = "2" +members = ["compute", "p2p", "workflows"] +# compute node is the default member, until Oracle comes in +# then, a Launcher will be the default member +default-members = ["compute"] + +[workspace.package] edition = "2021" +version = "0.2.11" license = "Apache-2.0" readme = "README.md" -authors = ["Erhan Tezcan "] # profiling build for flamegraphs [profile.profiling] inherits = "release" debug = true -[features] -# used by flamegraphs & instruments -profiling = [] -[dependencies] +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[workspace.dependencies] +# async stuff tokio-util = { version = "0.7.10", features = ["rt"] } tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal"] } -parking_lot = "0.12.2" +async-trait = "0.1.81" + +# serialize & deserialize serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -async-trait = "0.1.81" + +# http client reqwest = "0.12.5" -# utilities +# env reading dotenvy = "0.15.7" -base64 = "0.22.0" -hex = "0.4.3" -hex-literal = "0.4.1" -url = "2.5.0" -urlencoding = "2.1.3" -uuid = { version = "1.8.0", features = ["v4"] } + +# randomization rand = "0.8.5" # logging & errors env_logger = "0.11.3" log = "0.4.21" eyre = "0.6.12" - -# encryption (ecies) & signatures (ecdsa) & hashing & bloom-filters -ecies = { version = "0.2", default-features = false, features = ["pure"] } -libsecp256k1 = "0.7.1" -sha2 = "0.10.8" -sha3 = "0.10.8" -fastbloom-rs = "0.5.9" - -# workflows -ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows" } - -# peer-to-peer -libp2p = { git = "https://github.com/anilaltuner/rust-libp2p.git", rev = "7ce9f9e", features = [ - # libp2p = { version = "0.54.1", features = [ - "dcutr", - "ping", - "relay", - "autonat", - "identify", - "tokio", - "gossipsub", - "mdns", - "noise", - "macros", - "tcp", - "yamux", - "quic", - "kad", -] } -libp2p-identity = { version = "0.2.9", features = ["secp256k1"] } -tracing = { version = "0.1.40" } -tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } -port_check = "0.2.1" - -# Vendor OpenSSL so that its easier to build cross-platform packages -[dependencies.openssl] -version = "*" -features = ["vendored"] diff --git a/Makefile b/Makefile index 5ec333b..328c2e4 100644 --- a/Makefile +++ b/Makefile @@ -7,15 +7,15 @@ endif ############################################################################### .PHONY: launch # | Run with INFO logs in release mode launch: - RUST_LOG=none,dkn_compute=info cargo run --release + RUST_LOG=none,dkn_compute=info,dkn_workflows=info,dkn_p2p=info cargo run --release .PHONY: run # | Run with INFO logs run: - RUST_LOG=none,dkn_compute=info cargo run + RUST_LOG=none,dkn_compute=info,dkn_workflows=info,dkn_p2p=info cargo run .PHONY: debug # | Run with DEBUG logs with INFO log-level workflows debug: - RUST_LOG=warn,dkn_compute=debug,ollama_workflows=info cargo run + RUST_LOG=warn,dkn_compute=debug,dkn_workflows=debug,dkn_p2p=debug,ollama_workflows=info cargo run .PHONY: trace # | Run with TRACE logs trace: @@ -27,21 +27,21 @@ build: .PHONY: profile-cpu # | Profile CPU usage with flamegraph profile-cpu: - cargo flamegraph --root --profile=profiling --features=profiling + DKN_EXIT_TIMEOUT=120 cargo flamegraph --root --profile=profiling .PHONY: profile-mem # | Profile memory usage with instruments profile-mem: - cargo instruments --profile=profiling --features=profiling -t Allocations + DKN_EXIT_TIMEOUT=120 cargo instruments --profile=profiling -t Allocations ############################################################################### .PHONY: test # | Run tests test: - cargo test + cargo test --workspace ############################################################################### .PHONY: lint # | Run linter (clippy) lint: - cargo clippy + cargo clippy --workspace .PHONY: format # | Run formatter (cargo fmt) format: diff --git a/README.md b/README.md index 9cda781..5df0d7f 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Compute nodes can technically do any arbitrary task, from computing the square r - **Ping/Pong**: Dria Admin Node broadcasts **ping** messages at a set interval, it is a required duty of the compute node to respond with a **pong** to these so that they can be included in the list of available nodes for task assignment. These tasks will respect the type of model provided within the pong message, e.g. if a task requires `gpt-4o` and you are running `phi3`, you won't be selected for that task. -- **Workflows**: Each task is given in the form of a workflow, based on [Ollama Workflows](https://github.com/andthattoo/ollama-workflows) (see repository for more information). In simple terms, each workflow defines the agentic behavior of an LLM, all captured in a single JSON file, and can represent things ranging from simple LLM generations to iterative web searching. +- **Workflows**: Each task is given in the form of a workflow, based on [Ollama Workflows](https://github.com/andthattoo/ollama-workflows). In simple terms, each workflow defines the agentic behavior of an LLM, all captured in a single JSON file, and can represent things ranging from simple LLM generations to iterative web searching. ## Node Running diff --git a/compute/Cargo.toml b/compute/Cargo.toml new file mode 100644 index 0000000..c9e9a7e --- /dev/null +++ b/compute/Cargo.toml @@ -0,0 +1,50 @@ +[package] +name = "dkn-compute" +version.workspace = true +edition.workspace = true +license.workspace = true +readme = "README.md" +authors = ["Erhan Tezcan "] + +[dependencies] +tokio-util = { version = "0.7.10", features = ["rt"] } +tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +async-trait = "0.1.81" +reqwest = "0.12.5" + +# utilities +dotenvy.workspace = true +base64 = "0.22.0" +hex = "0.4.3" +hex-literal = "0.4.1" +url = "2.5.0" +urlencoding = "2.1.3" +uuid = { version = "1.8.0", features = ["v4"] } + +port_check = "0.2.1" + +# logging & errors +rand.workspace = true +env_logger.workspace = true +log.workspace = true +eyre.workspace = true +tracing = { version = "0.1.40" } +tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } + +# encryption (ecies) & signatures (ecdsa) & hashing & bloom-filters +ecies = { version = "0.2", default-features = false, features = ["pure"] } +libsecp256k1 = "0.7.1" +sha2 = "0.10.8" +sha3 = "0.10.8" +fastbloom-rs = "0.5.9" + +# dria subcrates +dkn-p2p = { path = "../p2p" } +dkn-workflows = { path = "../workflows" } + +# Vendor OpenSSL so that its easier to build cross-platform packages +[dependencies.openssl] +version = "*" +features = ["vendored"] diff --git a/src/config/mod.rs b/compute/src/config.rs similarity index 58% rename from src/config/mod.rs rename to compute/src/config.rs index 7cc3bd2..ce8cf72 100644 --- a/src/config/mod.rs +++ b/compute/src/config.rs @@ -1,23 +1,10 @@ -mod models; -mod ollama; -mod openai; - use crate::utils::{address_in_use, crypto::to_address}; +use dkn_p2p::libp2p::Multiaddr; +use dkn_workflows::DriaWorkflowsConfig; use eyre::{eyre, Result}; -use libp2p::Multiaddr; use libsecp256k1::{PublicKey, SecretKey}; -use models::ModelConfig; -use ollama::OllamaConfig; -use ollama_workflows::ModelProvider; -use openai::OpenAIConfig; - -use std::{env, str::FromStr, time::Duration}; - -/// Timeout duration for checking model performance during a generation. -const CHECK_TIMEOUT_DURATION: Duration = Duration::from_secs(80); -/// Minimum tokens per second (TPS) for checking model performance during a generation. -const CHECK_TPS: f64 = 15.0; +use std::{env, str::FromStr}; #[derive(Debug, Clone)] pub struct DriaComputeNodeConfig { @@ -31,13 +18,8 @@ pub struct DriaComputeNodeConfig { pub admin_public_key: PublicKey, /// P2P listen address, e.g. `/ip4/0.0.0.0/tcp/4001`. pub p2p_listen_addr: Multiaddr, - /// Available LLM models & providers for the node. - pub model_config: ModelConfig, - /// Even if Ollama is not used, we store the host & port here. - /// If Ollama is used, this config will be respected during its instantiations. - pub ollama_config: OllamaConfig, - /// OpenAI API key & its service check implementation. - pub openai_config: OpenAIConfig, + /// Workflow configurations, e.g. models and providers. + pub workflows: DriaWorkflowsConfig, } /// The default P2P network listen address. @@ -97,13 +79,14 @@ impl DriaComputeNodeConfig { let address = to_address(&public_key); log::info!("Node Address: 0x{}", hex::encode(address)); - let model_config = ModelConfig::new_from_csv(env::var("DKN_MODELS").ok()); + let workflows = + DriaWorkflowsConfig::new_from_csv(&env::var("DKN_MODELS").unwrap_or_default()); #[cfg(not(test))] - if model_config.models.is_empty() { + if workflows.models.is_empty() { log::error!("No models were provided, make sure to restart with at least one model provided within DKN_MODELS."); panic!("No models provided."); } - log::info!("Models: {:?}", model_config.models); + log::info!("Models: {:?}", workflows.models); let p2p_listen_addr_str = env::var("DKN_P2P_LISTEN_ADDR") .map(|addr| addr.trim_matches('"').to_string()) @@ -116,74 +99,14 @@ impl DriaComputeNodeConfig { secret_key, public_key, address, - model_config, + workflows, p2p_listen_addr, - ollama_config: OllamaConfig::new(), - openai_config: OpenAIConfig::new(), - } - } - - /// Check if the required compute services are running. - /// This has several steps: - /// - /// - If Ollama models are used, hardcoded models are checked locally, and for - /// external models, the workflow is tested with a simple task with timeout. - /// - If OpenAI models are used, the API key is checked and the models are tested - /// - /// If both type of models are used, both services are checked. - /// In the end, bad models are filtered out and we simply check if we are left if any valid models at all. - /// If not, an error is returned. - pub async fn check_services(&mut self) -> Result<()> { - log::info!("Checking configured services."); - - // TODO: can refactor (provider, model) logic here - let unique_providers = self.model_config.get_providers(); - - let mut good_models = Vec::new(); - - // if Ollama is a provider, check that it is running & Ollama models are pulled (or pull them) - if unique_providers.contains(&ModelProvider::Ollama) { - let ollama_models = self - .model_config - .get_models_for_provider(ModelProvider::Ollama); - - // ensure that the models are pulled / pull them if not - let good_ollama_models = self - .ollama_config - .check(ollama_models, CHECK_TIMEOUT_DURATION, CHECK_TPS) - .await?; - good_models.extend( - good_ollama_models - .into_iter() - .map(|m| (ModelProvider::Ollama, m)), - ); - } - - // if OpenAI is a provider, check that the API key is set - if unique_providers.contains(&ModelProvider::OpenAI) { - let openai_models = self - .model_config - .get_models_for_provider(ModelProvider::OpenAI); - - let good_openai_models = self.openai_config.check(openai_models).await?; - good_models.extend( - good_openai_models - .into_iter() - .map(|m| (ModelProvider::OpenAI, m)), - ); - } - - // update good models - if good_models.is_empty() { - Err(eyre!("No good models found, please check logs for errors.")) - } else { - self.model_config.models = good_models; - Ok(()) } } - // ensure that listen address is free - pub fn check_address_in_use(&self) -> Result<()> { + /// Asserts that the configured listen address is free. + /// Throws an error if the address is already in use. + pub fn assert_address_not_in_use(&self) -> Result<()> { if address_in_use(&self.p2p_listen_addr) { return Err(eyre!( "Listen address {} is already in use.", diff --git a/src/handlers/mod.rs b/compute/src/handlers/mod.rs similarity index 95% rename from src/handlers/mod.rs rename to compute/src/handlers/mod.rs index edb018e..00ccf51 100644 --- a/src/handlers/mod.rs +++ b/compute/src/handlers/mod.rs @@ -1,7 +1,7 @@ use crate::{utils::DKNMessage, DriaComputeNode}; use async_trait::async_trait; +use dkn_p2p::libp2p::gossipsub::MessageAcceptance; use eyre::Result; -use libp2p::gossipsub::MessageAcceptance; mod pingpong; pub use pingpong::PingpongHandler; diff --git a/src/handlers/pingpong.rs b/compute/src/handlers/pingpong.rs similarity index 92% rename from src/handlers/pingpong.rs rename to compute/src/handlers/pingpong.rs index 53ab332..f4fa7d7 100644 --- a/src/handlers/pingpong.rs +++ b/compute/src/handlers/pingpong.rs @@ -4,9 +4,9 @@ use crate::{ DriaComputeNode, }; use async_trait::async_trait; +use dkn_p2p::libp2p::gossipsub::MessageAcceptance; +use dkn_workflows::{Model, ModelProvider}; use eyre::{Context, Result}; -use libp2p::gossipsub::MessageAcceptance; -use ollama_workflows::{Model, ModelProvider}; use serde::{Deserialize, Serialize}; pub struct PingpongHandler; @@ -54,7 +54,7 @@ impl ComputeHandler for PingpongHandler { // respond let response_body = PingpongResponse { uuid: pingpong.uuid.clone(), - models: node.config.model_config.models.clone(), + models: node.config.workflows.models.clone(), timestamp: get_current_time_nanos(), }; diff --git a/src/handlers/workflow.rs b/compute/src/handlers/workflow.rs similarity index 93% rename from src/handlers/workflow.rs rename to compute/src/handlers/workflow.rs index 41f2789..4486cc6 100644 --- a/src/handlers/workflow.rs +++ b/compute/src/handlers/workflow.rs @@ -1,8 +1,8 @@ use async_trait::async_trait; +use dkn_p2p::libp2p::gossipsub::MessageAcceptance; +use dkn_workflows::ollama_workflows::{Entry, Executor, ModelProvider, ProgramMemory, Workflow}; use eyre::{eyre, Context, Result}; -use libp2p::gossipsub::MessageAcceptance; use libsecp256k1::PublicKey; -use ollama_workflows::{Entry, Executor, ModelProvider, ProgramMemory, Workflow}; use serde::Deserialize; use crate::payloads::{TaskErrorPayload, TaskRequestPayload, TaskResponsePayload}; @@ -66,15 +66,17 @@ impl ComputeHandler for WorkflowHandler { } // read model / provider from the task - let (model_provider, model) = config - .model_config - .get_any_matching_model(task.input.model)?; + let (model_provider, model) = config.workflows.get_any_matching_model(task.input.model)?; let model_name = model.to_string(); // get model name, we will pass it in payload log::info!("Using model {} for task {}", model_name, task.task_id); // prepare workflow executor let executor = if model_provider == ModelProvider::Ollama { - Executor::new_at(model, &config.ollama_config.host, config.ollama_config.port) + Executor::new_at( + model, + &config.workflows.ollama.host, + config.workflows.ollama.port, + ) } else { Executor::new(model) }; diff --git a/src/lib.rs b/compute/src/lib.rs similarity index 95% rename from src/lib.rs rename to compute/src/lib.rs index adc2b65..d924a68 100644 --- a/src/lib.rs +++ b/compute/src/lib.rs @@ -4,7 +4,6 @@ pub(crate) mod config; pub(crate) mod handlers; pub(crate) mod node; -pub(crate) mod p2p; pub(crate) mod payloads; pub(crate) mod utils; diff --git a/src/main.rs b/compute/src/main.rs similarity index 79% rename from src/main.rs rename to compute/src/main.rs index a2fd86d..f61e4dd 100644 --- a/src/main.rs +++ b/compute/src/main.rs @@ -1,16 +1,18 @@ +use std::env; + use dkn_compute::*; -use eyre::Result; +use eyre::{Context, Result}; use tokio_util::sync::CancellationToken; #[tokio::main] async fn main() -> Result<()> { - if let Err(e) = dotenvy::dotenv() { - log::warn!("Could not load .env file: {}", e); - } - + let dotenv_result = dotenvy::dotenv(); env_logger::builder() .format_timestamp(Some(env_logger::TimestampPrecision::Millis)) .init(); + if let Err(e) = dotenv_result { + log::warn!("Could not load .env file: {}", e); + } log::info!( r#" @@ -26,49 +28,45 @@ async fn main() -> Result<()> { let token = CancellationToken::new(); let cancellation_token = token.clone(); - // add cancellation check tokio::spawn(async move { - // FIXME: weird feature-gating here bugs with IDE, fix this later - #[cfg(feature = "profiling")] - { - const PROFILE_DURATION_SECS: u64 = 120; - tokio::time::sleep(tokio::time::Duration::from_secs(PROFILE_DURATION_SECS)).await; + if let Some(timeout_str) = env::var("DKN_EXIT_TIMEOUT").ok() { + // add cancellation check + let duration_secs = timeout_str.parse().unwrap_or(120); + tokio::time::sleep(tokio::time::Duration::from_secs(duration_secs)).await; cancellation_token.cancel(); + } else { + if let Err(err) = wait_for_termination(cancellation_token.clone()).await { + log::error!("Error waiting for termination: {:?}", err); + log::error!("Cancelling due to unexpected error."); + cancellation_token.cancel(); + }; } - - #[cfg(not(feature = "profiling"))] - if let Err(err) = wait_for_termination(cancellation_token.clone()).await { - log::error!("Error waiting for termination: {:?}", err); - log::error!("Cancelling due to unexpected error."); - cancellation_token.cancel(); - }; }); - // create configurations & check required services - let config = DriaComputeNodeConfig::new(); - config.check_address_in_use()?; + // create configurations & check required services & address in use + let mut config = DriaComputeNodeConfig::new(); + config.assert_address_not_in_use()?; let service_check_token = token.clone(); - let mut config_clone = config.clone(); let service_check_handle = tokio::spawn(async move { tokio::select! { _ = service_check_token.cancelled() => { log::info!("Service check cancelled."); + config } - result = config_clone.check_services() => { + result = config.workflows.check_services() => { if let Err(err) = result { log::error!("Error checking services: {:?}", err); panic!("Service check failed.") } + config } } }); + let config = service_check_handle + .await + .wrap_err("error during service checks")?; - // wait for service check to complete - if let Err(err) = service_check_handle.await { - log::error!("Service check handle error: {}", err); - panic!("Could not exit service check thread handle."); - }; - + log::warn!("Using models: {:#?}", config.workflows.models); if !token.is_cancelled() { // launch the node let node_token = token.clone(); @@ -97,11 +95,9 @@ async fn main() -> Result<()> { Ok(()) } -// FIXME: remove this `unused` once we have a better way to handle this /// Waits for various termination signals, and cancels the given token when the signal is received. /// /// Handles Unix and Windows [target families](https://doc.rust-lang.org/reference/conditional-compilation.html#target_family). -#[allow(unused)] async fn wait_for_termination(cancellation: CancellationToken) -> Result<()> { #[cfg(unix)] { diff --git a/src/node.rs b/compute/src/node.rs similarity index 93% rename from src/node.rs rename to compute/src/node.rs index e824f1c..e89a961 100644 --- a/src/node.rs +++ b/compute/src/node.rs @@ -1,12 +1,11 @@ +use dkn_p2p::{libp2p::gossipsub, DriaP2PClient}; use eyre::{eyre, Result}; -use libp2p::gossipsub; use std::time::Duration; use tokio_util::sync::CancellationToken; use crate::{ config::*, handlers::*, - p2p::P2PClient, utils::{crypto::secret_to_keypair, AvailableNodes, DKNMessage}, }; @@ -28,7 +27,7 @@ const RPC_PEER_ID_REFRESH_INTERVAL_SECS: u64 = 30; /// ``` pub struct DriaComputeNode { pub config: DriaComputeNodeConfig, - pub p2p: P2PClient, + pub p2p: DriaP2PClient, pub available_nodes: AvailableNodes, pub available_nodes_last_refreshed: tokio::time::Instant, pub cancellation: CancellationToken, @@ -52,7 +51,22 @@ impl DriaComputeNode { ) .sort_dedup(); - let p2p = P2PClient::new(keypair, config.p2p_listen_addr.clone(), &available_nodes)?; + // we are using the major.minor version as the P2P version + // so that patch versions do not interfere with the protocol + const P2P_VERSION: &str = concat!( + env!("CARGO_PKG_VERSION_MAJOR"), + ".", + env!("CARGO_PKG_VERSION_MINOR") + ); + + // create p2p client + let p2p = DriaP2PClient::new( + keypair, + config.p2p_listen_addr.clone(), + &available_nodes.bootstrap_nodes, + &available_nodes.relay_nodes, + P2P_VERSION, + )?; Ok(DriaComputeNode { p2p, @@ -97,7 +111,12 @@ impl DriaComputeNode { /// Returns the list of connected peers. #[inline(always)] - pub fn peers(&self) -> Vec<(&libp2p_identity::PeerId, Vec<&gossipsub::TopicHash>)> { + pub fn peers( + &self, + ) -> Vec<( + &dkn_p2p::libp2p_identity::PeerId, + Vec<&gossipsub::TopicHash>, + )> { self.p2p.peers() } diff --git a/src/payloads/error.rs b/compute/src/payloads/error.rs similarity index 100% rename from src/payloads/error.rs rename to compute/src/payloads/error.rs diff --git a/src/payloads/mod.rs b/compute/src/payloads/mod.rs similarity index 100% rename from src/payloads/mod.rs rename to compute/src/payloads/mod.rs diff --git a/src/payloads/request.rs b/compute/src/payloads/request.rs similarity index 100% rename from src/payloads/request.rs rename to compute/src/payloads/request.rs diff --git a/src/payloads/response.rs b/compute/src/payloads/response.rs similarity index 100% rename from src/payloads/response.rs rename to compute/src/payloads/response.rs diff --git a/src/utils/available_nodes.rs b/compute/src/utils/available_nodes.rs similarity index 94% rename from src/utils/available_nodes.rs rename to compute/src/utils/available_nodes.rs index 2215be1..45a6f73 100644 --- a/src/utils/available_nodes.rs +++ b/compute/src/utils/available_nodes.rs @@ -1,9 +1,8 @@ +use dkn_p2p::libp2p::{Multiaddr, PeerId}; +use dkn_workflows::split_csv_line; use eyre::Result; -use libp2p::{Multiaddr, PeerId}; use std::{env, fmt::Debug, str::FromStr}; -use crate::utils::split_comma_separated; - /// Static bootstrap nodes for the Kademlia DHT bootstrap step. const STATIC_BOOTSTRAP_NODES: [&str; 4] = [ "/ip4/44.206.245.139/tcp/4001/p2p/16Uiu2HAm4q3LZU2T9kgjKK4ysy6KZYKLq8KiXQyae4RHdF7uqSt4", @@ -48,7 +47,7 @@ impl AvailableNodes { /// - `DRIA_RELAY_NODES`: comma-separated list of relay nodes pub fn new_from_env() -> Self { // parse bootstrap nodes - let bootstrap_nodes = split_comma_separated(env::var("DKN_BOOTSTRAP_NODES").ok()); + let bootstrap_nodes = split_csv_line(&env::var("DKN_BOOTSTRAP_NODES").unwrap_or_default()); if bootstrap_nodes.is_empty() { log::debug!("No additional bootstrap nodes provided."); } else { @@ -56,7 +55,7 @@ impl AvailableNodes { } // parse relay nodes - let relay_nodes = split_comma_separated(env::var("DKN_RELAY_NODES").ok()); + let relay_nodes = split_csv_line(&env::var("DKN_RELAY_NODES").unwrap_or_default()); if relay_nodes.is_empty() { log::debug!("No additional relay nodes provided."); } else { diff --git a/src/utils/crypto.rs b/compute/src/utils/crypto.rs similarity index 96% rename from src/utils/crypto.rs rename to compute/src/utils/crypto.rs index 86bc376..21a69f1 100644 --- a/src/utils/crypto.rs +++ b/compute/src/utils/crypto.rs @@ -1,6 +1,6 @@ +use dkn_p2p::libp2p_identity; use ecies::PublicKey; use eyre::{Context, Result}; -use libp2p_identity::Keypair; use libsecp256k1::{Message, SecretKey}; use sha2::{Digest, Sha256}; use sha3::Keccak256; @@ -55,10 +55,10 @@ pub fn encrypt_bytes(data: impl AsRef<[u8]>, public_key: &PublicKey) -> Result Keypair { +pub fn secret_to_keypair(secret_key: &SecretKey) -> libp2p_identity::Keypair { let bytes = secret_key.serialize(); - let secret_key = libp2p_identity::secp256k1::SecretKey::try_from_bytes(bytes) + let secret_key = dkn_p2p::libp2p_identity::secp256k1::SecretKey::try_from_bytes(bytes) .expect("Failed to create secret key"); libp2p_identity::secp256k1::Keypair::from(secret_key).into() } diff --git a/src/utils/filter.rs b/compute/src/utils/filter.rs similarity index 100% rename from src/utils/filter.rs rename to compute/src/utils/filter.rs diff --git a/src/utils/message.rs b/compute/src/utils/message.rs similarity index 97% rename from src/utils/message.rs rename to compute/src/utils/message.rs index c7613e2..f05f026 100644 --- a/src/utils/message.rs +++ b/compute/src/utils/message.rs @@ -123,10 +123,10 @@ impl fmt::Display for DKNMessage { } } -impl TryFrom for DKNMessage { +impl TryFrom for DKNMessage { type Error = serde_json::Error; - fn try_from(value: libp2p::gossipsub::Message) -> Result { + fn try_from(value: dkn_p2p::libp2p::gossipsub::Message) -> Result { serde_json::from_slice(&value.data) } } diff --git a/src/utils/mod.rs b/compute/src/utils/mod.rs similarity index 69% rename from src/utils/mod.rs rename to compute/src/utils/mod.rs index 23606ae..8721f91 100644 --- a/src/utils/mod.rs +++ b/compute/src/utils/mod.rs @@ -7,7 +7,7 @@ pub use message::DKNMessage; mod available_nodes; pub use available_nodes::AvailableNodes; -use libp2p::{multiaddr::Protocol, Multiaddr}; +use dkn_p2p::libp2p::{multiaddr::Protocol, Multiaddr}; use port_check::is_port_reachable; use std::{ net::{Ipv4Addr, SocketAddrV4}, @@ -55,23 +55,3 @@ pub fn address_in_use(addr: &Multiaddr) -> bool { false }) } - -/// Utility to parse comma-separated string values, mostly read from the environment. -/// - Trims `"` from both ends at the start -/// - For each item, trims whitespace from both ends -pub fn split_comma_separated(input: Option) -> Vec { - match input { - Some(s) => s - .trim_matches('"') - .split(',') - .filter_map(|s| { - if s.is_empty() { - None - } else { - Some(s.trim().to_string()) - } - }) - .collect::>(), - None => vec![], - } -} diff --git a/p2p/Cargo.toml b/p2p/Cargo.toml new file mode 100644 index 0000000..f988166 --- /dev/null +++ b/p2p/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "dkn-p2p" +version.workspace = true +edition.workspace = true +license.workspace = true +readme = "README.md" +authors = [ + "Erhan Tezcan ", + "Anil Altuner Self { + pub fn new( + key: &Keypair, + relay_behavior: relay::client::Behaviour, + identity_protocol: String, + kademlia_protocol: StreamProtocol, + ) -> Self { let public_key = key.public(); let peer_id = public_key.to_peer_id(); Self { relay: relay_behavior, gossipsub: create_gossipsub_behavior(peer_id), - kademlia: create_kademlia_behavior(peer_id), + kademlia: create_kademlia_behavior(peer_id, kademlia_protocol), autonat: create_autonat_behavior(peer_id), dcutr: create_dcutr_behavior(peer_id), - identify: create_identify_behavior(public_key), + identify: create_identify_behavior(public_key, identity_protocol), } } } /// Configures the Kademlia DHT behavior for the node. #[inline] -fn create_kademlia_behavior(local_peer_id: PeerId) -> kad::Behaviour { +fn create_kademlia_behavior( + local_peer_id: PeerId, + protocol_name: StreamProtocol, +) -> kad::Behaviour { use kad::{Behaviour, Config}; const QUERY_TIMEOUT_SECS: u64 = 5 * 60; const RECORD_TTL_SECS: u64 = 30; - let mut cfg = Config::new(P2P_KADEMLIA_PROTOCOL); + let mut cfg = Config::new(protocol_name); cfg.set_query_timeout(Duration::from_secs(QUERY_TIMEOUT_SECS)) .set_record_ttl(Some(Duration::from_secs(RECORD_TTL_SECS))); @@ -50,10 +57,13 @@ fn create_kademlia_behavior(local_peer_id: PeerId) -> kad::Behaviour identify::Behaviour { +fn create_identify_behavior( + local_public_key: PublicKey, + protocol_version: String, +) -> identify::Behaviour { use identify::{Behaviour, Config}; - let cfg = Config::new(P2P_PROTOCOL_STRING.to_string(), local_public_key); + let cfg = Config::new(protocol_version, local_public_key); Behaviour::new(cfg) } @@ -119,9 +129,8 @@ fn create_gossipsub_behavior(author: PeerId) -> gossipsub::Behaviour { /// This helps to avoid memory exhaustion during high load const MAX_SEND_QUEUE_SIZE: usize = 400; - // message id's are simply hashes of the message data + // message id's are simply hashes of the message data, via SipHash13 let message_id_fn = |message: &Message| { - // uses siphash by default let mut hasher = hash_map::DefaultHasher::new(); message.data.hash(&mut hasher); let digest = hasher.finish(); diff --git a/src/p2p/client.rs b/p2p/src/client.rs similarity index 87% rename from src/p2p/client.rs rename to p2p/src/client.rs index e389206..f388f12 100644 --- a/src/p2p/client.rs +++ b/p2p/src/client.rs @@ -1,3 +1,4 @@ +use super::*; use eyre::Result; use libp2p::futures::StreamExt; use libp2p::gossipsub::{ @@ -7,16 +8,12 @@ use libp2p::kad::{GetClosestPeersError, GetClosestPeersOk, QueryResult}; use libp2p::{ autonat, gossipsub, identify, kad, multiaddr::Protocol, noise, swarm::SwarmEvent, tcp, yamux, }; -use libp2p::{Multiaddr, PeerId, Swarm, SwarmBuilder}; +use libp2p::{Multiaddr, PeerId, StreamProtocol, Swarm, SwarmBuilder}; use libp2p_identity::Keypair; -use std::time::Duration; -use std::time::Instant; - -use super::*; -use crate::utils::AvailableNodes; +use std::time::{Duration, Instant}; /// P2P client, exposes a simple interface to handle P2P communication. -pub struct P2PClient { +pub struct DriaP2PClient { /// `Swarm` instance, everything is accesses through this one. swarm: Swarm, /// Peer count for All and Mesh peers. @@ -26,21 +23,42 @@ pub struct P2PClient { peer_count: (usize, usize), /// Last time the peer count was refreshed. peer_last_refreshed: Instant, + /// Identity protocol string to be used for the Identity behaviour. + /// + /// This is usually `dria/{version}`. + identity_protocol: String, + /// Kademlia protocol, must match with other peers in the network. + /// + /// This is usually `/dria/kad/{version}`, notice the `/` at the start + /// which is mandatory for a `StreamProtocol`. + kademlia_protocol: StreamProtocol, } /// Number of seconds before an idle connection is closed. +/// TODO: default is 0, is 60 a good value? const IDLE_CONNECTION_TIMEOUT_SECS: u64 = 60; /// Number of seconds between refreshing the Kademlia DHT. const PEER_REFRESH_INTERVAL_SECS: u64 = 30; -impl P2PClient { +impl DriaP2PClient { /// Creates a new P2P client with the given keypair and listen address. + /// + /// Can provide a list of bootstrap and relay nodes to connect to as well at the start. + /// + /// The `version` is used to create the protocol strings for the client, and its very important that + /// they match with the clients existing within the network. pub fn new( keypair: Keypair, listen_addr: Multiaddr, - available_nodes: &AvailableNodes, + bootstraps: &[Multiaddr], + relays: &[Multiaddr], + version: &str, ) -> Result { + let identity_protocol = format!("{}{}", P2P_IDENTITY_PREFIX, version); + let kademlia_protocol = + StreamProtocol::try_from_owned(format!("{}{}", P2P_KADEMLIA_PREFIX, version))?; + // this is our peerId let node_peerid = keypair.public().to_peer_id(); log::info!("Compute node peer address: {}", node_peerid); @@ -54,7 +72,14 @@ impl P2PClient { )? .with_quic() .with_relay_client(noise::Config::new, yamux::Config::default)? - .with_behaviour(|key, relay_behavior| Ok(DriaBehaviour::new(key, relay_behavior)))? + .with_behaviour(|key, relay_behavior| { + Ok(DriaBehaviour::new( + key, + relay_behavior, + identity_protocol.clone(), + kademlia_protocol.clone(), + )) + })? .with_swarm_config(|c| { c.with_idle_connection_timeout(Duration::from_secs(IDLE_CONNECTION_TIMEOUT_SECS)) }) @@ -67,11 +92,8 @@ impl P2PClient { .set_mode(Some(libp2p::kad::Mode::Server)); // initiate bootstrap - log::info!( - "Initiating bootstrap: {:#?}", - available_nodes.bootstrap_nodes - ); - for addr in &available_nodes.bootstrap_nodes { + log::info!("Initiating bootstrap: {:#?}", bootstraps); + for addr in bootstraps { if let Some(peer_id) = addr.iter().find_map(|p| match p { Protocol::P2p(peer_id) => Some(peer_id), _ => None, @@ -101,11 +123,8 @@ impl P2PClient { log::info!("Listening p2p network on: {}", listen_addr); swarm.listen_on(listen_addr)?; - log::info!( - "Listening to relay nodes: {:#?}", - available_nodes.relay_nodes - ); - for addr in &available_nodes.relay_nodes { + log::info!("Listening to relay nodes: {:#?}", relays); + for addr in relays { swarm.listen_on(addr.clone().with(Protocol::P2pCircuit))?; } @@ -113,6 +132,8 @@ impl P2PClient { swarm, peer_count: (0, 0), peer_last_refreshed: Instant::now(), + identity_protocol, + kademlia_protocol, }) } @@ -237,12 +258,12 @@ impl P2PClient { /// - For Kademlia, we check the kademlia protocol and then add the address to the Kademlia routing table. fn handle_identify_event(&mut self, peer_id: PeerId, info: identify::Info) { // check identify protocol string - if info.protocol_version != P2P_PROTOCOL_STRING { + if info.protocol_version != self.identity_protocol { log::warn!( "Identify: Peer {} has different Identify protocol: (them {}, you {})", peer_id, info.protocol_version, - P2P_PROTOCOL_STRING + self.identity_protocol ); return; } @@ -254,7 +275,7 @@ impl P2PClient { .find(|p| p.to_string().starts_with(P2P_KADEMLIA_PREFIX)) { // if it matches our protocol, add it to the Kademlia routing table - if *kad_protocol == P2P_KADEMLIA_PROTOCOL { + if *kad_protocol == self.kademlia_protocol { // filter listen addresses let addrs = info.listen_addrs.into_iter().filter(|listen_addr| { if let Some(Protocol::Ip4(ipv4_addr)) = listen_addr.iter().next() { @@ -285,7 +306,7 @@ impl P2PClient { "Identify: Peer {} has different Kademlia version: (them {}, you {})", peer_id, kad_protocol, - P2P_KADEMLIA_PROTOCOL + self.kademlia_protocol ); } } diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs new file mode 100644 index 0000000..97eacec --- /dev/null +++ b/p2p/src/lib.rs @@ -0,0 +1,17 @@ +mod transform; + +mod behaviour; +use behaviour::{DriaBehaviour, DriaBehaviourEvent}; + +mod client; +pub use client::DriaP2PClient; + +/// Prefix for Kademlia protocol, must start with `/`! +pub(crate) const P2P_KADEMLIA_PREFIX: &str = "/dria/kad/"; + +/// Prefix for Identity protocol string. +pub(crate) const P2P_IDENTITY_PREFIX: &str = "dria/"; + +// re-exports +pub use libp2p; +pub use libp2p_identity; diff --git a/src/p2p/data_transform.rs b/p2p/src/transform.rs similarity index 100% rename from src/p2p/data_transform.rs rename to p2p/src/transform.rs diff --git a/p2p/tests/listen_test.rs b/p2p/tests/listen_test.rs new file mode 100644 index 0000000..e286429 --- /dev/null +++ b/p2p/tests/listen_test.rs @@ -0,0 +1,40 @@ +use dkn_p2p::DriaP2PClient; +use eyre::Result; +use libp2p::Multiaddr; +use libp2p_identity::Keypair; +use std::{env, str::FromStr}; + +const LOG_LEVEL: &str = "none,dkn_p2p=debug"; + +#[tokio::test] +#[ignore = "run manually with logs"] +async fn test_listen_topic_once() -> Result<()> { + // topic to be listened to + const TOPIC: &str = "pong"; + + env::set_var("RUST_LOG", LOG_LEVEL); + let _ = env_logger::try_init(); + + // setup client + let keypair = Keypair::generate_secp256k1(); + let addr = Multiaddr::from_str("/ip4/0.0.0.0/tcp/4001")?; + let bootstraps = vec![Multiaddr::from_str( + "/ip4/44.206.245.139/tcp/4001/p2p/16Uiu2HAm4q3LZU2T9kgjKK4ysy6KZYKLq8KiXQyae4RHdF7uqSt4", + )?]; + let relays = vec![Multiaddr::from_str( + "/ip4/34.201.33.141/tcp/4001/p2p/16Uiu2HAkuXiV2CQkC9eJgU6cMnJ9SMARa85FZ6miTkvn5fuHNufa", + )?]; + let mut client = DriaP2PClient::new(keypair, addr, &bootstraps, &relays, "0.2")?; + + // subscribe to the given topic + client.subscribe(TOPIC)?; + + // wait for a single gossipsub message on this topic + let message = client.process_events().await; + log::info!("Received {} message: {:?}", TOPIC, message); + + // unsubscribe gracefully + client.unsubscribe(TOPIC)?; + + Ok(()) +} diff --git a/src/p2p/mod.rs b/src/p2p/mod.rs deleted file mode 100644 index 1e4c09c..0000000 --- a/src/p2p/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -mod behaviour; -pub use behaviour::{DriaBehaviour, DriaBehaviourEvent}; - -mod client; -pub use client::P2PClient; - -mod versioning; -pub use versioning::*; - -mod data_transform; diff --git a/src/p2p/versioning.rs b/src/p2p/versioning.rs deleted file mode 100644 index 11b2a4d..0000000 --- a/src/p2p/versioning.rs +++ /dev/null @@ -1,35 +0,0 @@ -use libp2p::StreamProtocol; - -/// Kademlia protocol prefix, as a macro so that it can be used in literal-expecting constants. -macro_rules! P2P_KADEMLIA_PREFIX { - () => { - "/dria/kad/" - }; -} -pub const P2P_KADEMLIA_PREFIX: &str = P2P_KADEMLIA_PREFIX!(); - -/// Identity protocol name prefix, as a macro so that it can be used in literal-expecting constants. -macro_rules! P2P_IDENTITY_PREFIX { - () => { - "dria/" - }; -} - -/// Kademlia protocol version, in the form of `/dria/kad/`, **notice the `/` at the start**. -/// -/// It is important that this protocol matches EXACTLY among the nodes, otherwise there is a protocol-level logic -/// that will prevent peers from finding eachother within the DHT. -pub const P2P_KADEMLIA_PROTOCOL: StreamProtocol = StreamProtocol::new(concat!( - P2P_KADEMLIA_PREFIX!(), - env!("CARGO_PKG_VERSION_MAJOR"), - ".", - env!("CARGO_PKG_VERSION_MINOR") -)); - -/// Protocol string, checked by Identify protocol handlers. -pub const P2P_PROTOCOL_STRING: &str = concat!( - P2P_IDENTITY_PREFIX!(), - env!("CARGO_PKG_VERSION_MAJOR"), - ".", - env!("CARGO_PKG_VERSION_MINOR") -); diff --git a/workflows/Cargo.toml b/workflows/Cargo.toml new file mode 100644 index 0000000..e068f37 --- /dev/null +++ b/workflows/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "dkn-workflows" +version.workspace = true +edition.workspace = true +license.workspace = true +readme = "README.md" +authors = ["Erhan Tezcan "] + + +[dependencies] +# ollama-rs is re-exported from ollama-workflows as well +ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows" } + +tokio-util.workspace = true +tokio.workspace = true +serde.workspace = true +serde_json.workspace = true +async-trait.workspace = true +reqwest.workspace = true +rand.workspace = true +log.workspace = true +eyre.workspace = true + +[dev-dependencies] +env_logger.workspace = true +dotenvy.workspace = true diff --git a/workflows/README.md b/workflows/README.md new file mode 100644 index 0000000..46eea30 --- /dev/null +++ b/workflows/README.md @@ -0,0 +1,38 @@ +# DKN Workflows + +We make use of Ollama Workflows in DKN; however, we also want to make sure that the chosen models are valid and is performant enough (i.e. have enough TPS). +This crate handles the configurations of models to be used, and implements various service checks. + +- **OpenAI**: We check that the chosen models are enabled for the user's profile by fetching their models with their API key. We filter out the disabled models. +- **Ollama**: We provide a sample workflow to measure TPS and then pick models that are above some TPS threshold. While calculating TPS, there is also a timeout so that beyond that timeout the TPS is not even considered and the model becomes invalid. + +## Installation + +Add the package via `git` within your Cargo dependencies: + +```toml +dkn-workflows = { git = "https://github.com/firstbatchxyz/dkn-compute-node" } +``` + +Note that the underlying Ollama Workflows crate is re-exported by this crate. + +## Usage + +DKN Workflows make use of several environment variables, respecting the providers. + +- `OPENAI_API_KEY` is used for OpenAI requests +- `OLLAMA_HOST` is used to connect to Ollama server +- `OLLAMA_PORT` is used to connect to Ollama server +- `OLLAMA_AUTO_PULL` indicates whether we should pull missing models automatically or not +- `SERPER_API_KEY` is optional API key to use **Serper**, for better Workflow executions +- `JINA_API_KEY` is optional API key to use **Jina**, for better Workflow executions + +With the environment variables ready, you can simply create a new configuration and call `check_services` to ensure all models are correctly setup: + +```rs +use dkn_workflows::{DriaWorkflowsConfig, Model}; + +let models = vec![Model::Phi3_5Mini]; +let mut config = DriaWorkflowsConfig::new(models); +config.check_services().await?; +``` diff --git a/src/config/models.rs b/workflows/src/config.rs similarity index 53% rename from src/config/models.rs rename to workflows/src/config.rs index 943a8dc..d2bd29b 100644 --- a/src/config/models.rs +++ b/workflows/src/config.rs @@ -1,42 +1,43 @@ -use crate::utils::split_comma_separated; +use crate::{split_csv_line, OllamaConfig, OpenAIConfig}; use eyre::{eyre, Result}; use ollama_workflows::{Model, ModelProvider}; use rand::seq::IteratorRandom; // provides Vec<_>.choose #[derive(Debug, Clone)] -pub struct ModelConfig { - pub(crate) models: Vec<(ModelProvider, Model)>, +pub struct DriaWorkflowsConfig { + /// List of models with their providers. + pub models: Vec<(ModelProvider, Model)>, + /// Ollama configurations, in case Ollama is used. + /// Otherwise, can be ignored. + pub ollama: OllamaConfig, + /// OpenAI configurations, e.g. API key, in case OpenAI is used. + /// Otherwise, can be ignored. + pub openai: OpenAIConfig, } -impl std::fmt::Display for ModelConfig { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let models_str = self - .models - .iter() - .map(|(provider, model)| format!("{:?}:{}", provider, model)) - .collect::>() - .join(","); - write!(f, "{}", models_str) - } -} +impl DriaWorkflowsConfig { + pub fn new(models: Vec) -> Self { + let models_and_providers = models + .into_iter() + .map(|model| (model.clone().into(), model)) + .collect::>(); -impl ModelConfig { + Self { + models: models_and_providers, + openai: OpenAIConfig::new(), + ollama: OllamaConfig::new(), + } + } /// Parses Ollama-Workflows compatible models from a comma-separated values string. - pub fn new_from_csv(input: Option) -> Self { - let models_str = split_comma_separated(input); + pub fn new_from_csv(input: &str) -> Self { + let models_str = split_csv_line(input); let models = models_str .into_iter() - .filter_map(|s| match Model::try_from(s) { - Ok(model) => Some((model.clone().into(), model)), - Err(e) => { - log::warn!("Error parsing model: {}", e); - None - } - }) - .collect::>(); + .filter_map(|s| Model::try_from(s).ok()) + .collect(); - Self { models } + Self::new(models) } /// Returns the models that belong to a given providers from the config. @@ -55,12 +56,27 @@ impl ModelConfig { /// Given a raw model name or provider (as a string), returns the first matching model & provider. /// - /// If this is a model and is supported by this node, it is returned directly. - /// If this is a provider, the first matching model in the node config is returned. + /// - If input is `*` or `all`, a random model is returned. + /// - if input is `!` the first model is returned. + /// - If input is a model and is supported by this node, it is returned directly. + /// - If input is a provider, the first matching model in the node config is returned. /// /// If there are no matching models with this logic, an error is returned. pub fn get_matching_model(&self, model_or_provider: String) -> Result<(ModelProvider, Model)> { - if let Ok(provider) = ModelProvider::try_from(model_or_provider.clone()) { + if model_or_provider == "*" { + // return a random model + self.models + .iter() + .choose(&mut rand::thread_rng()) + .ok_or_else(|| eyre!("No models to randomly pick for '*'.")) + .cloned() + } else if model_or_provider == "!" { + // return the first model + self.models + .first() + .ok_or_else(|| eyre!("No models to choose first for '!'.")) + .cloned() + } else if let Ok(provider) = ModelProvider::try_from(model_or_provider.clone()) { // this is a valid provider, return the first matching model in the config self.models .iter() @@ -124,6 +140,70 @@ impl ModelConfig { unique }) } + + /// Check if the required compute services are running. + /// This has several steps: + /// + /// - If Ollama models are used, hardcoded models are checked locally, and for + /// external models, the workflow is tested with a simple task with timeout. + /// - If OpenAI models are used, the API key is checked and the models are tested + /// + /// If both type of models are used, both services are checked. + /// In the end, bad models are filtered out and we simply check if we are left if any valid models at all. + /// If not, an error is returned. + pub async fn check_services(&mut self) -> Result<()> { + log::info!("Checking configured services."); + + // TODO: can refactor (provider, model) logic here + let unique_providers = self.get_providers(); + + let mut good_models = Vec::new(); + + // if Ollama is a provider, check that it is running & Ollama models are pulled (or pull them) + if unique_providers.contains(&ModelProvider::Ollama) { + let ollama_models = self.get_models_for_provider(ModelProvider::Ollama); + + // ensure that the models are pulled / pull them if not + let good_ollama_models = self.ollama.check(ollama_models).await?; + good_models.extend( + good_ollama_models + .into_iter() + .map(|m| (ModelProvider::Ollama, m)), + ); + } + + // if OpenAI is a provider, check that the API key is set + if unique_providers.contains(&ModelProvider::OpenAI) { + let openai_models = self.get_models_for_provider(ModelProvider::OpenAI); + + let good_openai_models = self.openai.check(openai_models).await?; + good_models.extend( + good_openai_models + .into_iter() + .map(|m| (ModelProvider::OpenAI, m)), + ); + } + + // update good models + if good_models.is_empty() { + Err(eyre!("No good models found, please check logs for errors.")) + } else { + self.models = good_models; + Ok(()) + } + } +} + +impl std::fmt::Display for DriaWorkflowsConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let models_str = self + .models + .iter() + .map(|(provider, model)| format!("{:?}:{}", provider, model)) + .collect::>() + .join(","); + write!(f, "{}", models_str) + } } #[cfg(test)] @@ -132,19 +212,18 @@ mod tests { #[test] fn test_csv_parser() { - let cfg = - ModelConfig::new_from_csv(Some("idontexist,i dont either,i332287648762".to_string())); + let cfg = DriaWorkflowsConfig::new_from_csv("idontexist,i dont either,i332287648762"); assert_eq!(cfg.models.len(), 0); - let cfg = ModelConfig::new_from_csv(Some( - "gemma2:9b-instruct-q8_0,phi3:14b-medium-4k-instruct-q4_1,balblablabl".to_string(), - )); + let cfg = DriaWorkflowsConfig::new_from_csv( + "gemma2:9b-instruct-q8_0,phi3:14b-medium-4k-instruct-q4_1,balblablabl", + ); assert_eq!(cfg.models.len(), 2); } #[test] fn test_model_matching() { - let cfg = ModelConfig::new_from_csv(Some("gpt-4o,llama3.1:latest".to_string())); + let cfg = DriaWorkflowsConfig::new_from_csv("gpt-4o,llama3.1:latest"); assert_eq!( cfg.get_matching_model("openai".to_string()).unwrap().1, Model::GPT4o, @@ -173,7 +252,7 @@ mod tests { #[test] fn test_get_any_matching_model() { - let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,llama3.1:latest".to_string())); + let cfg = DriaWorkflowsConfig::new_from_csv("gpt-3.5-turbo,llama3.1:latest"); let result = cfg.get_any_matching_model(vec![ "i-dont-exist".to_string(), "llama3.1:latest".to_string(), diff --git a/workflows/src/lib.rs b/workflows/src/lib.rs new file mode 100644 index 0000000..1088758 --- /dev/null +++ b/workflows/src/lib.rs @@ -0,0 +1,11 @@ +mod utils; +pub use utils::split_csv_line; + +mod providers; +use providers::{OllamaConfig, OpenAIConfig}; + +mod config; +pub use config::DriaWorkflowsConfig; + +pub use ollama_workflows; +pub use ollama_workflows::{Model, ModelProvider}; diff --git a/workflows/src/providers/mod.rs b/workflows/src/providers/mod.rs new file mode 100644 index 0000000..a6ea768 --- /dev/null +++ b/workflows/src/providers/mod.rs @@ -0,0 +1,5 @@ +mod ollama; +pub use ollama::OllamaConfig; + +mod openai; +pub use openai::OpenAIConfig; diff --git a/src/config/ollama.rs b/workflows/src/providers/ollama.rs similarity index 84% rename from src/config/ollama.rs rename to workflows/src/providers/ollama.rs index 53ce2a3..d507130 100644 --- a/src/config/ollama.rs +++ b/workflows/src/providers/ollama.rs @@ -10,14 +10,20 @@ use ollama_workflows::{ }, Model, }; +use std::env; use std::time::Duration; const DEFAULT_OLLAMA_HOST: &str = "http://127.0.0.1"; const DEFAULT_OLLAMA_PORT: u16 = 11434; +/// Automatically pull missing models by default? +const DEFAULT_AUTO_PULL: bool = true; +/// Timeout duration for checking model performance during a generation. +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(80); +/// Minimum tokens per second (TPS) for checking model performance during a generation. +const DEFAULT_MIN_TPS: f64 = 15.0; /// Some models such as small embedding models, are hardcoded into the node. const HARDCODED_MODELS: [&str; 1] = ["hellord/mxbai-embed-large-v1:f16"]; - /// Prompt to be used to see Ollama performance. const TEST_PROMPT: &str = "Please write a poem about Kapadokya."; @@ -25,14 +31,16 @@ const TEST_PROMPT: &str = "Please write a poem about Kapadokya."; #[derive(Debug, Clone)] pub struct OllamaConfig { /// Host, usually `http://127.0.0.1`. - pub(crate) host: String, + pub host: String, /// Port, usually `11434`. - pub(crate) port: u16, - /// List of hardcoded models that are internally used by Ollama workflows. - hardcoded_models: Vec, + pub port: u16, /// Whether to automatically pull models from Ollama. /// This is useful for CI/CD workflows. auto_pull: bool, + /// Timeout duration for checking model performance during a generation. + timeout: Duration, + /// Minimum tokens per second (TPS) for checking model performance during a generation. + min_tps: f64, } impl Default for OllamaConfig { @@ -40,11 +48,9 @@ impl Default for OllamaConfig { Self { host: DEFAULT_OLLAMA_HOST.to_string(), port: DEFAULT_OLLAMA_PORT, - hardcoded_models: HARDCODED_MODELS - .into_iter() - .map(|s| s.to_string()) - .collect(), - auto_pull: false, + auto_pull: DEFAULT_AUTO_PULL, + timeout: DEFAULT_TIMEOUT, + min_tps: DEFAULT_MIN_TPS, } } } @@ -53,40 +59,33 @@ impl OllamaConfig { /// /// If not found, defaults to `DEFAULT_OLLAMA_HOST` and `DEFAULT_OLLAMA_PORT`. pub fn new() -> Self { - let host = std::env::var("OLLAMA_HOST") + let host = env::var("OLLAMA_HOST") .map(|h| h.trim_matches('"').to_string()) .unwrap_or(DEFAULT_OLLAMA_HOST.to_string()); - let port = std::env::var("OLLAMA_PORT") + let port = env::var("OLLAMA_PORT") .and_then(|port_str| port_str.parse().map_err(|_| std::env::VarError::NotPresent)) .unwrap_or(DEFAULT_OLLAMA_PORT); - // Ollama workflows may require specific models to be loaded regardless of the choices - let hardcoded_models = HARDCODED_MODELS.iter().map(|s| s.to_string()).collect(); - // auto-pull, its true by default - let auto_pull = std::env::var("OLLAMA_AUTO_PULL") + let auto_pull = env::var("OLLAMA_AUTO_PULL") .map(|s| s == "true") .unwrap_or(true); Self { host, port, - hardcoded_models, auto_pull, + ..Default::default() } } /// Check if requested models exist in Ollama, and then tests them using a workflow. - pub async fn check( - &self, - external_models: Vec, - timeout: Duration, - min_tps: f64, - ) -> Result> { + pub async fn check(&self, external_models: Vec) -> Result> { log::info!( - "Checking Ollama requirements (auto-pull {}, workflow timeout: {}s)", + "Checking Ollama requirements (auto-pull {}, timeout: {}s, min tps: {})", if self.auto_pull { "on" } else { "off" }, - timeout.as_secs() + self.timeout.as_secs(), + self.min_tps ); let ollama = Ollama::new(&self.host, self.port); @@ -105,11 +104,10 @@ impl OllamaConfig { // check hardcoded models & pull them if available // these are not used directly by the user, but are needed for the workflows - log::debug!("Checking hardcoded models: {:#?}", self.hardcoded_models); - // only check if model is contained in local_models - // we dont check workflows for hardcoded models - for model in &self.hardcoded_models { - if !local_models.contains(model) { + // we only check if model is contained in local_models, we dont check workflows for these + for model in HARDCODED_MODELS { + // `contains` doesnt work for &str so we equality check instead + if !&local_models.iter().any(|s| s == model) { self.try_pull(&ollama, model.to_owned()) .await .wrap_err("Could not pull model")?; @@ -126,10 +124,7 @@ impl OllamaConfig { .wrap_err("Could not pull model")?; } - if self - .test_performance(&ollama, &model, timeout, min_tps) - .await - { + if self.test_performance(&ollama, &model).await { good_models.push(model); } } @@ -165,13 +160,7 @@ impl OllamaConfig { /// /// This is to see if a given system can execute Ollama workflows for their chosen models, /// e.g. if they have enough RAM/CPU and such. - pub async fn test_performance( - &self, - ollama: &Ollama, - model: &Model, - timeout: Duration, - min_tps: f64, - ) -> bool { + pub async fn test_performance(&self, ollama: &Ollama, model: &Model) -> bool { log::info!("Testing model {}", model); // first generate a dummy embedding to load the model into memory (warm-up) @@ -200,7 +189,7 @@ impl OllamaConfig { // then, run a sample generation with timeout and measure tps tokio::select! { - _ = tokio::time::sleep(timeout) => { + _ = tokio::time::sleep(self.timeout) => { log::warn!("Ignoring model {}: Workflow timed out", model); }, result = ollama.generate(generation_request) => { @@ -210,7 +199,7 @@ impl OllamaConfig { / (response.eval_duration.unwrap_or(1) as f64) * 1_000_000_000f64; - if tps >= min_tps { + if tps >= self.min_tps { log::info!("Model {} passed the test with tps: {}", model, tps); return true; } @@ -219,7 +208,7 @@ impl OllamaConfig { "Ignoring model {}: tps too low ({:.3} < {:.3})", model, tps, - min_tps + self.min_tps ); } Err(e) => { @@ -239,7 +228,7 @@ mod tests { use ollama_workflows::{Executor, Model, ProgramMemory, Workflow}; #[tokio::test] - #[ignore = "run this manually"] + #[ignore = "requires Ollama"] async fn test_ollama_prompt() { let model = Model::default().to_string(); let ollama = Ollama::default(); @@ -258,7 +247,7 @@ mod tests { } #[tokio::test] - #[ignore = "run this manually"] + #[ignore = "requires Ollama"] async fn test_ollama_workflow() { let workflow = r#"{ "name": "Simple", diff --git a/src/config/openai.rs b/workflows/src/providers/openai.rs similarity index 87% rename from src/config/openai.rs rename to workflows/src/providers/openai.rs index a29819e..2252e36 100644 --- a/src/config/openai.rs +++ b/workflows/src/providers/openai.rs @@ -1,11 +1,8 @@ -#![allow(unused)] - use eyre::{eyre, Context, Result}; use ollama_workflows::Model; +use reqwest::Client; use serde::Deserialize; -const OPENAI_API_KEY: &str = "OPENAI_API_KEY"; - const OPENAI_MODELS_API: &str = "https://api.openai.com/v1/models"; /// [Model](https://platform.openai.com/docs/api-reference/models/object) API object. @@ -14,30 +11,36 @@ struct OpenAIModel { /// The model identifier, which can be referenced in the API endpoints. id: String, /// The Unix timestamp (in seconds) when the model was created. + #[allow(unused)] created: u64, /// The object type, which is always "model". + #[allow(unused)] object: String, /// The organization that owns the model. + #[allow(unused)] owned_by: String, } #[derive(Debug, Clone, Deserialize)] struct OpenAIModelsResponse { data: Vec, + #[allow(unused)] object: String, } +/// OpenAI-specific configurations. #[derive(Debug, Clone, Default)] pub struct OpenAIConfig { - pub(crate) api_key: Option, + /// API key, if available. + api_key: Option, } impl OpenAIConfig { /// Looks at the environment variables for OpenAI API key. pub fn new() -> Self { - let api_key = std::env::var(OPENAI_API_KEY).ok(); - - Self { api_key } + Self { + api_key: std::env::var("OPENAI_API_KEY").ok(), + } } /// Check if requested models exist & are available in the OpenAI account. @@ -50,17 +53,17 @@ impl OpenAIConfig { }; // fetch models - let client = reqwest::Client::new(); + let client = Client::new(); let request = client .get(OPENAI_MODELS_API) .header("Authorization", format!("Bearer {}", api_key)) .build() - .wrap_err("Failed to build request")?; + .wrap_err("failed to build request")?; let response = client .execute(request) .await - .wrap_err("Failed to send request")?; + .wrap_err("failed to send request")?; // parse response if response.status().is_client_error() { diff --git a/workflows/src/utils.rs b/workflows/src/utils.rs new file mode 100644 index 0000000..030f0e2 --- /dev/null +++ b/workflows/src/utils.rs @@ -0,0 +1,36 @@ +/// Utility to parse comma-separated string value line. +/// +/// - Trims `"` from both ends for the input +/// - For each item, trims whitespace from both ends +pub fn split_csv_line(input: &str) -> Vec { + input + .trim_matches('"') + .split(',') + .filter_map(|s| { + let s = s.trim().to_string(); + if s.is_empty() { + None + } else { + Some(s) + } + }) + .collect::>() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_example() { + // should ignore whitespaces and `"` at both ends, and ignore empty items + let input = "\"a, b , c ,, \""; + let expected = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + assert_eq!(split_csv_line(input), expected); + } + + #[test] + fn test_empty() { + assert!(split_csv_line(Default::default()).is_empty()); + } +} diff --git a/workflows/tests/models_test.rs b/workflows/tests/models_test.rs new file mode 100644 index 0000000..acfc786 --- /dev/null +++ b/workflows/tests/models_test.rs @@ -0,0 +1,51 @@ +use dkn_workflows::{DriaWorkflowsConfig, Model, ModelProvider}; +use eyre::Result; +use std::env; + +const LOG_LEVEL: &str = "none,dkn_workflows=debug"; + +#[tokio::test] +#[ignore = "requires Ollama"] +async fn test_ollama_check() -> Result<()> { + env::set_var("RUST_LOG", LOG_LEVEL); + let _ = env_logger::try_init(); + + let models = vec![Model::Phi3_5Mini]; + let mut model_config = DriaWorkflowsConfig::new(models); + model_config.check_services().await?; + + assert_eq!( + model_config.models[0], + (ModelProvider::Ollama, Model::Phi3_5Mini) + ); + + Ok(()) +} + +#[tokio::test] +#[ignore = "requires OpenAI"] +async fn test_openai_check() -> Result<()> { + let _ = dotenvy::dotenv(); // read api key + env::set_var("RUST_LOG", LOG_LEVEL); + let _ = env_logger::try_init(); + + let models = vec![Model::GPT4Turbo]; + let mut model_config = DriaWorkflowsConfig::new(models); + model_config.check_services().await?; + + assert_eq!( + model_config.models[0], + (ModelProvider::OpenAI, Model::GPT4Turbo) + ); + Ok(()) +} + +#[tokio::test] +async fn test_empty() -> Result<()> { + let mut model_config = DriaWorkflowsConfig::new(vec![]); + + let result = model_config.check_services().await; + assert!(result.is_err()); + + Ok(()) +}