diff --git a/roles/tests-integration/lib/sniffer.rs b/roles/tests-integration/lib/sniffer.rs index 6b680a4f5..945e6736e 100644 --- a/roles/tests-integration/lib/sniffer.rs +++ b/roles/tests-integration/lib/sniffer.rs @@ -219,31 +219,32 @@ impl Sniffer { ) -> Result<(), SnifferError> { while let Ok(mut frame) = recv.recv().await { let (msg_type, msg) = Self::message_from_frame(&mut frame); - for intercept_message in intercept_messages.iter() { - if intercept_message.direction == MessageDirection::ToUpstream - && intercept_message.expected_message_type == msg_type - { - let extension_type = 0; - let channel_msg = false; - let frame = StandardEitherFrame::>::Sv2( - Sv2Frame::from_message( - intercept_message.replacement_message.clone(), - intercept_message.replacement_message.message_type(), - extension_type, - channel_msg, - ) - .expect("Failed to create the frame"), - ); - downstream_messages - .add_message(msg_type, intercept_message.replacement_message.clone()); - let _ = send.send(frame).await; - } + let intercept_message = intercept_messages.iter().find(|im| { + im.direction == MessageDirection::ToUpstream && im.expected_message_type == msg_type + }); + if let Some(intercept_message) = intercept_message { + let intercept_frame = StandardEitherFrame::>::Sv2( + Sv2Frame::from_message( + intercept_message.replacement_message.clone(), + intercept_message.replacement_message.message_type(), + 0, + false, + ) + .expect("Failed to create the frame"), + ); + downstream_messages.add_message( + intercept_message.replacement_message.message_type(), + intercept_message.replacement_message.clone(), + ); + send.send(intercept_frame) + .await + .map_err(|_| SnifferError::UpstreamClosed)?; + } else { + downstream_messages.add_message(msg_type, msg); + send.send(frame) + .await + .map_err(|_| SnifferError::UpstreamClosed)?; } - - downstream_messages.add_message(msg_type, msg); - if send.send(frame).await.is_err() { - return Err(SnifferError::UpstreamClosed); - }; } Err(SnifferError::DownstreamClosed) } @@ -256,30 +257,33 @@ impl Sniffer { ) -> Result<(), SnifferError> { while let Ok(mut frame) = recv.recv().await { let (msg_type, msg) = Self::message_from_frame(&mut frame); - for intercept_message in intercept_messages.iter() { - if intercept_message.direction == MessageDirection::ToDownstream - && intercept_message.expected_message_type == msg_type - { - let extension_type = 0; - let channel_msg = false; - let frame = StandardEitherFrame::>::Sv2( - Sv2Frame::from_message( - intercept_message.replacement_message.clone(), - intercept_message.replacement_message.message_type(), - extension_type, - channel_msg, - ) - .expect("Failed to create the frame"), - ); - upstream_messages - .add_message(msg_type, intercept_message.replacement_message.clone()); - let _ = send.send(frame).await; - } + let intercept_message = intercept_messages.iter().find(|im| { + im.direction == MessageDirection::ToDownstream + && im.expected_message_type == msg_type + }); + if let Some(intercept_message) = intercept_message { + let intercept_frame = StandardEitherFrame::>::Sv2( + Sv2Frame::from_message( + intercept_message.replacement_message.clone(), + intercept_message.replacement_message.message_type(), + 0, + false, + ) + .expect("Failed to create the frame"), + ); + upstream_messages.add_message( + intercept_message.replacement_message.message_type(), + intercept_message.replacement_message.clone(), + ); + send.send(intercept_frame) + .await + .map_err(|_| SnifferError::DownstreamClosed)?; + } else { + upstream_messages.add_message(msg_type, msg); + send.send(frame) + .await + .map_err(|_| SnifferError::DownstreamClosed)?; } - if send.send(frame).await.is_err() { - return Err(SnifferError::DownstreamClosed); - }; - upstream_messages.add_message(msg_type, msg); } Err(SnifferError::UpstreamClosed) } diff --git a/roles/tests-integration/tests/sniffer_integration.rs b/roles/tests-integration/tests/sniffer_integration.rs index 93440f52a..d6be3d3eb 100644 --- a/roles/tests-integration/tests/sniffer_integration.rs +++ b/roles/tests-integration/tests/sniffer_integration.rs @@ -1,7 +1,10 @@ -use const_sv2::{MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS, MESSAGE_TYPE_SET_NEW_PREV_HASH}; +use const_sv2::{ + MESSAGE_TYPE_SETUP_CONNECTION, MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS, + MESSAGE_TYPE_SET_NEW_PREV_HASH, +}; use integration_tests_sv2::*; use roles_logic_sv2::{ - common_messages_sv2::SetupConnectionError, + common_messages_sv2::{Protocol, SetupConnection, SetupConnectionError}, parsers::{CommonMessages, PoolMessages}, }; use sniffer::{InterceptMessage, MessageDirection}; @@ -12,7 +15,7 @@ use std::convert::TryInto; // sniffer_b asserts that Pool is about to receive a SetupConnectionError // TP -> sniffer_a -> sniffer_b -> Pool #[tokio::test] -async fn test_sniffer_intercept() { +async fn test_sniffer_intercept_to_downstream() { let (_tp, tp_addr) = start_template_provider(None).await; let message_replacement = PoolMessages::Common(CommonMessages::SetupConnectionError(SetupConnectionError { @@ -47,6 +50,57 @@ async fn test_sniffer_intercept() { ); } +#[tokio::test] +async fn test_sniffer_intercept_to_upstream() { + let (_tp, tp_addr) = start_template_provider(None).await; + let setup_connection = SetupConnection { + protocol: Protocol::TemplateDistributionProtocol, + min_version: 2, + max_version: 2, + flags: 0, + endpoint_host: "0.0.0.0".to_string().into_bytes().try_into().unwrap(), + endpoint_port: 8081, + vendor: "Bitmain".to_string().into_bytes().try_into().unwrap(), + hardware_version: "901".to_string().into_bytes().try_into().unwrap(), + firmware: "abcX".to_string().into_bytes().try_into().unwrap(), + device_id: "89567".to_string().into_bytes().try_into().unwrap(), + }; + let message_replacement = + PoolMessages::Common(CommonMessages::SetupConnection(setup_connection)); + let intercept = InterceptMessage::new( + MessageDirection::ToUpstream, + MESSAGE_TYPE_SETUP_CONNECTION, + message_replacement, + ); + + let (sniffer_a, sniffer_a_addr) = + start_sniffer("A".to_string(), tp_addr, false, Some(vec![intercept])).await; + + let (_sniffer_b, sniffer_b_addr) = + start_sniffer("B".to_string(), sniffer_a_addr, false, None).await; + + let _ = start_pool(Some(sniffer_b_addr)).await; + + assert_common_message!( + &sniffer_a.next_message_from_downstream(), + SetupConnection, + protocol, + Protocol::TemplateDistributionProtocol, + flags, + 0, + min_version, + 2, + max_version, + 2, + endpoint_host, + "0.0.0.0".to_string().into_bytes().try_into().unwrap(), + endpoint_port, + 8081, + vendor, + "Bitmain".to_string().into_bytes().try_into().unwrap() + ); +} + #[tokio::test] async fn test_sniffer_wait_for_message_type_with_remove() { let (_tp, tp_addr) = start_template_provider(None).await;