From 6296924d0de1fc9cc701ce14a8fc5c26dc1f2fde Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Fri, 13 Oct 2023 17:10:27 -0700 Subject: [PATCH] only allow using a collector credential with a token hash present when the leader is +TokenHash --- .../TaskForm/CollectorCredentialSelect.tsx | 58 ++++++++++--------- src/entity/task/new_task.rs | 26 ++++++++- 2 files changed, 54 insertions(+), 30 deletions(-) diff --git a/app/src/tasks/TaskForm/CollectorCredentialSelect.tsx b/app/src/tasks/TaskForm/CollectorCredentialSelect.tsx index 9c37a261..e73b489d 100644 --- a/app/src/tasks/TaskForm/CollectorCredentialSelect.tsx +++ b/app/src/tasks/TaskForm/CollectorCredentialSelect.tsx @@ -1,23 +1,37 @@ -import { Await, useLoaderData } from "react-router-dom"; import FormControl from "react-bootstrap/FormControl"; import FormSelect from "react-bootstrap/FormSelect"; -import React, { Suspense } from "react"; -import { CollectorCredential } from "../../ApiClient"; +import React from "react"; +import { Aggregator, CollectorCredential } from "../../ApiClient"; import { Props, TaskFormGroup } from "."; import { ShortHelpAndLabel } from "./HelpText"; +import { useLoaderPromise } from "../../util"; export default function CollectorCredentialSelect(props: Props) { - const { collectorCredentials } = useLoaderData() as { - collectorCredentials: Promise; - }; - const { setFieldValue } = props; + const collectorCredentials = useLoaderPromise( + "collectorCredentials", + [], + ); + const aggregators = useLoaderPromise("aggregators", []); + const leader = React.useMemo( + () => + aggregators.find(({ id }) => id === props.values.leader_aggregator_id), + [props.values.leader_aggregator_id, aggregators], + ); + const enabledCredentials = React.useMemo( + () => + leader && leader.features.includes("TokenHash") + ? collectorCredentials.filter( + (collectorCredential) => !!collectorCredential.token_hash, + ) + : collectorCredentials, + [collectorCredentials, leader], + ); React.useEffect(() => { - collectorCredentials.then((configs) => { - if (configs.length === 1) - setFieldValue("collector_credential_id", configs[0].id); - }); - }, [collectorCredentials, setFieldValue]); + if (enabledCredentials.length === 1) { + props.setFieldValue("collector_credential_id", enabledCredentials[0].id); + } + }, [enabledCredentials, props.setFieldValue]); return ( @@ -30,21 +44,11 @@ export default function CollectorCredentialSelect(props: Props) { id="collector-credential-id" name="collector_credential_id" > - - ...}> - - {(collectorCredentials: CollectorCredential[]) => - collectorCredentials.map((collectorCredential) => ( - - )) - } - - + {enabledCredentials.map((collectorCredential) => ( + + ))} {props.errors.collector_credential_id} diff --git a/src/entity/task/new_task.rs b/src/entity/task/new_task.rs index 57ac14fc..b1bfb5f9 100644 --- a/src/entity/task/new_task.rs +++ b/src/entity/task/new_task.rs @@ -105,15 +105,30 @@ impl NewTask { async fn validate_collector_credential( &self, account: &Account, + leader: Option<&Aggregator>, db: &impl ConnectionTrait, errors: &mut ValidationErrors, ) -> Option { match self.load_collector_credential(account, db).await { - Some(collector_credential) => Some(collector_credential), None => { errors.add("collector_credential_id", ValidationError::new("required")); None } + + Some(collector_credential) => { + let leader_needs_token_hash = + leader.map_or(false, |leader| leader.features.token_hash_enabled()); + + if leader_needs_token_hash && collector_credential.token_hash.is_none() { + errors.add( + "collector_credential_id", + ValidationError::new("missing-token-hash"), + ); + None + } else { + Some(collector_credential) + } + } } } @@ -255,10 +270,15 @@ impl NewTask { ) -> Result { let mut errors = Validate::validate(self).err().unwrap_or_default(); self.validate_min_lte_max(&mut errors); + let aggregators = self.validate_aggregators(&account, db, &mut errors).await; let collector_credential = self - .validate_collector_credential(&account, db, &mut errors) + .validate_collector_credential( + &account, + aggregators.as_ref().map(|(leader, ..)| leader), + db, + &mut errors, + ) .await; - let aggregators = self.validate_aggregators(&account, db, &mut errors).await; let aggregator_vdaf = if let Some((leader, helper, protocol)) = aggregators.as_ref() { self.validate_query_type_is_supported(leader, helper, &mut errors);