Skip to content

Commit

Permalink
Merge pull request #157 from firstbatchxyz/erhant/batch-publish
Browse files Browse the repository at this point in the history
feat: publish-in-task
  • Loading branch information
erhant authored Dec 6, 2024
2 parents 1c6fded + 819e4d6 commit 650bdf5
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 62 deletions.
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ DKN_ADMIN_PUBLIC_KEY=0208ef5e65a9c656a6f92fb2c770d5d5e2ecffe02a6aade19207f75110b
# example: phi3:3.8b,gpt-4o-mini
DKN_MODELS=


## DRIA (optional) ##
# P2P address, you don't need to change this unless this port is already in use.
DKN_P2P_LISTEN_ADDR=/ip4/0.0.0.0/tcp/4001
# Comma-separated static relay nodes
DKN_RELAY_NODES=
# Comma-separated static bootstrap nodes
DKN_BOOTSTRAP_NODES=
# Batch size for workflows, you do not need to edit this.
DKN_BATCH_SIZE=

## DRIA (profiling only, do not uncomment) ##
# Set to a number of seconds to wait before exiting, only use in profiling build!
Expand Down
10 changes: 5 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ default-members = ["compute"]

[workspace.package]
edition = "2021"
version = "0.2.26"
version = "0.2.27"
license = "Apache-2.0"
readme = "README.md"

Expand Down
14 changes: 14 additions & 0 deletions compute/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ use libsecp256k1::{PublicKey, SecretKey};

use std::{env, str::FromStr};

// TODO: make this configurable later
const DEFAULT_WORKFLOW_BATCH_SIZE: usize = 5;

#[derive(Debug, Clone)]
pub struct DriaComputeNodeConfig {
/// Wallet secret/private key.
Expand All @@ -25,6 +28,11 @@ pub struct DriaComputeNodeConfig {
pub workflows: DriaWorkflowsConfig,
/// Network type of the node.
pub network_type: DriaNetworkType,
/// Batch size for batchable workflows.
///
/// A higher value will help execute more tasks concurrently,
/// at the risk of hitting rate-limits.
pub batch_size: usize,
}

/// The default P2P network listen address.
Expand Down Expand Up @@ -103,6 +111,11 @@ impl DriaComputeNodeConfig {
.map(|s| DriaNetworkType::from(s.as_str()))
.unwrap_or_default();

// parse batch size
let batch_size = env::var("DKN_BATCH_SIZE")
.map(|s| s.parse::<usize>().unwrap_or(DEFAULT_WORKFLOW_BATCH_SIZE))
.unwrap_or(DEFAULT_WORKFLOW_BATCH_SIZE);

Self {
admin_public_key,
secret_key,
Expand All @@ -111,6 +124,7 @@ impl DriaComputeNodeConfig {
workflows,
p2p_listen_addr,
network_type,
batch_size,
}
}

Expand Down
8 changes: 2 additions & 6 deletions compute/src/handlers/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,7 @@ impl WorkflowHandler {

// convert payload to message
let payload_str = serde_json::json!(payload).to_string();
log::debug!(
"Publishing result for task {}\n{}",
task.task_id,
payload_str
);
log::info!("Publishing result for task {}", task.task_id);
DriaMessage::new(payload_str, Self::RESPONSE_TOPIC)
}
Err(err) => {
Expand Down Expand Up @@ -161,7 +157,7 @@ impl WorkflowHandler {

// try publishing the result
if let Err(publish_err) = node.publish(message).await {
let err_msg = format!("could not publish result: {:?}", publish_err);
let err_msg = format!("Could not publish task result: {:?}", publish_err);
log::error!("{}", err_msg);

let payload = serde_json::json!({
Expand Down
15 changes: 12 additions & 3 deletions compute/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use dkn_workflows::DriaWorkflowsConfig;
use eyre::Result;
use std::env;
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use workers::workflow::WorkflowsWorker;

#[tokio::main]
async fn main() -> Result<()> {
Expand Down Expand Up @@ -86,6 +87,7 @@ async fn main() -> Result<()> {
log::warn!("Using models: {:#?}", config.workflows.models);

// create the node
let batch_size = config.batch_size;
let (mut node, p2p, worker_batch, worker_single) = DriaComputeNode::new(config).await?;

// spawn p2p client first
Expand All @@ -94,14 +96,21 @@ async fn main() -> Result<()> {

// spawn batch worker thread if we are using such models (e.g. OpenAI, Gemini, OpenRouter)
if let Some(mut worker_batch) = worker_batch {
log::info!("Spawning workflows batch worker thread.");
task_tracker.spawn(async move { worker_batch.run_batch().await });
assert!(
batch_size <= WorkflowsWorker::MAX_BATCH_SIZE,
"batch size too large"
);
log::info!(
"Spawning workflows batch worker thread. (batch size {})",
batch_size
);
task_tracker.spawn(async move { worker_batch.run_batch(batch_size).await });
}

// spawn single worker thread if we are using such models (e.g. Ollama)
if let Some(mut worker_single) = worker_single {
log::info!("Spawning workflows single worker thread.");
task_tracker.spawn(async move { worker_single.run().await });
task_tracker.spawn(async move { worker_single.run_series().await });
}

// spawn compute node thread
Expand Down
4 changes: 1 addition & 3 deletions compute/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ impl DriaComputeNode {
let (p2p_client, p2p_commander, message_rx) = DriaP2PClient::new(
keypair,
config.p2p_listen_addr.clone(),
available_nodes.bootstrap_nodes.clone().into_iter(),
available_nodes.relay_nodes.clone().into_iter(),
available_nodes.rpc_nodes.clone().into_iter(),
&available_nodes,
protocol,
)?;

Expand Down
56 changes: 29 additions & 27 deletions compute/src/workers/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,21 @@ pub struct WorkflowsWorkerOutput {
///
/// It is expected to be spawned in another thread, with `run_batch` for batch processing and `run` for single processing.
pub struct WorkflowsWorker {
/// Workflow message channel receiver, the sender is most likely the compute node itself.
workflow_rx: mpsc::Receiver<WorkflowsWorkerInput>,
/// Publish message channel sender, the receiver is most likely the compute node itself.
publish_tx: mpsc::Sender<WorkflowsWorkerOutput>,
}

/// Buffer size for workflow tasks (per worker).
const WORKFLOW_CHANNEL_BUFSIZE: usize = 1024;

impl WorkflowsWorker {
/// Batch size that defines how many tasks can be executed in parallel at once.
/// IMPORTANT NOTE: `run` function is designed to handle the batch size here specifically,
/// Batch size that defines how many tasks can be executed concurrently at once.
///
/// The `run` function is designed to handle the batch size here specifically,
/// if there are more tasks than the batch size, the function will panic.
const BATCH_SIZE: usize = 8;
pub const MAX_BATCH_SIZE: usize = 8;

/// Creates a worker and returns the sender and receiver for the worker.
pub fn new(
Expand All @@ -65,24 +68,20 @@ impl WorkflowsWorker {
self.workflow_rx.close();
}

/// Launches the thread that can process tasks one by one.
/// Launches the thread that can process tasks one by one (in series).
/// This function will block until the channel is closed.
///
/// It is suitable for task streams that consume local resources, unlike API calls.
pub async fn run(&mut self) {
pub async fn run_series(&mut self) {
loop {
let task = self.workflow_rx.recv().await;

let result = if let Some(task) = task {
if let Some(task) = task {
log::info!("Processing single workflow for task {}", task.task_id);
WorkflowsWorker::execute(task).await
WorkflowsWorker::execute((task, self.publish_tx.clone())).await
} else {
return self.shutdown();
};

if let Err(e) = self.publish_tx.send(result).await {
log::error!("Error sending workflow result: {}", e);
}
}
}

Expand All @@ -91,13 +90,16 @@ impl WorkflowsWorker {
///
/// It is suitable for task streams that make use of API calls, unlike Ollama-like
/// tasks that consumes local resources and would not make sense to run in parallel.
pub async fn run_batch(&mut self) {
///
/// Batch size must NOT be larger than `MAX_BATCH_SIZE`, otherwise will panic.
pub async fn run_batch(&mut self, batch_size: usize) {
// TODO: need some better batch_size error handling here
loop {
// get tasks in batch from the channel
let mut task_buffer = Vec::new();
let num_tasks = self
.workflow_rx
.recv_many(&mut task_buffer, Self::BATCH_SIZE)
.recv_many(&mut task_buffer, batch_size)
.await;

if num_tasks == 0 {
Expand All @@ -106,8 +108,10 @@ impl WorkflowsWorker {

// process the batch
log::info!("Processing {} workflows in batch", num_tasks);
let mut batch = task_buffer.into_iter();
let results = match num_tasks {
let mut batch = task_buffer
.into_iter()
.map(|b| (b, self.publish_tx.clone()));
match num_tasks {
1 => {
let r0 = WorkflowsWorker::execute(batch.next().unwrap()).await;
vec![r0]
Expand Down Expand Up @@ -186,23 +190,17 @@ impl WorkflowsWorker {
unreachable!(
"number of tasks cant be larger than batch size ({} > {})",
num_tasks,
Self::BATCH_SIZE
Self::MAX_BATCH_SIZE
);
}
};

// publish all results
log::info!("Publishing {} workflow results", results.len());
for result in results {
if let Err(e) = self.publish_tx.send(result).await {
log::error!("Error sending workflow result: {}", e);
}
}
}
}

/// A single task execution.
pub async fn execute(input: WorkflowsWorkerInput) -> WorkflowsWorkerOutput {
/// Executes a single task, and publishes the output.
pub async fn execute(
(input, publish_tx): (WorkflowsWorkerInput, mpsc::Sender<WorkflowsWorkerOutput>),
) {
let mut memory = ProgramMemory::new();

let started_at = std::time::Instant::now();
Expand All @@ -211,13 +209,17 @@ impl WorkflowsWorker {
.execute(input.entry.as_ref(), &input.workflow, &mut memory)
.await;

WorkflowsWorkerOutput {
let output = WorkflowsWorkerOutput {
result,
public_key: input.public_key,
task_id: input.task_id,
model_name: input.model_name,
batchable: input.batchable,
stats: input.stats.record_execution_time(started_at),
};

if let Err(e) = publish_tx.send(output).await {
log::error!("Error sending workflow result: {}", e);
}
}
}
4 changes: 1 addition & 3 deletions monitor/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ async fn main() -> eyre::Result<()> {
let (client, commander, msg_rx) = DriaP2PClient::new(
keypair,
listen_addr,
nodes.bootstrap_nodes.into_iter(),
nodes.relay_nodes.into_iter(),
nodes.rpc_nodes.into_iter(),
&nodes,
DriaP2PProtocol::new_major_minor(network.protocol_name()),
)?;

Expand Down
Loading

0 comments on commit 650bdf5

Please sign in to comment.