Skip to content

Commit

Permalink
add /api/tasks/:task_id/collector_auth_tokens (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbr authored Jun 27, 2023
1 parent 90da966 commit b367a1f
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 38 deletions.
29 changes: 26 additions & 3 deletions src/aggregator_api_mock.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::clients::aggregator_client::api_types::{
HpkeAeadId, HpkeConfig, HpkeKdfId, HpkeKemId, HpkePublicKey, JanusDuration, TaskCreate, TaskId,
TaskIds, TaskMetrics, TaskResponse,
HpkeAeadId, HpkeConfig, HpkeKdfId, HpkeKemId, HpkePublicKey, JanusDuration, QueryType, Role,
TaskCreate, TaskId, TaskIds, TaskMetrics, TaskResponse, VdafInstance,
};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use fastrand::alphanumeric;
Expand All @@ -11,14 +11,15 @@ use std::iter::repeat_with;
use trillium::{Conn, Handler, Status};
use trillium_api::{api, Json};
use trillium_logger::{dev_formatter, logger};
use trillium_router::router;
use trillium_router::{router, RouterConnExt};
use uuid::Uuid;

pub fn aggregator_api() -> impl Handler {
(
logger().with_formatter(("[aggregator mock] ", dev_formatter)),
router()
.post("/tasks", api(post_task))
.get("/tasks/:task_id", api(get_task))
.get("/task_ids", api(task_ids))
.delete("/tasks/:task_id", Status::Ok)
.get("/tasks/:task_id/metrics", api(get_task_metrics)),
Expand All @@ -32,6 +33,28 @@ async fn get_task_metrics(_: &mut Conn, (): ()) -> Json<TaskMetrics> {
})
}

async fn get_task(conn: &mut Conn, (): ()) -> Json<TaskResponse> {
let task_id = conn.param("task_id").unwrap();
Json(TaskResponse {
task_id: task_id.parse().unwrap(),
peer_aggregator_endpoint: "https://_".parse().unwrap(),
query_type: QueryType::TimeInterval,
vdaf: VdafInstance::Prio3Count,
role: Role::Leader,
vdaf_verify_keys: vec![repeat_with(alphanumeric).take(10).collect()],
max_batch_query_count: 100,
task_expiration: None,
report_expiry_age: None,
min_batch_size: 1000,
time_precision: JanusDuration::from_seconds(60),
tolerable_clock_skew: JanusDuration::from_seconds(60),
collector_hpke_config: random_hpke_config(),
aggregator_auth_token: Some(repeat_with(fastrand::alphanumeric).take(32).collect()),
collector_auth_token: Some(repeat_with(fastrand::alphanumeric).take(32).collect()),
aggregator_hpke_configs: repeat_with(random_hpke_config).take(5).collect(),
})
}

async fn post_task(_: &mut Conn, Json(task_create): Json<TaskCreate>) -> Json<TaskResponse> {
Json(task_response(task_create))
}
Expand Down
4 changes: 4 additions & 0 deletions src/clients/aggregator_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ impl AggregatorClient {
}
}

pub async fn get_task(&self, task_id: &str) -> Result<TaskResponse, ClientError> {
self.get(&format!("/tasks/{task_id}")).await
}

pub async fn get_task_metrics(&self, task_id: &str) -> Result<TaskMetrics, ClientError> {
self.get(&format!("/tasks/{task_id}/metrics")).await
}
Expand Down
5 changes: 0 additions & 5 deletions src/entity/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,6 @@ impl Model {
self.deleted_at.is_some()
}

pub fn is_first_party(&self) -> bool {
// probably temporary
matches!(self.dap_url.domain(), Some(domain) if domain.ends_with("divviup.org"))
}

pub fn client(&self, http_client: trillium_client::Client) -> AggregatorClient {
AggregatorClient::new(http_client, self.clone())
}
Expand Down
32 changes: 14 additions & 18 deletions src/entity/task.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::{
clients::aggregator_client::{api_types::TaskResponse, TaskMetrics},
entity::{account, membership},
entity::{
account, membership, AccountColumn, Accounts, Aggregator, AggregatorColumn, Aggregators,
},
};
use sea_orm::{entity::prelude::*, ActiveValue::Set, IntoActiveModel};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -64,21 +66,15 @@ impl Model {
.ok_or(DbErr::Custom("expected leader aggregator".into()))?
}

pub async fn helper_aggregator(
&self,
db: &impl ConnectionTrait,
) -> Result<super::Aggregator, DbErr> {
super::Aggregators::find_by_id(self.leader_aggregator_id)
pub async fn helper_aggregator(&self, db: &impl ConnectionTrait) -> Result<Aggregator, DbErr> {
Aggregators::find_by_id(self.helper_aggregator_id)
.one(db)
.await
.transpose()
.ok_or(DbErr::Custom("expected helper aggregator".into()))?
}

pub async fn aggregators(
&self,
db: &impl ConnectionTrait,
) -> Result<[super::Aggregator; 2], DbErr> {
pub async fn aggregators(&self, db: &impl ConnectionTrait) -> Result<[Aggregator; 2], DbErr> {
let (leader, helper) =
futures_lite::future::try_zip(self.leader_aggregator(db), self.helper_aggregator(db))
.await?;
Expand All @@ -88,35 +84,35 @@ impl Model {
pub async fn first_party_aggregator(
&self,
db: &impl ConnectionTrait,
) -> Result<Option<super::Aggregator>, DbErr> {
) -> Result<Option<Aggregator>, DbErr> {
Ok(self
.aggregators(db)
.await?
.into_iter()
.find(|agg| agg.is_first_party()))
.find(|agg| agg.is_first_party))
}
}

#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::Accounts",
belongs_to = "Accounts",
from = "Column::AccountId",
to = "super::account::Column::Id"
to = "AccountColumn::Id"
)]
Account,

#[sea_orm(
belongs_to = "super::Aggregators",
belongs_to = "Aggregators",
from = "Column::HelperAggregatorId",
to = "super::AggregatorColumn::Id"
to = "AggregatorColumn::Id"
)]
HelperAggregator,

#[sea_orm(
belongs_to = "super::Aggregators",
belongs_to = "Aggregators",
from = "Column::LeaderAggregatorId",
to = "super::AggregatorColumn::Id"
to = "AggregatorColumn::Id"
)]
LeaderAggregator,
}
Expand Down
2 changes: 1 addition & 1 deletion src/entity/task/new_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl NewTask {
errors.add("helper_aggregator_id", ValidationError::new("same"));
}

if !leader.is_first_party() && !helper.is_first_party() {
if !leader.is_first_party && !helper.is_first_party {
errors.add(
"leader_aggregator_id",
ValidationError::new("no-first-party"),
Expand Down
4 changes: 4 additions & 0 deletions src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ fn api_routes(config: &ApiConfig) -> impl Handler {
.post("/accounts", api(accounts::create))
.delete("/memberships/:membership_id", api(memberships::delete))
.get("/tasks/:task_id", api(tasks::show))
.get(
"/tasks/:task_id/collector_auth_tokens",
api(tasks::collector_auth_tokens::index),
)
.patch("/tasks/:task_id", api(tasks::update))
.patch("/aggregators/:aggregator_id", api(aggregators::update))
.get("/aggregators/:aggregator_id", api(aggregators::show))
Expand Down
13 changes: 13 additions & 0 deletions src/routes/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,16 @@ pub async fn update(
let task = update.build(task)?.update(&db).await?;
Ok(Json(task))
}

pub mod collector_auth_tokens {
use super::*;
pub async fn index(
_: &mut Conn,
(task, db, State(client)): (Task, Db, State<Client>),
) -> Result<impl Handler, Error> {
let leader = task.leader_aggregator(&db).await?;
let client = leader.client(client);
let task_response = client.get_task(&task.id).await?;
Ok(Json([task_response.collector_auth_token]))
}
}
3 changes: 3 additions & 0 deletions tests/aggregators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ mod shared_create {

assert_response!(conn, 201);
let aggregator: Aggregator = conn.response_json().await;
assert!(aggregator.is_first_party);
assert_eq!(
aggregator.dap_url,
Url::parse(&new_aggregator.dap_url.unwrap()).unwrap()
Expand All @@ -802,11 +803,13 @@ mod shared_create {
assert_eq!(aggregator.role.as_ref(), new_aggregator.role.unwrap());
assert!(aggregator.account_id.is_none());
assert!(aggregator.is_first_party);

let aggregator_from_db = Aggregators::find_by_id(aggregator.id)
.one(app.db())
.await?
.unwrap();

assert!(aggregator_from_db.is_first_party);
assert_same_json_representation(&aggregator, &aggregator_from_db);

Ok(())
Expand Down
6 changes: 3 additions & 3 deletions tests/harness/client_logs.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use serde_json::Value;
use serde::Deserialize;
use std::{
fmt::{Display, Formatter, Result},
sync::{Arc, RwLock},
Expand All @@ -17,8 +17,8 @@ pub struct LoggedConn {
}

impl LoggedConn {
pub fn response_json(&self) -> Value {
serde_json::from_str(self.response_body.as_ref().unwrap()).unwrap()
pub fn response_json<'a: 'de, 'de, T: Deserialize<'de>>(&'a self) -> T {
serde_json::from_str(self.response_body.as_ref().unwrap()).expect("deserialization error")
}
}

Expand Down
1 change: 1 addition & 0 deletions tests/harness/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub use sea_orm::{
};
pub use serde_json::{json, Value};
pub use test_harness::test;
pub use time::OffsetDateTime;
pub use trillium::{Conn, KnownHeaderName, Method, Status};
pub use trillium_testing::prelude::*;
pub use url::Url;
Expand Down
4 changes: 2 additions & 2 deletions tests/jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async fn create_account(app: DivviupApi, client_logs: ClientLogs) -> TestResult
create_user_request.url,
app.config().auth_url.join("/api/v2/users").unwrap()
);
let user_id = create_user_request.response_json()["user_id"]
let user_id = create_user_request.response_json::<Value>()["user_id"]
.as_str()
.unwrap()
.to_string();
Expand Down Expand Up @@ -50,7 +50,7 @@ async fn reset_password(app: DivviupApi, client_logs: ClientLogs) -> TestResult
.join("/api/v2/tickets/password-change")
.unwrap()
);
let action_url = reset_request.response_json()["ticket"]
let action_url = reset_request.response_json::<Value>()["ticket"]
.as_str()
.unwrap()
.parse()
Expand Down
94 changes: 88 additions & 6 deletions tests/tasks.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod harness;
use divviup_api::clients::aggregator_client::*;
use harness::*;

mod index {
Expand Down Expand Up @@ -208,6 +209,7 @@ mod create {

mod show {
use super::{assert_eq, test, *};
use time::Duration;

#[test(harness = set_up)]
async fn as_member(app: DivviupApi) -> TestResult {
Expand All @@ -224,23 +226,40 @@ mod show {
Ok(())
}

#[test(harness = set_up)]
async fn metrics_caching(app: DivviupApi) -> TestResult {
#[test(harness = with_client_logs)]
async fn metrics_caching(app: DivviupApi, client_logs: ClientLogs) -> TestResult {
let (user, account, ..) = fixtures::member(&app).await;
let task = fixtures::task(&app, &account).await;
let mut task = task.into_active_model();
task.updated_at =
Set(time::OffsetDateTime::now_utc() - std::time::Duration::from_secs(10 * 60));
task.updated_at = ActiveValue::Set(OffsetDateTime::now_utc() - Duration::minutes(10));
let task = task.update(app.db()).await?;

let first_party_aggregator = task.first_party_aggregator(app.db()).await?.unwrap();

let mut conn = get(format!("/api/tasks/{}", task.id))
.with_api_headers()
.with_state(user.clone())
.run_async(&app)
.await;
assert_ok!(conn);

let aggregator_api_request = client_logs.last();
assert_eq!(
aggregator_api_request.url,
first_party_aggregator
.api_url
.join(&format!("/tasks/{}/metrics", task.id))
.unwrap()
);
let metrics: TaskMetrics = aggregator_api_request.response_json();

let response_task: Task = conn.response_json().await;
assert!(response_task.report_count != task.report_count);
assert!(response_task.aggregate_collection_count != task.aggregate_collection_count);

assert_eq!(response_task.report_count, metrics.reports as i32);
assert_eq!(
response_task.aggregate_collection_count,
metrics.report_aggregations as i32
);
assert!(response_task.updated_at > task.updated_at);

let mut conn = get(format!("/api/tasks/{}", task.id))
Expand Down Expand Up @@ -428,3 +447,66 @@ mod update {
Ok(())
}
}

mod collector_auth_tokens {
use super::{assert_eq, test, *};

#[test(harness = with_client_logs)]
async fn as_member(app: DivviupApi, client_logs: ClientLogs) -> TestResult {
let (user, account, ..) = fixtures::member(&app).await;
let task = fixtures::task(&app, &account).await;
let mut conn = get(format!("/api/tasks/{}/collector_auth_tokens", task.id))
.with_api_headers()
.with_state(user)
.run_async(&app)
.await;

let auth_token = client_logs
.last()
.response_json::<TaskResponse>()
.collector_auth_token
.unwrap();

assert_ok!(conn);
let body: Vec<String> = conn.response_json().await;
assert_eq!(vec![auth_token], body);
Ok(())
}

#[test(harness = with_client_logs)]
async fn as_rando(app: DivviupApi, client_logs: ClientLogs) -> TestResult {
let user = fixtures::user();
let account = fixtures::account(&app).await;
let task = fixtures::task(&app, &account).await;
let mut conn = get(format!("/api/tasks/{}/collector_auth_tokens", task.id))
.with_api_headers()
.with_state(user)
.run_async(&app)
.await;
assert!(client_logs.logs().is_empty());
assert_not_found!(conn);
Ok(())
}

#[test(harness = with_client_logs)]
async fn as_admin(app: DivviupApi, client_logs: ClientLogs) -> TestResult {
let (admin, ..) = fixtures::admin(&app).await;
let account = fixtures::account(&app).await;
let task = fixtures::task(&app, &account).await;
let mut conn = get(format!("/api/tasks/{}/collector_auth_tokens", task.id))
.with_api_headers()
.with_state(admin)
.run_async(&app)
.await;
let auth_token = client_logs
.last()
.response_json::<TaskResponse>()
.collector_auth_token
.unwrap();

assert_ok!(conn);
let body: Vec<String> = conn.response_json().await;
assert_eq!(vec![auth_token], body);
Ok(())
}
}

0 comments on commit b367a1f

Please sign in to comment.