Skip to content

Commit

Permalink
refactor!: move one eval lengths from `ProofPlan::first_round_evaluat…
Browse files Browse the repository at this point in the history
…e` output to `FirstRoundBuilder` (#405)


Please be sure to look over the pull request guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md#submit-pr.

# Please go through the following checklist
- [x] The PR title and commit messages adhere to guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md.
In particular `!` is used if and only if at least one breaking change
has been introduced.
- [x] I have run the ci check script with `source
scripts/run_ci_checks.sh`.

# Rationale for this change
We should have obvious correspondence between how one eval lengths are
produced and consumed just like how we handle intermediate MLEs.
<!--
Why are you proposing this change? If this is already explained clearly
in the linked issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.

 Example:
 Add `NestedLoopJoinExec`.
 Closes #345.

Since we added `HashJoinExec` in #323 it has been possible to do
provable inner joins. However performance is not satisfactory in some
cases. Hence we need to fix the problem by implement
`NestedLoopJoinExec` and speed up the code
 for `HashJoinExec`.
-->

# What changes are included in this PR?
- move one eval lengths from `ProofPlan::first_round_builder` output to
`FirstRoundBuilder`
<!--
There is no need to duplicate the description in the ticket here but it
is sometimes worth providing a summary of the individual changes in this
PR.

Example:
- Add `NestedLoopJoinExec`.
- Speed up `HashJoinExec`.
- Route joins to `NestedLoopJoinExec` if the outer input is sufficiently
small.
-->

# Are these changes tested?
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code

If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?

Example:
Yes.
-->
Yes.
  • Loading branch information
iajoiner authored Dec 4, 2024
2 parents 3eb74b3 + 4f1fc32 commit 3ce6724
Show file tree
Hide file tree
Showing 15 changed files with 132 additions and 114 deletions.
14 changes: 14 additions & 0 deletions crates/proof-of-sql/src/sql/proof/first_round_builder.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use alloc::vec::Vec;
/// Track the result created by a query
pub struct FirstRoundBuilder {
/// The number of challenges used in the proof.
/// Specifically, these are the challenges that the verifier sends to
/// the prover after the prover sends the result, but before the prover
/// send commitments to the intermediate witness columns.
num_post_result_challenges: usize,
/// The extra one evaluation lengths used in the proof.
one_evaluation_lengths: Vec<usize>,
}

impl Default for FirstRoundBuilder {
Expand All @@ -17,9 +20,20 @@ impl FirstRoundBuilder {
pub fn new() -> Self {
Self {
num_post_result_challenges: 0,
one_evaluation_lengths: Vec::new(),
}
}

/// Get the one evaluation lengths used in the proof.
pub(crate) fn one_evaluation_lengths(&self) -> &[usize] {
&self.one_evaluation_lengths
}

/// Append the length to the list of one evaluation lengths.
pub(crate) fn produce_one_evaluation_length(&mut self, length: usize) {
self.one_evaluation_lengths.push(length);
}

/// The number of challenges used in the proof.
/// Specifically, these are the challenges that the verifier sends to
/// the prover after the prover sends the result, but before the prover
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/proof/proof_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub trait ProverEvaluate {
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>);
) -> Table<'a, S>;

/// Evaluate the query and modify `FinalRoundBuilder` to store an intermediate representation
/// of the query result and track all the components needed to form the query's proof.
Expand Down
8 changes: 4 additions & 4 deletions crates/proof-of-sql/src/sql/proof/query_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {

// Prover First Round: Evaluate the query && get the right number of post result challenges
let mut first_round_builder = FirstRoundBuilder::new();
let (query_result, one_evaluation_lengths) =
expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map);
let query_result = expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map);
let provable_result = query_result.into();
let one_evaluation_lengths = first_round_builder.one_evaluation_lengths();

let range_length = one_evaluation_lengths
.iter()
Expand All @@ -114,7 +114,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
&provable_result,
range_length,
min_row_num,
&one_evaluation_lengths,
one_evaluation_lengths,
);

// These are the challenges that will be consumed by the proof
Expand Down Expand Up @@ -191,7 +191,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {

let proof = Self {
bit_distributions: builder.bit_distributions().to_vec(),
one_evaluation_lengths,
one_evaluation_lengths: one_evaluation_lengths.to_vec(),
commitments,
sumcheck_proof,
pcs_proof_evaluations,
Expand Down
29 changes: 15 additions & 14 deletions crates/proof-of-sql/src/sql/proof/query_proof_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,13 @@ impl Default for TrivialTestProofPlan {
impl ProverEvaluate for TrivialTestProofPlan {
fn first_round_evaluate<'a, S: Scalar>(
&self,
_builder: &mut FirstRoundBuilder,
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
_table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
) -> Table<'a, S> {
let col = vec![self.column_fill_value; self.length];
(
table([borrowed_bigint("a1", col, alloc)]),
vec![self.length],
)
builder.produce_one_evaluation_length(self.length);
table([borrowed_bigint("a1", col, alloc)])
}

fn final_round_evaluate<'a, S: Scalar>(
Expand Down Expand Up @@ -226,11 +224,12 @@ impl Default for SquareTestProofPlan {
impl ProverEvaluate for SquareTestProofPlan {
fn first_round_evaluate<'a, S: Scalar>(
&self,
_builder: &mut FirstRoundBuilder,
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
_table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
(table([borrowed_bigint("a1", self.res, alloc)]), vec![2])
) -> Table<'a, S> {
builder.produce_one_evaluation_length(2);
table([borrowed_bigint("a1", self.res, alloc)])
}

fn final_round_evaluate<'a, S: Scalar>(
Expand Down Expand Up @@ -409,11 +408,12 @@ impl Default for DoubleSquareTestProofPlan {
impl ProverEvaluate for DoubleSquareTestProofPlan {
fn first_round_evaluate<'a, S: Scalar>(
&self,
_builder: &mut FirstRoundBuilder,
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
_table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
(table([borrowed_bigint("a1", self.res, alloc)]), vec![2])
) -> Table<'a, S> {
builder.produce_one_evaluation_length(2);
table([borrowed_bigint("a1", self.res, alloc)])
}

fn final_round_evaluate<'a, S: Scalar>(
Expand Down Expand Up @@ -625,9 +625,10 @@ impl ProverEvaluate for ChallengeTestProofPlan {
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
_table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
) -> Table<'a, S> {
builder.request_post_result_challenges(2);
(table([borrowed_bigint("a1", [9, 25], alloc)]), vec![2])
builder.produce_one_evaluation_length(2);
table([borrowed_bigint("a1", [9, 25], alloc)])
}

fn final_round_evaluate<'a, S: Scalar>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,15 @@ pub(super) struct EmptyTestQueryExpr {
impl ProverEvaluate for EmptyTestQueryExpr {
fn first_round_evaluate<'a, S: Scalar>(
&self,
_builder: &mut FirstRoundBuilder,
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
_table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
) -> Table<'a, S> {
let zeros = vec![0_i64; self.length];
(
table_with_row_count(
(1..=self.columns).map(|i| borrowed_bigint(format!("a{i}"), zeros.clone(), alloc)),
self.length,
),
vec![self.length],
builder.produce_one_evaluation_length(self.length);
table_with_row_count(
(1..=self.columns).map(|i| borrowed_bigint(format!("a{i}"), zeros.clone(), alloc)),
self.length,
)
}

Expand Down
11 changes: 4 additions & 7 deletions crates/proof-of-sql/src/sql/proof_plans/empty_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
VerificationBuilder,
},
};
use alloc::{vec, vec::Vec};
use alloc::vec::Vec;
use bumpalo::Bump;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -74,13 +74,10 @@ impl ProverEvaluate for EmptyExec {
_builder: &mut FirstRoundBuilder,
_alloc: &'a Bump,
_table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
) -> Table<'a, S> {
// Create an empty table with one row
(
Table::<'a, S>::try_new_with_options(IndexMap::default(), TableOptions::new(Some(1)))
.unwrap(),
vec![],
)
Table::<'a, S>::try_new_with_options(IndexMap::default(), TableOptions::new(Some(1)))
.unwrap()
}

#[tracing::instrument(name = "EmptyExec::final_round_evaluate", level = "debug", skip_all)]
Expand Down
5 changes: 3 additions & 2 deletions crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl ProverEvaluate for FilterExec {
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
) -> Table<'a, S> {
let table = table_map
.get(&self.table.table_ref)
.expect("Table not found");
Expand Down Expand Up @@ -188,7 +188,8 @@ impl ProverEvaluate for FilterExec {
)
.expect("Failed to create table from iterator");
builder.request_post_result_challenges(2);
(res, vec![output_length])
builder.produce_one_evaluation_length(output_length);
res
}

#[tracing::instrument(name = "FilterExec::final_round_evaluate", level = "debug", skip_all)]
Expand Down
36 changes: 20 additions & 16 deletions crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,11 @@ fn we_can_get_an_empty_result_from_a_basic_filter_on_an_empty_table_using_first_
),
];
let first_round_builder = &mut FirstRoundBuilder::new();
let res: OwnedTable<Curve25519Scalar> = ProvableQueryResult::from(
expr.first_round_evaluate(first_round_builder, &alloc, &table_map)
.0,
)
let res: OwnedTable<Curve25519Scalar> = ProvableQueryResult::from(expr.first_round_evaluate(
first_round_builder,
&alloc,
&table_map,
))
.to_owned_table(fields)
.unwrap();
let expected: OwnedTable<Curve25519Scalar> = owned_table([
Expand Down Expand Up @@ -256,10 +257,11 @@ fn we_can_get_an_empty_result_from_a_basic_filter_using_first_round_evaluate() {
),
];
let first_round_builder = &mut FirstRoundBuilder::new();
let res: OwnedTable<Curve25519Scalar> = ProvableQueryResult::from(
expr.first_round_evaluate(first_round_builder, &alloc, &table_map)
.0,
)
let res: OwnedTable<Curve25519Scalar> = ProvableQueryResult::from(expr.first_round_evaluate(
first_round_builder,
&alloc,
&table_map,
))
.to_owned_table(fields)
.unwrap();
let expected: OwnedTable<Curve25519Scalar> = owned_table([
Expand Down Expand Up @@ -292,10 +294,11 @@ fn we_can_get_no_columns_from_a_basic_filter_with_no_selected_columns_using_firs
let expr = filter(cols_expr_plan(t, &[], &accessor), tab(t), where_clause);
let fields = &[];
let first_round_builder = &mut FirstRoundBuilder::new();
let res: OwnedTable<Curve25519Scalar> = ProvableQueryResult::from(
expr.first_round_evaluate(first_round_builder, &alloc, &table_map)
.0,
)
let res: OwnedTable<Curve25519Scalar> = ProvableQueryResult::from(expr.first_round_evaluate(
first_round_builder,
&alloc,
&table_map,
))
.to_owned_table(fields)
.unwrap();
let expected = OwnedTable::try_new(IndexMap::default()).unwrap();
Expand Down Expand Up @@ -334,10 +337,11 @@ fn we_can_get_the_correct_result_from_a_basic_filter_using_first_round_evaluate(
),
];
let first_round_builder = &mut FirstRoundBuilder::new();
let res: OwnedTable<Curve25519Scalar> = ProvableQueryResult::from(
expr.first_round_evaluate(first_round_builder, &alloc, &table_map)
.0,
)
let res: OwnedTable<Curve25519Scalar> = ProvableQueryResult::from(expr.first_round_evaluate(
first_round_builder,
&alloc,
&table_map,
))
.to_owned_table(fields)
.unwrap();
let expected: OwnedTable<Curve25519Scalar> = owned_table([
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl ProverEvaluate for DishonestFilterExec {
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
) -> Table<'a, S> {
let table = table_map
.get(&self.table.table_ref)
.expect("Table not found");
Expand Down Expand Up @@ -67,7 +67,8 @@ impl ProverEvaluate for DishonestFilterExec {
)
.expect("Failed to create table from iterator");
builder.request_post_result_challenges(2);
(res, vec![output_length])
builder.produce_one_evaluation_length(output_length);
res
}

#[tracing::instrument(
Expand Down
5 changes: 3 additions & 2 deletions crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ impl ProverEvaluate for GroupByExec {
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
) -> Table<'a, S> {
let table = table_map
.get(&self.table.table_ref)
.expect("Table not found");
Expand Down Expand Up @@ -257,7 +257,8 @@ impl ProverEvaluate for GroupByExec {
)
.expect("Failed to create table from column references");
builder.request_post_result_challenges(2);
(res, vec![count_column.len()])
builder.produce_one_evaluation_length(count_column.len());
res
}

#[tracing::instrument(name = "GroupByExec::final_round_evaluate", level = "debug", skip_all)]
Expand Down
25 changes: 11 additions & 14 deletions crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
proof_exprs::{AliasedDynProofExpr, ProofExpr, TableExpr},
},
};
use alloc::{vec, vec::Vec};
use alloc::vec::Vec;
use bumpalo::Bump;
use core::iter::repeat_with;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -106,23 +106,20 @@ impl ProverEvaluate for ProjectionExec {
_builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
) -> Table<'a, S> {
let table = table_map
.get(&self.table.table_ref)
.expect("Table not found");
(
Table::<'a, S>::try_from_iter_with_options(
self.aliased_results.iter().map(|aliased_expr| {
(
aliased_expr.alias,
aliased_expr.expr.result_evaluate(alloc, table),
)
}),
TableOptions::new(Some(table.num_rows())),
)
.expect("Failed to create table from iterator"),
vec![],
Table::<'a, S>::try_from_iter_with_options(
self.aliased_results.iter().map(|aliased_expr| {
(
aliased_expr.alias,
aliased_expr.expr.result_evaluate(alloc, table),
)
}),
TableOptions::new(Some(table.num_rows())),
)
.expect("Failed to create table from iterator")
}

#[tracing::instrument(
Expand Down
Loading

0 comments on commit 3ce6724

Please sign in to comment.