diff --git a/collector/src/lib.rs b/collector/src/lib.rs index 6853d445d..c6606823d 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -621,8 +621,8 @@ impl Collector { job.aggregation_parameter.get_encoded()?, BatchSelector::::new(B::batch_identifier_for_collection( &job.query, - &collect_response, - )?), + partial_batch_selector.batch_identifier(), + )), ) .get_encoded()?, )?; diff --git a/messages/src/batch_mode.rs b/messages/src/batch_mode.rs index 0d7ee901f..f3ee0dd51 100644 --- a/messages/src/batch_mode.rs +++ b/messages/src/batch_mode.rs @@ -1,4 +1,4 @@ -use crate::{BatchId, CollectionJobResp, Error, Interval, Query}; +use crate::{BatchId, Interval, Query}; use anyhow::anyhow; use num_enum::TryFromPrimitive; use prio::codec::{CodecError, Decode, Encode}; @@ -55,8 +55,8 @@ pub trait BatchMode: Clone + Debug + PartialEq + Eq + Send + Sync + 'static { /// Retrieves the batch identifier associated with an ongoing collection. fn batch_identifier_for_collection( query: &Query, - collection_job_resp: &CollectionJobResp, - ) -> Result; + partial_batch_identifier: &Self::PartialBatchIdentifier, + ) -> Self::BatchIdentifier; } /// Represents the `time-interval` DAP batch mode. @@ -76,9 +76,9 @@ impl BatchMode for TimeInterval { fn batch_identifier_for_collection( query: &Query, - _: &CollectionJobResp, - ) -> Result { - Ok(*query.batch_interval()) + _: &Self::PartialBatchIdentifier, + ) -> Self::BatchIdentifier { + *query.batch_interval() } } @@ -101,17 +101,9 @@ impl BatchMode for LeaderSelected { fn batch_identifier_for_collection( _: &Query, - collection_job_resp: &CollectionJobResp, - ) -> Result { - match collection_job_resp { - CollectionJobResp::Processing => Err(Error::InvalidParameter( - "collection job resp in Processing state", - )), - CollectionJobResp::Finished { - partial_batch_selector, - .. - } => Ok(*partial_batch_selector.batch_identifier()), - } + partial_batch_identifier: &Self::PartialBatchIdentifier, + ) -> Self::BatchIdentifier { + *partial_batch_identifier } } diff --git a/messages/src/tests/collection.rs b/messages/src/tests/collection.rs index 25840b587..ab5f6018b 100644 --- a/messages/src/tests/collection.rs +++ b/messages/src/tests/collection.rs @@ -150,8 +150,15 @@ fn roundtrip_collection_job_resp() { start: Time::from_seconds_since_epoch(54321), duration: Duration::from_seconds(12345), }; + // TimeInterval. roundtrip_encoding(&[ + ( + CollectionJobResp::::Processing, + concat!( + "00", // status + ), + ), ( CollectionJobResp::Finished { partial_batch_selector: PartialBatchSelector::new_time_interval(), @@ -276,6 +283,12 @@ fn roundtrip_collection_job_resp() { // LeaderSelected. roundtrip_encoding(&[ + ( + CollectionJobResp::::Processing, + concat!( + "00", // status + ), + ), ( CollectionJobResp::Finished { partial_batch_selector: PartialBatchSelector::new_leader_selected(BatchId::from( diff --git a/tools/src/bin/dap_decode.rs b/tools/src/bin/dap_decode.rs index 350009cac..2f7555fce 100644 --- a/tools/src/bin/dap_decode.rs +++ b/tools/src/bin/dap_decode.rs @@ -81,7 +81,7 @@ fn decode_dap_message(message_file: &str, media_type: &MediaType) -> Result { + MediaType::CollectionJobReq => { if let Ok(decoded) = CollectionJobReq::::get_decoded(&message_buf) { let message: CollectionJobReq = decoded; Box::new(message) @@ -91,7 +91,7 @@ fn decode_dap_message(message_file: &str, media_type: &MediaType) -> Result { + MediaType::CollectionJobResp => { if let Ok(decoded) = CollectionJobResp::::get_decoded(&message_buf) { let message: CollectionJobResp = decoded; Box::new(message) @@ -125,10 +125,10 @@ enum MediaType { AggregateShareReq, #[value(name = "aggregate-share")] AggregateShare, - #[value(name = "collect-req")] - CollectionReq, - #[value(name = "collection")] - Collection, + #[value(name = "collect-job-req")] + CollectionJobReq, + #[value(name = "collection-job-resp")] + CollectionJobResp, } #[derive(Debug, Parser)]