Skip to content

Commit

Permalink
Add custom op enum variant and wildcard matching procedures
Browse files Browse the repository at this point in the history
  • Loading branch information
ax0 committed Feb 25, 2025
1 parent f98f297 commit 01f93be
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 7 deletions.
68 changes: 66 additions & 2 deletions src/middleware/custom.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use std::sync::Arc;
use std::{fmt, hash as h};
use std::{fmt, hash as h, iter::zip};

use super::{hash_str, Hash, NativePredicate, ToFields, Value, F};
use anyhow::{anyhow, Result};

use super::{
hash_str, AnchoredKey, Hash, NativePredicate, PodId, Statement, StatementArg, ToFields, Value,
F,
};

// BEGIN Custom 1b

Expand All @@ -11,6 +16,19 @@ pub enum HashOrWildcard {
Wildcard(usize),
}

impl HashOrWildcard {
/// Matches a hash or wildcard against a value, returning a pair
/// representing a wildcard binding (if any) or an error if no
/// match is possible.
pub fn match_against(&self, v: &Value) -> Result<Option<(usize, Value)>> {
match self {
HashOrWildcard::Hash(h) if &Value::from(h.clone()) == v => Ok(None),
HashOrWildcard::Wildcard(i) => Ok(Some((*i, v.clone()))),
_ => Err(anyhow!("Failed to match {} against {}.", self, v)),
}
}
}

impl fmt::Display for HashOrWildcard {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Expand All @@ -27,6 +45,25 @@ pub enum StatementTmplArg {
Key(HashOrWildcard, HashOrWildcard),
}

impl StatementTmplArg {
/// Matches a statement template argument against a statement
/// argument, returning a wildcard correspondence in the case of
/// one or more wildcard matches, nothing in the case of a
/// literal/hash match, and an error otherwise.
pub fn match_against(&self, s_arg: &StatementArg) -> Result<Vec<(usize, Value)>> {
match (self, s_arg) {
(Self::None, StatementArg::None) => Ok(vec![]),
(Self::Literal(v), StatementArg::Literal(w)) if v == w => Ok(vec![]),
(Self::Key(tmpl_o, tmpl_k), StatementArg::Key(AnchoredKey(PodId(o), k))) => {
let o_corr = tmpl_o.match_against(&o.clone().into())?;
let k_corr = tmpl_k.match_against(&k.clone().into())?;
Ok([o_corr, k_corr].into_iter().flat_map(|x| x).collect())
}
_ => Err(anyhow!("Failed to match {} against {}.", self, s_arg)),
}
}
}

impl fmt::Display for StatementTmplArg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Expand All @@ -53,6 +90,33 @@ impl fmt::Display for StatementTmplArg {
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct StatementTmpl(pub Predicate, pub Vec<StatementTmplArg>);

impl StatementTmpl {
pub fn pred(&self) -> &Predicate {
&self.0
}
pub fn args(&self) -> &[StatementTmplArg] {
&self.1
}
/// Matches a statement template against a statement, returning
/// the variable bindings as an association list. Returns an error
/// if there is type or argument mismatch.
pub fn match_against(&self, s: &Statement) -> Result<Vec<(usize, Value)>> {
type P = Predicate;
if matches!(self, Self(P::BatchSelf(_), _)) {
Err(anyhow!(
"Cannot check self-referencing statement templates."
))
} else if self.pred() != &s.code() {
Err(anyhow!("Type mismatch between {:?} and {}.", self, s))
} else {
zip(self.args(), s.args())
.map(|(t_arg, s_arg)| t_arg.match_against(&s_arg))
.collect::<Result<Vec<_>>>()
.map(|v| v.concat())
}
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct CustomPredicate {
/// true for "and", false for "or"
Expand Down
9 changes: 8 additions & 1 deletion src/middleware/operation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use anyhow::{anyhow, Result};

use super::Statement;
use super::{CustomPredicateRef, Statement};
use crate::middleware::{AnchoredKey, SELF};

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -42,6 +42,7 @@ pub enum Operation {
SumOf(Statement, Statement, Statement),
ProductOf(Statement, Statement, Statement),
MaxOf(Statement, Statement, Statement),
Custom(CustomPredicateRef, Vec<Statement>),
}

impl Operation {
Expand All @@ -64,6 +65,7 @@ impl Operation {
Self::SumOf(_, _, _) => SumOf,
Self::ProductOf(_, _, _) => ProductOf,
Self::MaxOf(_, _, _) => MaxOf,
Self::Custom(_, _) => todo!(),
}
}

Expand All @@ -85,6 +87,7 @@ impl Operation {
Self::SumOf(s1, s2, s3) => vec![s1, s2, s3],
Self::ProductOf(s1, s2, s3) => vec![s1, s2, s3],
Self::MaxOf(s1, s2, s3) => vec![s1, s2, s3],
Self::Custom(_, args) => args,
}
}
/// Forms operation from op-code and arguments.
Expand Down Expand Up @@ -171,6 +174,10 @@ impl Operation {
let v3: i64 = v3.clone().try_into()?;
Ok((v1 == v2 + v3) && ak4 == ak1 && ak5 == ak2 && ak6 == ak3)
}
(
Self::Custom(CustomPredicateRef(cpb, i), _args),
Custom(CustomPredicateRef(s_cpb, s_i), _s_args),
) if cpb == s_cpb && i == s_i => todo!(),
_ => Err(anyhow!(
"Invalid deduction: {:?} ⇏ {:#}",
self,
Expand Down
7 changes: 3 additions & 4 deletions src/middleware/statement.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use anyhow::{anyhow, Result};
use plonky2::field::types::Field;
use std::fmt;
use std::{collections::HashMap, fmt};
use strum_macros::FromRepr;

use super::{AnchoredKey, CustomPredicateRef, Hash, Predicate, ToFields, Value, F};
Expand Down Expand Up @@ -30,7 +30,6 @@ impl ToFields for NativePredicate {
}
}

// TODO: Incorporate custom statements into this enum.
/// Type encapsulating statements with their associated arguments.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Statement {
Expand All @@ -45,7 +44,7 @@ pub enum Statement {
SumOf(AnchoredKey, AnchoredKey, AnchoredKey),
ProductOf(AnchoredKey, AnchoredKey, AnchoredKey),
MaxOf(AnchoredKey, AnchoredKey, AnchoredKey),
Custom(CustomPredicateRef, Vec<Hash>),
Custom(CustomPredicateRef, Vec<AnchoredKey>),
}

impl Statement {
Expand Down Expand Up @@ -83,7 +82,7 @@ impl Statement {
Self::SumOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::ProductOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::MaxOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(|h| Literal(h.into()))),
Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(|h| Key(h))),
}
}
}
Expand Down

0 comments on commit 01f93be

Please sign in to comment.