Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: custom predicates in frontend statement and operation types #97

Merged
merged 3 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
groth = "groth" # to avoid it dectecting it as 'growth'
BA = "BA"
Ded = "Ded" # "ANDed", it thought "Ded" should be "Dead"
OT = "OT"
8 changes: 6 additions & 2 deletions src/backends/plonky2/mock_main/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::fmt;

use crate::middleware::{
self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativePredicate, NonePod,
Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF,
OperationType, Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF,
};

mod operation;
Expand Down Expand Up @@ -261,7 +261,11 @@ impl MockMainPod {
.map(|mid_arg| Self::find_op_arg(statements, mid_arg))
.collect::<Result<Vec<_>>>()?;
Self::pad_operation_args(params, &mut args);
operations.push(Operation(op.code(), args));
let op_code = match op.code() {
OperationType::Native(code) => code,
_ => unimplemented!(),
};
operations.push(Operation(op_code, args));
}
Ok(operations)
}
Expand Down
4 changes: 2 additions & 2 deletions src/backends/plonky2/mock_main/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use anyhow::Result;
use std::fmt;

use super::Statement;
use crate::middleware::{self, NativeOperation};
use crate::middleware::{self, NativeOperation, OperationType};

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum OperationArg {
Expand All @@ -29,7 +29,7 @@ impl Operation {
OperationArg::Index(i) => Some(statements[*i].clone().try_into()),
})
.collect::<Result<Vec<crate::middleware::Statement>>>()?;
middleware::Operation::op(self.0, &deref_args)
middleware::Operation::op(OperationType::Native(self.0), &deref_args)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::middleware::{
Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, F,
};

/// Argument to an statement template
/// Argument to a statement template
pub enum HashOrWildcardStr {
Hash(Hash), // represents a literal key
Wildcard(String),
Expand Down
88 changes: 51 additions & 37 deletions src/frontend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::middleware::{
hash_str, Hash, MainPodInputs, NativeOperation, NativePredicate, Params, PodId, PodProver,
PodSigner, SELF,
};
use crate::middleware::{OperationType, Predicate};

mod custom;
mod operation;
Expand Down Expand Up @@ -254,7 +255,7 @@ impl MainPodBuilder {
for arg in args.iter_mut() {
match arg {
OperationArg::Statement(s) => {
if s.0 == NativePredicate::ValueOf {
if s.0 == Predicate::Native(NativePredicate::ValueOf) {
st_args.push(s.1[0].clone())
} else {
panic!("Invalid statement argument.");
Expand All @@ -266,7 +267,7 @@ impl MainPodBuilder {
let value_of_st = self.op(
public,
Operation(
NativeOperation::NewEntry,
OperationType::Native(NativeOperation::NewEntry),
vec![OperationArg::Entry(k.clone(), v.clone())],
),
);
Expand All @@ -291,36 +292,49 @@ impl MainPodBuilder {

pub fn op(&mut self, public: bool, mut op: Operation) -> Statement {
use NativeOperation::*;
let Operation(op_type, ref mut args) = op;
let Operation(op_type, ref mut args) = &mut op;
// TODO: argument type checking
let st = match op_type {
None => Statement(NativePredicate::None, vec![]),
NewEntry => Statement(NativePredicate::ValueOf, self.op_args_entries(public, args)),
CopyStatement => todo!(),
EqualFromEntries => {
Statement(NativePredicate::Equal, self.op_args_entries(public, args))
}
NotEqualFromEntries => Statement(
NativePredicate::NotEqual,
self.op_args_entries(public, args),
),
GtFromEntries => Statement(NativePredicate::Gt, self.op_args_entries(public, args)),
LtFromEntries => Statement(NativePredicate::Lt, self.op_args_entries(public, args)),
TransitiveEqualFromStatements => todo!(),
GtToNotEqual => todo!(),
LtToNotEqual => todo!(),
ContainsFromEntries => Statement(
NativePredicate::Contains,
self.op_args_entries(public, args),
),
NotContainsFromEntries => Statement(
NativePredicate::NotContains,
self.op_args_entries(public, args),
),
RenameContainedBy => todo!(),
SumOf => todo!(),
ProductOf => todo!(),
MaxOf => todo!(),
OperationType::Native(o) => match o {
None => Statement(Predicate::Native(NativePredicate::None), vec![]),
NewEntry => Statement(
Predicate::Native(NativePredicate::ValueOf),
self.op_args_entries(public, args),
),
CopyStatement => todo!(),
EqualFromEntries => Statement(
Predicate::Native(NativePredicate::Equal),
self.op_args_entries(public, args),
),
NotEqualFromEntries => Statement(
Predicate::Native(NativePredicate::NotEqual),
self.op_args_entries(public, args),
),
GtFromEntries => Statement(
Predicate::Native(NativePredicate::Gt),
self.op_args_entries(public, args),
),
LtFromEntries => Statement(
Predicate::Native(NativePredicate::Lt),
self.op_args_entries(public, args),
),
TransitiveEqualFromStatements => todo!(),
GtToNotEqual => todo!(),
LtToNotEqual => todo!(),
ContainsFromEntries => Statement(
Predicate::Native(NativePredicate::Contains),
self.op_args_entries(public, args),
),
NotContainsFromEntries => Statement(
Predicate::Native(NativePredicate::NotContains),
self.op_args_entries(public, args),
),
RenameContainedBy => todo!(),
SumOf => todo!(),
ProductOf => todo!(),
MaxOf => todo!(),
},
_ => todo!(),
};
self.operations.push(op);
if public {
Expand Down Expand Up @@ -440,7 +454,7 @@ impl MainPodCompiler {

fn compile_op(&self, op: &Operation) -> middleware::Operation {
// TODO
let mop_code: middleware::NativeOperation = op.0.into();
let mop_code: OperationType = op.0.clone();
let mop_args =
op.1.iter()
.flat_map(|arg| self.compile_op_arg(arg).map(|s| s.try_into().unwrap()))
Expand Down Expand Up @@ -496,22 +510,22 @@ pub mod build_utils {
#[macro_export]
macro_rules! op {
(eq, $($arg:expr),+) => { crate::frontend::Operation(
crate::middleware::NativeOperation::EqualFromEntries,
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::EqualFromEntries),
crate::op_args!($($arg),*)) };
(ne, $($arg:expr),+) => { crate::frontend::Operation(
crate::middleware::NativeOperation::NotEqualFromEntries,
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::NotEqualFromEntries),
crate::op_args!($($arg),*)) };
(gt, $($arg:expr),+) => { crate::frontend::Operation(
crate::middleware::NativeOperation::GtFromEntries,
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::GtFromEntries),
crate::op_args!($($arg),*)) };
(lt, $($arg:expr),+) => { crate::frontend::Operation(
crate::middleware::NativeOperation::LtFromEntries,
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::LtFromEntries),
crate::op_args!($($arg),*)) };
(contains, $($arg:expr),+) => { crate::frontend::Operation(
crate::middleware::NativeOperation::ContainsFromEntries,
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::ContainsFromEntries),
crate::op_args!($($arg),*)) };
(not_contains, $($arg:expr),+) => { crate::frontend::Operation(
crate::middleware::NativeOperation::NotContainsFromEntries,
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::NotContainsFromEntries),
crate::op_args!($($arg),*)) };
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/frontend/operation.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt;

use super::{AnchoredKey, SignedPod, Statement, StatementArg, Value};
use crate::middleware::{hash_str, NativeOperation, NativePredicate};
use crate::middleware::{hash_str, NativeOperation, NativePredicate, OperationType, Predicate};

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum OperationArg {
Expand Down Expand Up @@ -55,7 +55,7 @@ impl From<(&SignedPod, &str)> for OperationArg {
// TODO: Actual value, TryFrom.
let value = pod.kvs().get(&hash_str(key)).unwrap().clone();
Self::Statement(Statement(
NativePredicate::ValueOf,
Predicate::Native(NativePredicate::ValueOf),
vec![
StatementArg::Key(AnchoredKey(pod.origin(), key.to_string())),
StatementArg::Literal(Value::Raw(value)),
Expand All @@ -65,7 +65,7 @@ impl From<(&SignedPod, &str)> for OperationArg {
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Operation(pub NativeOperation, pub Vec<OperationArg>);
pub struct Operation(pub OperationType, pub Vec<OperationArg>);

impl fmt::Display for Operation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down
80 changes: 46 additions & 34 deletions src/frontend/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use std::fmt;

use super::{AnchoredKey, Value};
use crate::middleware::{self, NativePredicate};
use crate::middleware::{self, NativePredicate, Predicate};

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum StatementArg {
Expand All @@ -20,7 +20,7 @@ impl fmt::Display for StatementArg {
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Statement(pub NativePredicate, pub Vec<StatementArg>);
pub struct Statement(pub Predicate, pub Vec<StatementArg>);

impl TryFrom<Statement> for middleware::Statement {
type Error = anyhow::Error;
Expand All @@ -33,38 +33,50 @@ impl TryFrom<Statement> for middleware::Statement {
s.1.get(1).cloned(),
s.1.get(2).cloned(),
);
Ok(match (s.0, args) {
(NP::None, (None, None, None)) => MS::None,
(NP::ValueOf, (Some(SA::Key(ak)), Some(StatementArg::Literal(v)), None)) => {
MS::ValueOf(ak.into(), (&v).into())
}
(NP::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::Equal(ak1.into(), ak2.into())
}
(NP::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::NotEqual(ak1.into(), ak2.into())
}
(NP::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::Gt(ak1.into(), ak2.into())
}
(NP::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::Lt(ak1.into(), ak2.into())
}
(NP::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::Contains(ak1.into(), ak2.into())
}
(NP::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::NotContains(ak1.into(), ak2.into())
}
(NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
MS::SumOf(ak1.into(), ak2.into(), ak3.into())
}
(NP::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
MS::ProductOf(ak1.into(), ak2.into(), ak3.into())
}
(NP::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
MS::MaxOf(ak1.into(), ak2.into(), ak3.into())
}
Ok(match &s.0 {
Predicate::Native(np) => match (np, args) {
(NP::None, (None, None, None)) => MS::None,
(NP::ValueOf, (Some(SA::Key(ak)), Some(StatementArg::Literal(v)), None)) => {
MS::ValueOf(ak.into(), (&v).into())
}
(NP::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::Equal(ak1.into(), ak2.into())
}
(NP::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::NotEqual(ak1.into(), ak2.into())
}
(NP::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::Gt(ak1.into(), ak2.into())
}
(NP::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::Lt(ak1.into(), ak2.into())
}
(NP::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::Contains(ak1.into(), ak2.into())
}
(NP::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::NotContains(ak1.into(), ak2.into())
}
(NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
MS::SumOf(ak1.into(), ak2.into(), ak3.into())
}
(NP::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
MS::ProductOf(ak1.into(), ak2.into(), ak3.into())
}
(NP::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
MS::MaxOf(ak1.into(), ak2.into(), ak3.into())
}
_ => Err(anyhow!("Ill-formed statement: {}", s))?,
},
Predicate::Custom(cpr) => MS::Custom(
cpr.clone(),
s.1.iter()
.map(|arg| match arg {
StatementArg::Key(ak) => Ok(ak.clone().into()),
_ => Err(anyhow!("Invalid statement arg: {}", arg)),
})
.collect::<Result<Vec<_>>>()?,
),
_ => Err(anyhow!("Ill-formed statement: {}", s))?,
})
}
Expand Down
Loading
Loading