Skip to content

Commit

Permalink
Implement min/max for consts
Browse files Browse the repository at this point in the history
  • Loading branch information
fkettelhoit committed May 21, 2024
1 parent 27cf557 commit 362c843
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 87 deletions.
2 changes: 1 addition & 1 deletion src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::token::{MetaInfo, SignedNumType, UnsignedNumType};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Program<T> {
/// The external constants that the top level const definitions depend upon.
pub const_deps: HashMap<String, HashMap<String, (String, T)>>,
pub const_deps: HashMap<String, HashMap<String, T>>,
/// Top level const definitions.
pub const_defs: HashMap<String, ConstDef>,
/// Top level struct type definitions.
Expand Down
59 changes: 14 additions & 45 deletions src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,18 +354,18 @@ impl UntypedProgram {
struct_names,
enum_names,
};
let mut const_deps: HashMap<String, HashMap<String, (String, Type)>> = HashMap::new();
let mut const_deps: HashMap<String, HashMap<String, Type>> = HashMap::new();
let mut const_types = HashMap::with_capacity(self.const_defs.len());
let mut const_defs = HashMap::with_capacity(self.const_defs.len());
{
for (const_name, const_def) in self.const_defs.iter() {
fn check_const_expr(
const_name: &String,
value: &ConstExpr,
const_def: &ConstDef,
errors: &mut Vec<Option<TypeError>>,
const_deps: &mut HashMap<String, HashMap<String, (String, Type)>>,
const_deps: &mut HashMap<String, HashMap<String, Type>>,
) {
match &const_def.value {
match value {
ConstExpr::True | ConstExpr::False => {
if const_def.ty != Type::Bool {
let e = TypeErrorEnum::UnexpectedType {
Expand Down Expand Up @@ -396,52 +396,21 @@ impl UntypedProgram {
}
}
ConstExpr::ExternalValue { party, identifier } => {
const_deps.entry(party.clone()).or_default().insert(
identifier.clone(),
(const_name.clone(), const_def.ty.clone()),
);
const_deps
.entry(party.clone())
.or_default()
.insert(identifier.clone(), const_def.ty.clone());
}
ConstExpr::Max(args) | ConstExpr::Min(args) => {
for arg in args {
check_const_expr(arg, const_def, errors, const_deps);
}
}
ConstExpr::Max(args) => for arg in args {},
ConstExpr::Min(_) => todo!(),
}
}
check_const_expr(&const_name, &const_def, &mut errors, &mut const_deps);
check_const_expr(&const_def.value, &const_def, &mut errors, &mut const_deps);
const_defs.insert(const_name.clone(), const_def.clone());
const_types.insert(const_name.clone(), const_def.ty.clone());
// TODO: remove the following:
/*match &const_def.value {
ConstExpr::Literal(expr) => {
match expr.type_check(&top_level_defs, &mut env, &mut fns, &defs) {
Ok(mut expr) => {
if let Err(errs) = check_type(&mut expr, &const_def.ty) {
errors.extend(errs);
}
const_defs.insert(
const_name.clone(),
ConstDef {
ty: const_def.ty.clone(),
value: ConstExpr::Literal(expr),
meta: const_def.meta,
},
);
}
Err(errs) => {
for e in errs.into_iter().flatten() {
if let TypeError(TypeErrorEnum::UnknownEnum(p, n), _) = e {
// ignore this error, constant can be provided later during compilation
const_deps
.entry(p)
.or_default()
.insert(n, (const_name.clone(), const_def.ty.clone()));
} else {
errors.push(Some(e));
}
}
}
}
const_types.insert(const_name.clone(), const_def.ty.clone());
}
}*/
}
}
let mut struct_defs = HashMap::with_capacity(self.struct_defs.len());
Expand Down
184 changes: 147 additions & 37 deletions src/compile.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
//! Compiles a [`crate::ast::Program`] to a [`crate::circuit::Circuit`].
use std::{cmp::max, collections::HashMap};
use std::{
cmp::{max, min},
collections::HashMap,
};

use crate::{
ast::{
Expand Down Expand Up @@ -56,14 +59,26 @@ impl TypedProgram {
) -> Result<(Circuit, &TypedFnDef), CompilerError> {
let mut env = Env::new();
let mut const_sizes = HashMap::new();
let mut consts_unsigned = HashMap::new();
let mut consts_signed = HashMap::new();
for (party, deps) in self.const_deps.iter() {
for (c, (identifier, ty)) in deps {
for (c, ty) in deps {
let Some(party_deps) = consts.get(party) else {
todo!("missing party dep for {party}");
};
let Some(literal) = party_deps.get(c) else {
todo!("missing value {party}::{c}");
};
let identifier = format!("{party}::{c}");
match literal {
Literal::NumUnsigned(n, _) => {
consts_unsigned.insert(identifier.clone(), *n);
}
Literal::NumSigned(n, _) => {
consts_signed.insert(identifier.clone(), *n);
}
_ => {}
}
if literal.is_of_type(self, ty) {
let bits = literal
.as_bits(self, &const_sizes)
Expand All @@ -72,7 +87,7 @@ impl TypedProgram {
.collect();
env.let_in_current_scope(identifier.clone(), bits);
if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal {
const_sizes.insert(identifier.clone(), *size as usize);
const_sizes.insert(identifier, *size as usize);
}
} else {
return Err(CompilerError::InvalidLiteralType(
Expand All @@ -84,58 +99,153 @@ impl TypedProgram {
}
let mut input_gates = vec![];
let mut wire = 2;
if let Some(fn_def) = self.fn_defs.get(fn_name) {
for param in fn_def.params.iter() {
let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes);
let mut wires = Vec::with_capacity(type_size);
for _ in 0..type_size {
wires.push(wire);
wire += 1;
let Some(fn_def) = self.fn_defs.get(fn_name) else {
return Err(CompilerError::FnNotFound(fn_name.to_string()));
};
for param in fn_def.params.iter() {
let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes);
let mut wires = Vec::with_capacity(type_size);
for _ in 0..type_size {
wires.push(wire);
wire += 1;
}
input_gates.push(type_size);
env.let_in_current_scope(param.name.clone(), wires);
}
fn resolve_const_expr_unsigned(
expr: &ConstExpr,
consts_unsigned: &HashMap<String, u64>,
) -> u64 {
match expr {
ConstExpr::NumUnsigned(n, _) => *n,
ConstExpr::ExternalValue { party, identifier } => *consts_unsigned
.get(&format!("{party}::{identifier}"))
.unwrap(),
ConstExpr::Max(args) => {
let mut result = 0;
for arg in args {
result = max(result, resolve_const_expr_unsigned(arg, consts_unsigned));
}
result
}
ConstExpr::Min(args) => {
let mut result = u64::MAX;
for arg in args {
result = min(result, resolve_const_expr_unsigned(arg, consts_unsigned));
}
result
}
expr => panic!("Not an unsigned const expr: {expr:?}"),
}
}
fn resolve_const_expr_signed(
expr: &ConstExpr,
consts_signed: &HashMap<String, i64>,
) -> i64 {
match expr {
ConstExpr::NumSigned(n, _) => *n,
ConstExpr::ExternalValue { party, identifier } => *consts_signed
.get(&format!("{party}::{identifier}"))
.unwrap(),
ConstExpr::Max(args) => {
let mut result = 0;
for arg in args {
result = max(result, resolve_const_expr_signed(arg, consts_signed));
}
result
}
ConstExpr::Min(args) => {
let mut result = i64::MAX;
for arg in args {
result = min(result, resolve_const_expr_signed(arg, consts_signed));
}
result
}
expr => panic!("Not an unsigned const expr: {expr:?}"),
}
}
for (const_name, const_def) in self.const_defs.iter() {
if let Type::Unsigned(UnsignedNumType::Usize) = const_def.ty {
if let ConstExpr::ExternalValue { party, identifier } = &const_def.value {
let identifier = format!("{party}::{identifier}");
const_sizes.insert(const_name.clone(), *const_sizes.get(&identifier).unwrap());
}
input_gates.push(type_size);
env.let_in_current_scope(param.name.clone(), wires);
let n = resolve_const_expr_unsigned(&const_def.value, &consts_unsigned);
const_sizes.insert(const_name.clone(), n as usize);
}
let mut circuit = CircuitBuilder::new(input_gates, const_sizes);
for (identifier, const_def) in self.const_defs.iter() {
match &const_def.value {
ConstExpr::True => env.let_in_current_scope(identifier.clone(), vec![1]),
ConstExpr::False => env.let_in_current_scope(identifier.clone(), vec![0]),
ConstExpr::NumUnsigned(n, ty) => {
let ty = Type::Unsigned(*ty);
}
let mut circuit = CircuitBuilder::new(input_gates, const_sizes);
for (const_name, const_def) in self.const_defs.iter() {
match &const_def.value {
ConstExpr::True => env.let_in_current_scope(const_name.clone(), vec![1]),
ConstExpr::False => env.let_in_current_scope(const_name.clone(), vec![0]),
ConstExpr::NumUnsigned(n, ty) => {
let ty = Type::Unsigned(*ty);
let mut bits =
Vec::with_capacity(ty.size_in_bits_for_defs(self, circuit.const_sizes()));
unsigned_to_bits(
*n,
ty.size_in_bits_for_defs(self, circuit.const_sizes()),
&mut bits,
);
let bits = bits.into_iter().map(|b| b as usize).collect();
env.let_in_current_scope(const_name.clone(), bits);
}
ConstExpr::NumSigned(n, ty) => {
let ty = Type::Signed(*ty);
let mut bits =
Vec::with_capacity(ty.size_in_bits_for_defs(self, circuit.const_sizes()));
signed_to_bits(
*n,
ty.size_in_bits_for_defs(self, circuit.const_sizes()),
&mut bits,
);
let bits = bits.into_iter().map(|b| b as usize).collect();
env.let_in_current_scope(const_name.clone(), bits);
}
ConstExpr::ExternalValue { party, identifier } => {
let bits = env.get(&format!("{party}::{identifier}")).unwrap();
env.let_in_current_scope(const_name.clone(), bits);
}
expr @ (ConstExpr::Max(_) | ConstExpr::Min(_)) => {
if let Type::Unsigned(_) = const_def.ty {
let result = resolve_const_expr_unsigned(expr, &consts_unsigned);
let mut bits = Vec::with_capacity(
ty.size_in_bits_for_defs(self, circuit.const_sizes()),
const_def
.ty
.size_in_bits_for_defs(self, circuit.const_sizes()),
);
unsigned_to_bits(
*n,
ty.size_in_bits_for_defs(self, circuit.const_sizes()),
result,
const_def
.ty
.size_in_bits_for_defs(self, circuit.const_sizes()),
&mut bits,
);
let bits = bits.into_iter().map(|b| b as usize).collect();
env.let_in_current_scope(identifier.clone(), bits);
}
ConstExpr::NumSigned(n, ty) => {
let ty = Type::Signed(*ty);
env.let_in_current_scope(const_name.clone(), bits);
} else {
let result = resolve_const_expr_signed(expr, &consts_signed);
let mut bits = Vec::with_capacity(
ty.size_in_bits_for_defs(self, circuit.const_sizes()),
const_def
.ty
.size_in_bits_for_defs(self, circuit.const_sizes()),
);
signed_to_bits(
*n,
ty.size_in_bits_for_defs(self, circuit.const_sizes()),
result,
const_def
.ty
.size_in_bits_for_defs(self, circuit.const_sizes()),
&mut bits,
);
let bits = bits.into_iter().map(|b| b as usize).collect();
env.let_in_current_scope(identifier.clone(), bits);
env.let_in_current_scope(const_name.clone(), bits);
}
ConstExpr::ExternalValue { .. } => {}
ConstExpr::Max(_) => todo!("compile max"),
ConstExpr::Min(_) => todo!("compile min"),
}
}
let output_gates = compile_block(&fn_def.body, self, &mut env, &mut circuit);
Ok((circuit.build(output_gates), fn_def))
} else {
Err(CompilerError::FnNotFound(fn_name.to_string()))
}
let output_gates = compile_block(&fn_def.body, self, &mut env, &mut circuit);
Ok((circuit.build(output_gates), fn_def))
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,16 @@ impl<'a> Evaluator<'a> {
) -> Self {
let mut const_sizes = HashMap::new();
for (party, deps) in program.const_deps.iter() {
for (c, (identifier, _)) in deps {
for (c, _) in deps {
let Some(party_deps) = consts.get(party) else {
todo!("missing party dep for {party}");
};
let Some(literal) = party_deps.get(c) else {
todo!("missing value {party}::{c}");
};
if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal {
const_sizes.insert(identifier.clone(), *size as usize);
let identifier = format!("{party}::{c}");
const_sizes.insert(identifier, *size as usize);
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,16 @@ pub fn compile_with_constants(
let main = main.clone();
let mut const_sizes = HashMap::new();
for (party, deps) in program.const_deps.iter() {
for (c, (identifier, _)) in deps {
for (c, _) in deps {
let Some(party_deps) = consts.get(party) else {
todo!("missing party dep for {party}");
};
let Some(literal) = party_deps.get(c) else {
todo!("missing value {party}::{c}");
};
if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal {
const_sizes.insert(identifier.clone(), *size as usize);
let identifier = format!("{party}::{c}");
const_sizes.insert(identifier, *size as usize);
}
}
}
Expand Down
Loading

0 comments on commit 362c843

Please sign in to comment.