From 362c84323b24b5ba167f40d1d107e00fc568b457 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 21 May 2024 16:51:16 +0100 Subject: [PATCH] Implement min/max for consts --- src/ast.rs | 2 +- src/check.rs | 59 ++++----------- src/compile.rs | 184 +++++++++++++++++++++++++++++++++++++---------- src/eval.rs | 5 +- src/lib.rs | 5 +- tests/compile.rs | 36 ++++++++++ 6 files changed, 204 insertions(+), 87 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 40b546d..f14631e 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -12,7 +12,7 @@ use crate::token::{MetaInfo, SignedNumType, UnsignedNumType}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Program { /// The external constants that the top level const definitions depend upon. - pub const_deps: HashMap>, + pub const_deps: HashMap>, /// Top level const definitions. pub const_defs: HashMap, /// Top level struct type definitions. diff --git a/src/check.rs b/src/check.rs index c0a4ed1..52e6eaf 100644 --- a/src/check.rs +++ b/src/check.rs @@ -354,18 +354,18 @@ impl UntypedProgram { struct_names, enum_names, }; - let mut const_deps: HashMap> = HashMap::new(); + let mut const_deps: HashMap> = HashMap::new(); let mut const_types = HashMap::with_capacity(self.const_defs.len()); let mut const_defs = HashMap::with_capacity(self.const_defs.len()); { for (const_name, const_def) in self.const_defs.iter() { fn check_const_expr( - const_name: &String, + value: &ConstExpr, const_def: &ConstDef, errors: &mut Vec>, - const_deps: &mut HashMap>, + const_deps: &mut HashMap>, ) { - match &const_def.value { + match value { ConstExpr::True | ConstExpr::False => { if const_def.ty != Type::Bool { let e = TypeErrorEnum::UnexpectedType { @@ -396,52 +396,21 @@ impl UntypedProgram { } } ConstExpr::ExternalValue { party, identifier } => { - const_deps.entry(party.clone()).or_default().insert( - identifier.clone(), - (const_name.clone(), const_def.ty.clone()), - ); + const_deps + .entry(party.clone()) + .or_default() + .insert(identifier.clone(), const_def.ty.clone()); + } + ConstExpr::Max(args) | ConstExpr::Min(args) => { + for arg in args { + check_const_expr(arg, const_def, errors, const_deps); + } } - ConstExpr::Max(args) => for arg in args {}, - ConstExpr::Min(_) => todo!(), } } - check_const_expr(&const_name, &const_def, &mut errors, &mut const_deps); + check_const_expr(&const_def.value, &const_def, &mut errors, &mut const_deps); const_defs.insert(const_name.clone(), const_def.clone()); const_types.insert(const_name.clone(), const_def.ty.clone()); - // TODO: remove the following: - /*match &const_def.value { - ConstExpr::Literal(expr) => { - match expr.type_check(&top_level_defs, &mut env, &mut fns, &defs) { - Ok(mut expr) => { - if let Err(errs) = check_type(&mut expr, &const_def.ty) { - errors.extend(errs); - } - const_defs.insert( - const_name.clone(), - ConstDef { - ty: const_def.ty.clone(), - value: ConstExpr::Literal(expr), - meta: const_def.meta, - }, - ); - } - Err(errs) => { - for e in errs.into_iter().flatten() { - if let TypeError(TypeErrorEnum::UnknownEnum(p, n), _) = e { - // ignore this error, constant can be provided later during compilation - const_deps - .entry(p) - .or_default() - .insert(n, (const_name.clone(), const_def.ty.clone())); - } else { - errors.push(Some(e)); - } - } - } - } - const_types.insert(const_name.clone(), const_def.ty.clone()); - } - }*/ } } let mut struct_defs = HashMap::with_capacity(self.struct_defs.len()); diff --git a/src/compile.rs b/src/compile.rs index 746c0ca..c76d2ce 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -1,6 +1,9 @@ //! Compiles a [`crate::ast::Program`] to a [`crate::circuit::Circuit`]. -use std::{cmp::max, collections::HashMap}; +use std::{ + cmp::{max, min}, + collections::HashMap, +}; use crate::{ ast::{ @@ -56,14 +59,26 @@ impl TypedProgram { ) -> Result<(Circuit, &TypedFnDef), CompilerError> { let mut env = Env::new(); let mut const_sizes = HashMap::new(); + let mut consts_unsigned = HashMap::new(); + let mut consts_signed = HashMap::new(); for (party, deps) in self.const_deps.iter() { - for (c, (identifier, ty)) in deps { + for (c, ty) in deps { let Some(party_deps) = consts.get(party) else { todo!("missing party dep for {party}"); }; let Some(literal) = party_deps.get(c) else { todo!("missing value {party}::{c}"); }; + let identifier = format!("{party}::{c}"); + match literal { + Literal::NumUnsigned(n, _) => { + consts_unsigned.insert(identifier.clone(), *n); + } + Literal::NumSigned(n, _) => { + consts_signed.insert(identifier.clone(), *n); + } + _ => {} + } if literal.is_of_type(self, ty) { let bits = literal .as_bits(self, &const_sizes) @@ -72,7 +87,7 @@ impl TypedProgram { .collect(); env.let_in_current_scope(identifier.clone(), bits); if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal { - const_sizes.insert(identifier.clone(), *size as usize); + const_sizes.insert(identifier, *size as usize); } } else { return Err(CompilerError::InvalidLiteralType( @@ -84,58 +99,153 @@ impl TypedProgram { } let mut input_gates = vec![]; let mut wire = 2; - if let Some(fn_def) = self.fn_defs.get(fn_name) { - for param in fn_def.params.iter() { - let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes); - let mut wires = Vec::with_capacity(type_size); - for _ in 0..type_size { - wires.push(wire); - wire += 1; + let Some(fn_def) = self.fn_defs.get(fn_name) else { + return Err(CompilerError::FnNotFound(fn_name.to_string())); + }; + for param in fn_def.params.iter() { + let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes); + let mut wires = Vec::with_capacity(type_size); + for _ in 0..type_size { + wires.push(wire); + wire += 1; + } + input_gates.push(type_size); + env.let_in_current_scope(param.name.clone(), wires); + } + fn resolve_const_expr_unsigned( + expr: &ConstExpr, + consts_unsigned: &HashMap, + ) -> u64 { + match expr { + ConstExpr::NumUnsigned(n, _) => *n, + ConstExpr::ExternalValue { party, identifier } => *consts_unsigned + .get(&format!("{party}::{identifier}")) + .unwrap(), + ConstExpr::Max(args) => { + let mut result = 0; + for arg in args { + result = max(result, resolve_const_expr_unsigned(arg, consts_unsigned)); + } + result + } + ConstExpr::Min(args) => { + let mut result = u64::MAX; + for arg in args { + result = min(result, resolve_const_expr_unsigned(arg, consts_unsigned)); + } + result + } + expr => panic!("Not an unsigned const expr: {expr:?}"), + } + } + fn resolve_const_expr_signed( + expr: &ConstExpr, + consts_signed: &HashMap, + ) -> i64 { + match expr { + ConstExpr::NumSigned(n, _) => *n, + ConstExpr::ExternalValue { party, identifier } => *consts_signed + .get(&format!("{party}::{identifier}")) + .unwrap(), + ConstExpr::Max(args) => { + let mut result = 0; + for arg in args { + result = max(result, resolve_const_expr_signed(arg, consts_signed)); + } + result + } + ConstExpr::Min(args) => { + let mut result = i64::MAX; + for arg in args { + result = min(result, resolve_const_expr_signed(arg, consts_signed)); + } + result + } + expr => panic!("Not an unsigned const expr: {expr:?}"), + } + } + for (const_name, const_def) in self.const_defs.iter() { + if let Type::Unsigned(UnsignedNumType::Usize) = const_def.ty { + if let ConstExpr::ExternalValue { party, identifier } = &const_def.value { + let identifier = format!("{party}::{identifier}"); + const_sizes.insert(const_name.clone(), *const_sizes.get(&identifier).unwrap()); } - input_gates.push(type_size); - env.let_in_current_scope(param.name.clone(), wires); + let n = resolve_const_expr_unsigned(&const_def.value, &consts_unsigned); + const_sizes.insert(const_name.clone(), n as usize); } - let mut circuit = CircuitBuilder::new(input_gates, const_sizes); - for (identifier, const_def) in self.const_defs.iter() { - match &const_def.value { - ConstExpr::True => env.let_in_current_scope(identifier.clone(), vec![1]), - ConstExpr::False => env.let_in_current_scope(identifier.clone(), vec![0]), - ConstExpr::NumUnsigned(n, ty) => { - let ty = Type::Unsigned(*ty); + } + let mut circuit = CircuitBuilder::new(input_gates, const_sizes); + for (const_name, const_def) in self.const_defs.iter() { + match &const_def.value { + ConstExpr::True => env.let_in_current_scope(const_name.clone(), vec![1]), + ConstExpr::False => env.let_in_current_scope(const_name.clone(), vec![0]), + ConstExpr::NumUnsigned(n, ty) => { + let ty = Type::Unsigned(*ty); + let mut bits = + Vec::with_capacity(ty.size_in_bits_for_defs(self, circuit.const_sizes())); + unsigned_to_bits( + *n, + ty.size_in_bits_for_defs(self, circuit.const_sizes()), + &mut bits, + ); + let bits = bits.into_iter().map(|b| b as usize).collect(); + env.let_in_current_scope(const_name.clone(), bits); + } + ConstExpr::NumSigned(n, ty) => { + let ty = Type::Signed(*ty); + let mut bits = + Vec::with_capacity(ty.size_in_bits_for_defs(self, circuit.const_sizes())); + signed_to_bits( + *n, + ty.size_in_bits_for_defs(self, circuit.const_sizes()), + &mut bits, + ); + let bits = bits.into_iter().map(|b| b as usize).collect(); + env.let_in_current_scope(const_name.clone(), bits); + } + ConstExpr::ExternalValue { party, identifier } => { + let bits = env.get(&format!("{party}::{identifier}")).unwrap(); + env.let_in_current_scope(const_name.clone(), bits); + } + expr @ (ConstExpr::Max(_) | ConstExpr::Min(_)) => { + if let Type::Unsigned(_) = const_def.ty { + let result = resolve_const_expr_unsigned(expr, &consts_unsigned); let mut bits = Vec::with_capacity( - ty.size_in_bits_for_defs(self, circuit.const_sizes()), + const_def + .ty + .size_in_bits_for_defs(self, circuit.const_sizes()), ); unsigned_to_bits( - *n, - ty.size_in_bits_for_defs(self, circuit.const_sizes()), + result, + const_def + .ty + .size_in_bits_for_defs(self, circuit.const_sizes()), &mut bits, ); let bits = bits.into_iter().map(|b| b as usize).collect(); - env.let_in_current_scope(identifier.clone(), bits); - } - ConstExpr::NumSigned(n, ty) => { - let ty = Type::Signed(*ty); + env.let_in_current_scope(const_name.clone(), bits); + } else { + let result = resolve_const_expr_signed(expr, &consts_signed); let mut bits = Vec::with_capacity( - ty.size_in_bits_for_defs(self, circuit.const_sizes()), + const_def + .ty + .size_in_bits_for_defs(self, circuit.const_sizes()), ); signed_to_bits( - *n, - ty.size_in_bits_for_defs(self, circuit.const_sizes()), + result, + const_def + .ty + .size_in_bits_for_defs(self, circuit.const_sizes()), &mut bits, ); let bits = bits.into_iter().map(|b| b as usize).collect(); - env.let_in_current_scope(identifier.clone(), bits); + env.let_in_current_scope(const_name.clone(), bits); } - ConstExpr::ExternalValue { .. } => {} - ConstExpr::Max(_) => todo!("compile max"), - ConstExpr::Min(_) => todo!("compile min"), } } - let output_gates = compile_block(&fn_def.body, self, &mut env, &mut circuit); - Ok((circuit.build(output_gates), fn_def)) - } else { - Err(CompilerError::FnNotFound(fn_name.to_string())) } + let output_gates = compile_block(&fn_def.body, self, &mut env, &mut circuit); + Ok((circuit.build(output_gates), fn_def)) } } diff --git a/src/eval.rs b/src/eval.rs index aa3cad3..9a1b43c 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -44,7 +44,7 @@ impl<'a> Evaluator<'a> { ) -> Self { let mut const_sizes = HashMap::new(); for (party, deps) in program.const_deps.iter() { - for (c, (identifier, _)) in deps { + for (c, _) in deps { let Some(party_deps) = consts.get(party) else { todo!("missing party dep for {party}"); }; @@ -52,7 +52,8 @@ impl<'a> Evaluator<'a> { todo!("missing value {party}::{c}"); }; if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal { - const_sizes.insert(identifier.clone(), *size as usize); + let identifier = format!("{party}::{c}"); + const_sizes.insert(identifier, *size as usize); } } } diff --git a/src/lib.rs b/src/lib.rs index f740ad8..15b684f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -122,7 +122,7 @@ pub fn compile_with_constants( let main = main.clone(); let mut const_sizes = HashMap::new(); for (party, deps) in program.const_deps.iter() { - for (c, (identifier, _)) in deps { + for (c, _) in deps { let Some(party_deps) = consts.get(party) else { todo!("missing party dep for {party}"); }; @@ -130,7 +130,8 @@ pub fn compile_with_constants( todo!("missing value {party}::{c}"); }; if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal { - const_sizes.insert(identifier.clone(), *size as usize); + let identifier = format!("{party}::{c}"); + const_sizes.insert(identifier, *size as usize); } } } diff --git a/tests/compile.rs b/tests/compile.rs index 91069fa..57f8013 100644 --- a/tests/compile.rs +++ b/tests/compile.rs @@ -1947,3 +1947,39 @@ pub fn main(x: u16) -> u16 { ); Ok(()) } + +#[test] +fn compile_const_aggregated_min() -> Result<(), Error> { + let prg = " +const MY_CONST: usize = min(PARTY_0::MY_CONST, PARTY_1::MY_CONST); +pub fn main(x: u16) -> u16 { + let array = [2u16; MY_CONST]; + x + array[1] +} +"; + let consts = HashMap::from_iter(vec![ + ( + "PARTY_0".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(3, UnsignedNumType::Usize), + )]), + ), + ( + "PARTY_1".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(2, UnsignedNumType::Usize), + )]), + ), + ]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.set_u16(255); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!( + u16::try_from(output).map_err(|e| pretty_print(e, prg))?, + 257 + ); + Ok(()) +}