diff --git a/.github/workflows/typos.toml b/.github/workflows/typos.toml index cef8a2b..471d317 100644 --- a/.github/workflows/typos.toml +++ b/.github/workflows/typos.toml @@ -2,3 +2,4 @@ groth = "groth" # to avoid it dectecting it as 'growth' BA = "BA" Ded = "Ded" # "ANDed", it thought "Ded" should be "Dead" +OT = "OT" \ No newline at end of file diff --git a/src/backends/plonky2/mock_main/mod.rs b/src/backends/plonky2/mock_main/mod.rs index a708331..b557715 100644 --- a/src/backends/plonky2/mock_main/mod.rs +++ b/src/backends/plonky2/mock_main/mod.rs @@ -7,7 +7,7 @@ use std::fmt; use crate::middleware::{ self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativePredicate, NonePod, - Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF, + OperationType, Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF, }; mod operation; @@ -261,7 +261,11 @@ impl MockMainPod { .map(|mid_arg| Self::find_op_arg(statements, mid_arg)) .collect::>>()?; Self::pad_operation_args(params, &mut args); - operations.push(Operation(op.code(), args)); + let op_code = match op.code() { + OperationType::Native(code) => code, + _ => unimplemented!(), + }; + operations.push(Operation(op_code, args)); } Ok(operations) } diff --git a/src/backends/plonky2/mock_main/operation.rs b/src/backends/plonky2/mock_main/operation.rs index cb5ff3a..f5ae1de 100644 --- a/src/backends/plonky2/mock_main/operation.rs +++ b/src/backends/plonky2/mock_main/operation.rs @@ -2,7 +2,7 @@ use anyhow::Result; use std::fmt; use super::Statement; -use crate::middleware::{self, NativeOperation}; +use crate::middleware::{self, NativeOperation, OperationType}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum OperationArg { @@ -29,7 +29,7 @@ impl Operation { OperationArg::Index(i) => Some(statements[*i].clone().try_into()), }) .collect::>>()?; - middleware::Operation::op(self.0, &deref_args) + middleware::Operation::op(OperationType::Native(self.0), &deref_args) } } diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index 8a051ff..5cdba95 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -6,7 +6,7 @@ use crate::middleware::{ Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, F, }; -/// Argument to an statement template +/// Argument to a statement template pub enum HashOrWildcardStr { Hash(Hash), // represents a literal key Wildcard(String), diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 2a66070..249845f 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -13,6 +13,7 @@ use crate::middleware::{ hash_str, Hash, MainPodInputs, NativeOperation, NativePredicate, Params, PodId, PodProver, PodSigner, SELF, }; +use crate::middleware::{OperationType, Predicate}; mod custom; mod operation; @@ -254,7 +255,7 @@ impl MainPodBuilder { for arg in args.iter_mut() { match arg { OperationArg::Statement(s) => { - if s.0 == NativePredicate::ValueOf { + if s.0 == Predicate::Native(NativePredicate::ValueOf) { st_args.push(s.1[0].clone()) } else { panic!("Invalid statement argument."); @@ -266,7 +267,7 @@ impl MainPodBuilder { let value_of_st = self.op( public, Operation( - NativeOperation::NewEntry, + OperationType::Native(NativeOperation::NewEntry), vec![OperationArg::Entry(k.clone(), v.clone())], ), ); @@ -291,36 +292,49 @@ impl MainPodBuilder { pub fn op(&mut self, public: bool, mut op: Operation) -> Statement { use NativeOperation::*; - let Operation(op_type, ref mut args) = op; + let Operation(op_type, ref mut args) = &mut op; // TODO: argument type checking let st = match op_type { - None => Statement(NativePredicate::None, vec![]), - NewEntry => Statement(NativePredicate::ValueOf, self.op_args_entries(public, args)), - CopyStatement => todo!(), - EqualFromEntries => { - Statement(NativePredicate::Equal, self.op_args_entries(public, args)) - } - NotEqualFromEntries => Statement( - NativePredicate::NotEqual, - self.op_args_entries(public, args), - ), - GtFromEntries => Statement(NativePredicate::Gt, self.op_args_entries(public, args)), - LtFromEntries => Statement(NativePredicate::Lt, self.op_args_entries(public, args)), - TransitiveEqualFromStatements => todo!(), - GtToNotEqual => todo!(), - LtToNotEqual => todo!(), - ContainsFromEntries => Statement( - NativePredicate::Contains, - self.op_args_entries(public, args), - ), - NotContainsFromEntries => Statement( - NativePredicate::NotContains, - self.op_args_entries(public, args), - ), - RenameContainedBy => todo!(), - SumOf => todo!(), - ProductOf => todo!(), - MaxOf => todo!(), + OperationType::Native(o) => match o { + None => Statement(Predicate::Native(NativePredicate::None), vec![]), + NewEntry => Statement( + Predicate::Native(NativePredicate::ValueOf), + self.op_args_entries(public, args), + ), + CopyStatement => todo!(), + EqualFromEntries => Statement( + Predicate::Native(NativePredicate::Equal), + self.op_args_entries(public, args), + ), + NotEqualFromEntries => Statement( + Predicate::Native(NativePredicate::NotEqual), + self.op_args_entries(public, args), + ), + GtFromEntries => Statement( + Predicate::Native(NativePredicate::Gt), + self.op_args_entries(public, args), + ), + LtFromEntries => Statement( + Predicate::Native(NativePredicate::Lt), + self.op_args_entries(public, args), + ), + TransitiveEqualFromStatements => todo!(), + GtToNotEqual => todo!(), + LtToNotEqual => todo!(), + ContainsFromEntries => Statement( + Predicate::Native(NativePredicate::Contains), + self.op_args_entries(public, args), + ), + NotContainsFromEntries => Statement( + Predicate::Native(NativePredicate::NotContains), + self.op_args_entries(public, args), + ), + RenameContainedBy => todo!(), + SumOf => todo!(), + ProductOf => todo!(), + MaxOf => todo!(), + }, + _ => todo!(), }; self.operations.push(op); if public { @@ -440,7 +454,7 @@ impl MainPodCompiler { fn compile_op(&self, op: &Operation) -> middleware::Operation { // TODO - let mop_code: middleware::NativeOperation = op.0.into(); + let mop_code: OperationType = op.0.clone(); let mop_args = op.1.iter() .flat_map(|arg| self.compile_op_arg(arg).map(|s| s.try_into().unwrap())) @@ -496,22 +510,22 @@ pub mod build_utils { #[macro_export] macro_rules! op { (eq, $($arg:expr),+) => { crate::frontend::Operation( - crate::middleware::NativeOperation::EqualFromEntries, + crate::middleware::OperationType::Native(crate::middleware::NativeOperation::EqualFromEntries), crate::op_args!($($arg),*)) }; (ne, $($arg:expr),+) => { crate::frontend::Operation( - crate::middleware::NativeOperation::NotEqualFromEntries, + crate::middleware::OperationType::Native(crate::middleware::NativeOperation::NotEqualFromEntries), crate::op_args!($($arg),*)) }; (gt, $($arg:expr),+) => { crate::frontend::Operation( - crate::middleware::NativeOperation::GtFromEntries, + crate::middleware::OperationType::Native(crate::middleware::NativeOperation::GtFromEntries), crate::op_args!($($arg),*)) }; (lt, $($arg:expr),+) => { crate::frontend::Operation( - crate::middleware::NativeOperation::LtFromEntries, + crate::middleware::OperationType::Native(crate::middleware::NativeOperation::LtFromEntries), crate::op_args!($($arg),*)) }; (contains, $($arg:expr),+) => { crate::frontend::Operation( - crate::middleware::NativeOperation::ContainsFromEntries, + crate::middleware::OperationType::Native(crate::middleware::NativeOperation::ContainsFromEntries), crate::op_args!($($arg),*)) }; (not_contains, $($arg:expr),+) => { crate::frontend::Operation( - crate::middleware::NativeOperation::NotContainsFromEntries, + crate::middleware::OperationType::Native(crate::middleware::NativeOperation::NotContainsFromEntries), crate::op_args!($($arg),*)) }; } } diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs index 57d6f4f..ef8114f 100644 --- a/src/frontend/operation.rs +++ b/src/frontend/operation.rs @@ -1,7 +1,7 @@ use std::fmt; use super::{AnchoredKey, SignedPod, Statement, StatementArg, Value}; -use crate::middleware::{hash_str, NativeOperation, NativePredicate}; +use crate::middleware::{hash_str, NativeOperation, NativePredicate, OperationType, Predicate}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum OperationArg { @@ -55,7 +55,7 @@ impl From<(&SignedPod, &str)> for OperationArg { // TODO: Actual value, TryFrom. let value = pod.kvs().get(&hash_str(key)).unwrap().clone(); Self::Statement(Statement( - NativePredicate::ValueOf, + Predicate::Native(NativePredicate::ValueOf), vec![ StatementArg::Key(AnchoredKey(pod.origin(), key.to_string())), StatementArg::Literal(Value::Raw(value)), @@ -65,7 +65,7 @@ impl From<(&SignedPod, &str)> for OperationArg { } #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Operation(pub NativeOperation, pub Vec); +pub struct Operation(pub OperationType, pub Vec); impl fmt::Display for Operation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/src/frontend/statement.rs b/src/frontend/statement.rs index 59a75e2..c0dfc25 100644 --- a/src/frontend/statement.rs +++ b/src/frontend/statement.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, Result}; use std::fmt; use super::{AnchoredKey, Value}; -use crate::middleware::{self, NativePredicate}; +use crate::middleware::{self, NativePredicate, Predicate}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum StatementArg { @@ -20,7 +20,7 @@ impl fmt::Display for StatementArg { } #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Statement(pub NativePredicate, pub Vec); +pub struct Statement(pub Predicate, pub Vec); impl TryFrom for middleware::Statement { type Error = anyhow::Error; @@ -33,38 +33,50 @@ impl TryFrom for middleware::Statement { s.1.get(1).cloned(), s.1.get(2).cloned(), ); - Ok(match (s.0, args) { - (NP::None, (None, None, None)) => MS::None, - (NP::ValueOf, (Some(SA::Key(ak)), Some(StatementArg::Literal(v)), None)) => { - MS::ValueOf(ak.into(), (&v).into()) - } - (NP::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { - MS::Equal(ak1.into(), ak2.into()) - } - (NP::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { - MS::NotEqual(ak1.into(), ak2.into()) - } - (NP::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { - MS::Gt(ak1.into(), ak2.into()) - } - (NP::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { - MS::Lt(ak1.into(), ak2.into()) - } - (NP::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { - MS::Contains(ak1.into(), ak2.into()) - } - (NP::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { - MS::NotContains(ak1.into(), ak2.into()) - } - (NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { - MS::SumOf(ak1.into(), ak2.into(), ak3.into()) - } - (NP::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { - MS::ProductOf(ak1.into(), ak2.into(), ak3.into()) - } - (NP::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { - MS::MaxOf(ak1.into(), ak2.into(), ak3.into()) - } + Ok(match &s.0 { + Predicate::Native(np) => match (np, args) { + (NP::None, (None, None, None)) => MS::None, + (NP::ValueOf, (Some(SA::Key(ak)), Some(StatementArg::Literal(v)), None)) => { + MS::ValueOf(ak.into(), (&v).into()) + } + (NP::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + MS::Equal(ak1.into(), ak2.into()) + } + (NP::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + MS::NotEqual(ak1.into(), ak2.into()) + } + (NP::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + MS::Gt(ak1.into(), ak2.into()) + } + (NP::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + MS::Lt(ak1.into(), ak2.into()) + } + (NP::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + MS::Contains(ak1.into(), ak2.into()) + } + (NP::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + MS::NotContains(ak1.into(), ak2.into()) + } + (NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { + MS::SumOf(ak1.into(), ak2.into(), ak3.into()) + } + (NP::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { + MS::ProductOf(ak1.into(), ak2.into(), ak3.into()) + } + (NP::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { + MS::MaxOf(ak1.into(), ak2.into(), ak3.into()) + } + _ => Err(anyhow!("Ill-formed statement: {}", s))?, + }, + Predicate::Custom(cpr) => MS::Custom( + cpr.clone(), + s.1.iter() + .map(|arg| match arg { + StatementArg::Key(ak) => Ok(ak.clone().into()), + _ => Err(anyhow!("Invalid statement arg: {}", arg)), + }) + .collect::>>()?, + ), _ => Err(anyhow!("Ill-formed statement: {}", s))?, }) } diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index cf08705..dd22cc3 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -8,6 +8,12 @@ use crate::{ util::hashmap_insert_no_dupe, }; +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum OperationType { + Native(NativeOperation), + Custom(CustomPredicateRef), +} + #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum NativeOperation { None = 0, @@ -51,26 +57,27 @@ pub enum Operation { } impl Operation { - pub fn code(&self) -> NativeOperation { + pub fn code(&self) -> OperationType { + type OT = OperationType; use NativeOperation::*; match self { - Self::None => None, - Self::NewEntry => NewEntry, - Self::CopyStatement(_) => CopyStatement, - Self::EqualFromEntries(_, _) => EqualFromEntries, - Self::NotEqualFromEntries(_, _) => NotEqualFromEntries, - Self::GtFromEntries(_, _) => GtFromEntries, - Self::LtFromEntries(_, _) => LtFromEntries, - Self::TransitiveEqualFromStatements(_, _) => TransitiveEqualFromStatements, - Self::GtToNotEqual(_) => GtToNotEqual, - Self::LtToNotEqual(_) => LtToNotEqual, - Self::ContainsFromEntries(_, _) => ContainsFromEntries, - Self::NotContainsFromEntries(_, _) => NotContainsFromEntries, - Self::RenameContainedBy(_, _) => RenameContainedBy, - Self::SumOf(_, _, _) => SumOf, - Self::ProductOf(_, _, _) => ProductOf, - Self::MaxOf(_, _, _) => MaxOf, - Self::Custom(_, _) => todo!(), + Self::None => OT::Native(None), + Self::NewEntry => OT::Native(NewEntry), + Self::CopyStatement(_) => OT::Native(CopyStatement), + Self::EqualFromEntries(_, _) => OT::Native(EqualFromEntries), + Self::NotEqualFromEntries(_, _) => OT::Native(NotEqualFromEntries), + Self::GtFromEntries(_, _) => OT::Native(GtFromEntries), + Self::LtFromEntries(_, _) => OT::Native(LtFromEntries), + Self::TransitiveEqualFromStatements(_, _) => OT::Native(TransitiveEqualFromStatements), + Self::GtToNotEqual(_) => OT::Native(GtToNotEqual), + Self::LtToNotEqual(_) => OT::Native(LtToNotEqual), + Self::ContainsFromEntries(_, _) => OT::Native(ContainsFromEntries), + Self::NotContainsFromEntries(_, _) => OT::Native(NotContainsFromEntries), + Self::RenameContainedBy(_, _) => OT::Native(RenameContainedBy), + Self::SumOf(_, _, _) => OT::Native(SumOf), + Self::ProductOf(_, _, _) => OT::Native(ProductOf), + Self::MaxOf(_, _, _) => OT::Native(MaxOf), + Self::Custom(cpr, _) => OT::Custom(cpr.clone()), } } @@ -96,40 +103,45 @@ impl Operation { } } /// Forms operation from op-code and arguments. - pub fn op(op_code: NativeOperation, args: &[Statement]) -> Result { + pub fn op(op_code: OperationType, args: &[Statement]) -> Result { type NO = NativeOperation; let arg_tup = ( args.get(0).cloned(), args.get(1).cloned(), args.get(2).cloned(), ); - Ok(match (op_code, arg_tup, args.len()) { - (NO::None, (None, None, None), 0) => Self::None, - (NO::NewEntry, (None, None, None), 0) => Self::NewEntry, - (NO::CopyStatement, (Some(s), None, None), 1) => Self::CopyStatement(s), - (NO::EqualFromEntries, (Some(s1), Some(s2), None), 2) => Self::EqualFromEntries(s1, s2), - (NO::NotEqualFromEntries, (Some(s1), Some(s2), None), 2) => { - Self::NotEqualFromEntries(s1, s2) - } - (NO::GtFromEntries, (Some(s1), Some(s2), None), 2) => Self::GtFromEntries(s1, s2), - (NO::LtFromEntries, (Some(s1), Some(s2), None), 2) => Self::LtFromEntries(s1, s2), - (NO::ContainsFromEntries, (Some(s1), Some(s2), None), 2) => { - Self::ContainsFromEntries(s1, s2) - } - (NO::NotContainsFromEntries, (Some(s1), Some(s2), None), 2) => { - Self::NotContainsFromEntries(s1, s2) - } - (NO::RenameContainedBy, (Some(s1), Some(s2), None), 2) => { - Self::RenameContainedBy(s1, s2) - } - (NO::SumOf, (Some(s1), Some(s2), Some(s3)), 3) => Self::SumOf(s1, s2, s3), - (NO::ProductOf, (Some(s1), Some(s2), Some(s3)), 3) => Self::ProductOf(s1, s2, s3), - (NO::MaxOf, (Some(s1), Some(s2), Some(s3)), 3) => Self::MaxOf(s1, s2, s3), - _ => Err(anyhow!( - "Ill-formed operation {:?} with arguments {:?}.", - op_code, - args - ))?, + Ok(match op_code { + OperationType::Native(o) => match (o, arg_tup, args.len()) { + (NO::None, (None, None, None), 0) => Self::None, + (NO::NewEntry, (None, None, None), 0) => Self::NewEntry, + (NO::CopyStatement, (Some(s), None, None), 1) => Self::CopyStatement(s), + (NO::EqualFromEntries, (Some(s1), Some(s2), None), 2) => { + Self::EqualFromEntries(s1, s2) + } + (NO::NotEqualFromEntries, (Some(s1), Some(s2), None), 2) => { + Self::NotEqualFromEntries(s1, s2) + } + (NO::GtFromEntries, (Some(s1), Some(s2), None), 2) => Self::GtFromEntries(s1, s2), + (NO::LtFromEntries, (Some(s1), Some(s2), None), 2) => Self::LtFromEntries(s1, s2), + (NO::ContainsFromEntries, (Some(s1), Some(s2), None), 2) => { + Self::ContainsFromEntries(s1, s2) + } + (NO::NotContainsFromEntries, (Some(s1), Some(s2), None), 2) => { + Self::NotContainsFromEntries(s1, s2) + } + (NO::RenameContainedBy, (Some(s1), Some(s2), None), 2) => { + Self::RenameContainedBy(s1, s2) + } + (NO::SumOf, (Some(s1), Some(s2), Some(s3)), 3) => Self::SumOf(s1, s2, s3), + (NO::ProductOf, (Some(s1), Some(s2), Some(s3)), 3) => Self::ProductOf(s1, s2, s3), + (NO::MaxOf, (Some(s1), Some(s2), Some(s3)), 3) => Self::MaxOf(s1, s2, s3), + _ => Err(anyhow!( + "Ill-formed operation {:?} with arguments {:?}.", + op_code, + args + ))?, + }, + OperationType::Custom(cpr) => Self::Custom(cpr, args.to_vec()), }) } /// Checks the given operation against a statement.