Skip to content

Commit

Permalink
Merge pull request #91 from firstbatchxyz/erhant/gossipsub-validation
Browse files Browse the repository at this point in the history
Added GossipSub message validation
  • Loading branch information
erhant authored Aug 15, 2024
2 parents 98931bb + d9bb968 commit 384e1b7
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 99 deletions.
23 changes: 19 additions & 4 deletions src/handlers/pingpong.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{
errors::NodeResult, node::DriaComputeNode, p2p::P2PMessage, utils::get_current_time_nanos,
};
use libp2p::gossipsub::MessageAcceptance;
use ollama_workflows::{Model, ModelProvider};
use serde::{Deserialize, Serialize};

Expand All @@ -14,16 +15,25 @@ struct PingpongPayload {
struct PingpongResponse {
pub(crate) uuid: String,
pub(crate) models: Vec<(ModelProvider, Model)>,
pub(crate) timestamp: u128,
}

/// A ping-pong is a message sent by a node to indicate that it is alive.
/// Compute nodes listen to `pong` topic, and respond to `ping` topic.
pub trait HandlesPingpong {
fn handle_heartbeat(&mut self, message: P2PMessage, result_topic: &str) -> NodeResult<()>;
fn handle_heartbeat(
&mut self,
message: P2PMessage,
result_topic: &str,
) -> NodeResult<MessageAcceptance>;
}

impl HandlesPingpong for DriaComputeNode {
fn handle_heartbeat(&mut self, message: P2PMessage, result_topic: &str) -> NodeResult<()> {
fn handle_heartbeat(
&mut self,
message: P2PMessage,
result_topic: &str,
) -> NodeResult<MessageAcceptance> {
let pingpong = message.parse_payload::<PingpongPayload>(true)?;

// check deadline
Expand All @@ -35,21 +45,26 @@ impl HandlesPingpong for DriaComputeNode {
current_time,
pingpong.deadline
);
return Ok(());

// ignore message due to past deadline
return Ok(MessageAcceptance::Ignore);
}

// respond
let response_body = PingpongResponse {
uuid: pingpong.uuid.clone(),
models: self.config.model_config.models.clone(),
timestamp: get_current_time_nanos(),
};
let response = P2PMessage::new_signed(
serde_json::json!(response_body).to_string(),
result_topic,
&self.config.secret_key,
);
self.publish(response)?;
Ok(())

// accept message, someone else may be included in the filter
Ok(MessageAcceptance::Accept)
}
}

Expand Down
132 changes: 88 additions & 44 deletions src/handlers/workflow.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use async_trait::async_trait;
use libp2p::gossipsub::MessageAcceptance;
use ollama_workflows::{Entry, Executor, ModelProvider, ProgramMemory, Workflow};
use serde::Deserialize;

use crate::errors::NodeResult;
use crate::node::DriaComputeNode;
use crate::p2p::P2PMessage;
use crate::utils::get_current_time_nanos;
use crate::utils::payload::{TaskRequest, TaskRequestPayload};

#[derive(Debug, Deserialize)]
struct WorkflowPayload {
Expand All @@ -21,58 +24,99 @@ struct WorkflowPayload {

#[async_trait]
pub trait HandlesWorkflow {
async fn handle_workflow(&mut self, message: P2PMessage, result_topic: &str) -> NodeResult<()>;
async fn handle_workflow(
&mut self,
message: P2PMessage,
result_topic: &str,
) -> NodeResult<MessageAcceptance>;
}

#[async_trait]
impl HandlesWorkflow for DriaComputeNode {
async fn handle_workflow(&mut self, message: P2PMessage, result_topic: &str) -> NodeResult<()> {
if let Some(task) =
self.parse_topiced_message_to_task_request::<WorkflowPayload>(message)?
{
// read model / provider from the task
let (model_provider, model) = self
.config
.model_config
.get_any_matching_model(task.input.model)?;
log::info!("Using model {} for task {}", model, task.task_id);
async fn handle_workflow(
&mut self,
message: P2PMessage,
result_topic: &str,
) -> NodeResult<MessageAcceptance> {
let task = message.parse_payload::<TaskRequestPayload<WorkflowPayload>>(true)?;

// execute workflow with cancellation
let executor = if model_provider == ModelProvider::Ollama {
Executor::new_at(
model,
&self.config.ollama_config.host,
self.config.ollama_config.port,
)
} else {
Executor::new(model)
};
let mut memory = ProgramMemory::new();
let entry: Option<Entry> = task
.input
.prompt
.map(|prompt| Entry::try_value_or_str(&prompt));
let result: Option<String>;
tokio::select! {
_ = self.cancellation.cancelled() => {
log::info!("Received cancellation, quitting all tasks.");
return Ok(())
},
exec_result = executor.execute(entry.as_ref(), task.input.workflow, &mut memory) => {
if exec_result.is_empty() {
return Err(format!("Got empty string result for task {}", task.task_id).into());
} else {
result = Some(exec_result);
}
// check if deadline is past or not
let current_time = get_current_time_nanos();
if current_time >= task.deadline {
log::debug!(
"Task (id: {}) is past the deadline, ignoring. (local: {}, deadline: {})",
task.task_id,
current_time,
task.deadline
);

// ignore the message
return Ok(MessageAcceptance::Ignore);
}

// check task inclusion via the bloom filter
if !task.filter.contains(&self.config.address)? {
log::info!(
"Task {} does not include this node within the filter.",
task.task_id
);

// accept the message, someonelse may be included in filter
return Ok(MessageAcceptance::Accept);
}

// obtain public key from the payload
let task_public_key = hex::decode(&task.public_key)?;

let task = TaskRequest {
task_id: task.task_id,
input: task.input,
public_key: task_public_key,
};

// read model / provider from the task
let (model_provider, model) = self
.config
.model_config
.get_any_matching_model(task.input.model)?;
log::info!("Using model {} for task {}", model, task.task_id);

// execute workflow with cancellation
let executor = if model_provider == ModelProvider::Ollama {
Executor::new_at(
model,
&self.config.ollama_config.host,
self.config.ollama_config.port,
)
} else {
Executor::new(model)
};
let mut memory = ProgramMemory::new();
let entry: Option<Entry> = task
.input
.prompt
.map(|prompt| Entry::try_value_or_str(&prompt));
let result: Option<String>;
tokio::select! {
_ = self.cancellation.cancelled() => {
log::info!("Received cancellation, quitting all tasks.");
return Ok(MessageAcceptance::Accept)
},
exec_result = executor.execute(entry.as_ref(), task.input.workflow, &mut memory) => {
if exec_result.is_empty() {
return Err(format!("Got empty string result for task {}", task.task_id).into());
} else {
result = Some(exec_result);
}
}
let result =
result.ok_or::<String>(format!("No result for task {}", task.task_id).into())?;

// publish the result
self.send_result(result_topic, &task.public_key, &task.task_id, result)?;
}
let result =
result.ok_or::<String>(format!("No result for task {}", task.task_id).into())?;

// publish the result
self.send_result(result_topic, &task.public_key, &task.task_id, result)?;

Ok(())
// accept message, someone else may be included in the filter
Ok(MessageAcceptance::Accept)
}
}
75 changes: 26 additions & 49 deletions src/node.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::str::FromStr;

use libp2p::{gossipsub, Multiaddr};
use serde::Deserialize;
use tokio::signal::unix::{signal, SignalKind};
use tokio_util::sync::CancellationToken;

Expand All @@ -10,11 +9,7 @@ use crate::{
errors::NodeResult,
handlers::{HandlesPingpong, HandlesWorkflow},
p2p::{P2PClient, P2PMessage},
utils::{
crypto::secret_to_keypair,
get_current_time_nanos,
payload::{TaskRequest, TaskRequestPayload},
},
utils::crypto::secret_to_keypair,
};

pub struct DriaComputeNode {
Expand Down Expand Up @@ -133,25 +128,47 @@ impl DriaComputeNode {
}
};



// then handle the prepared message
if let Err(err) = match topic_str {
let handle_result = match topic_str {
WORKFLOW_LISTEN_TOPIC => {
self.handle_workflow(message, WORKFLOW_RESPONSE_TOPIC).await
}
PINGPONG_LISTEN_TOPIC => {
self.handle_heartbeat(message, PINGPONG_RESPONSE_TOPIC)
}
// TODO: can we do this in a nicer way?
// TODO: yes, cast to enum above and let type-casting do the work
_ => unreachable!() // unreachable because of the if condition
} {
log::error!("Error handling {} message: {}", topic_str, err);
};

// validate the message based on the result
match handle_result {
Ok(acceptance) => {
// TODO: !!! remove me
log::info!(
"Validating message with ID: {}\nFrom: {}\nAcceptance: {:?}",
message_id,
peer_id,
acceptance
);
self.p2p.validate_message(&message_id, &peer_id, acceptance)?;
},
Err(err) => log::error!("Error handling {} message: {}", topic_str, err)
}
} else if std::matches!(topic_str, PINGPONG_RESPONSE_TOPIC | WORKFLOW_RESPONSE_TOPIC) {
// since we are responding to these topics, we might receive messages from other compute nodes
// we can gracefully ignore them
log::trace!("Ignoring message for topic: {}", topic_str);

// accept this message for propagation
self.p2p.validate_message(&message_id, &peer_id, gossipsub::MessageAcceptance::Accept)?;
} else {
log::warn!("Received unexpected message from topic: {}", topic_str);

// reject this message as its from a foreign topic
self.p2p.validate_message(&message_id, &peer_id, gossipsub::MessageAcceptance::Reject)?;
}

}
Expand Down Expand Up @@ -189,46 +206,6 @@ impl DriaComputeNode {
Ok(message)
}

pub fn parse_topiced_message_to_task_request<T>(
&self,
message: P2PMessage,
) -> NodeResult<Option<TaskRequest<T>>>
where
T: for<'a> Deserialize<'a>,
{
let task = message.parse_payload::<TaskRequestPayload<T>>(true)?;

// check if deadline is past or not
let current_time = get_current_time_nanos();
if current_time >= task.deadline {
log::debug!(
"Task (id: {}) is past the deadline, ignoring. (local: {}, deadline: {})",
task.task_id,
current_time,
task.deadline
);
return Ok(None);
}

// check task inclusion via the bloom filter
if !task.filter.contains(&self.config.address)? {
log::info!(
"Task {} does not include this node within the filter.",
task.task_id
);
return Ok(None);
}

// obtain public key from the payload
let task_public_key = hex::decode(&task.public_key)?;

Ok(Some(TaskRequest {
task_id: task.task_id,
input: task.input,
public_key: task_public_key,
}))
}

/// Given a task with `id` and respective `public_key`, sign-then-encrypt the result.
pub fn send_result<R: AsRef<[u8]>>(
&mut self,
Expand Down
3 changes: 2 additions & 1 deletion src/p2p/behaviour.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ fn create_gossipsub_behavior(id_keys: Keypair) -> gossipsub::Behaviour {
ConfigBuilder::default()
.heartbeat_interval(Duration::from_secs(10))
.max_transmit_size(262144) // 256 KB
.validation_mode(gossipsub::ValidationMode::Strict)
.validation_mode(gossipsub::ValidationMode::Strict) // TODO!!
.validate_messages()
.message_id_fn(message_id_fn)
.build()
.expect("Valid config"), // TODO: better error handling
Expand Down
29 changes: 28 additions & 1 deletion src/p2p/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use libp2p::futures::StreamExt;
use libp2p::gossipsub::{Message, MessageId, PublishError, SubscriptionError, TopicHash};
use libp2p::gossipsub::{
Message, MessageAcceptance, MessageId, PublishError, SubscriptionError, TopicHash,
};
use libp2p::kad::{GetClosestPeersError, GetClosestPeersOk, QueryResult};
use libp2p::{gossipsub, identify, kad, multiaddr::Protocol, noise, swarm::SwarmEvent, tcp, yamux};
use libp2p::{Multiaddr, PeerId, Swarm, SwarmBuilder};
Expand Down Expand Up @@ -179,6 +181,31 @@ impl P2PClient {
Ok(message_id)
}

/// Validates a GossipSub message for propagation.
///
/// - `Accept`: Accept the message and propagate it.
/// - `Reject`: Reject the message and do not propagate it, with penalty to `propagation_source`.
/// - `Ignore`: Ignore the message and do not propagate it, without any penalties.
///
/// See [`validate_messages`](https://docs.rs/libp2p-gossipsub/latest/libp2p_gossipsub/struct.Config.html#method.validate_messages)
/// and [`report_message_validation_result`](https://docs.rs/libp2p-gossipsub/latest/libp2p_gossipsub/struct.Behaviour.html#method.report_message_validation_result) for more details.
pub fn validate_message(
&mut self,
msg_id: &MessageId,
propagation_source: &PeerId,
acceptance: MessageAcceptance,
) -> Result<(), PublishError> {
let msg_was_in_cache = self
.swarm
.behaviour_mut()
.gossipsub
.report_message_validation_result(msg_id, propagation_source, acceptance)?;

if !msg_was_in_cache {
log::debug!("Validated message was not in cache.");
}
Ok(())
}
/// Returns the list of connected peers within Gossipsub, with a list of subscribed topic hashes by each peer.
pub fn peers(&self) -> Vec<(&PeerId, Vec<&TopicHash>)> {
self.swarm
Expand Down
Loading

0 comments on commit 384e1b7

Please sign in to comment.