Skip to content

Commit

Permalink
Hide low-level type details when using compile + GarbleProgram
Browse files Browse the repository at this point in the history
  • Loading branch information
fkettelhoit committed Jan 23, 2024
1 parent c227ec2 commit 85b783e
Show file tree
Hide file tree
Showing 6 changed files with 383 additions and 291 deletions.
5 changes: 5 additions & 0 deletions src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ pub enum EvalError {
UnexpectedNumberOfInputsFromParty(usize),
/// An input literal could not be parsed.
LiteralParseError(CompileTimeError),
/// The circuit does not have an input argument with the given index.
InvalidArgIndex(usize),
/// The literal is not of the expected parameter type.
InvalidLiteralType(Literal, Type),
/// The number of output bits does not match the expected type.
Expand All @@ -68,6 +70,9 @@ impl std::fmt::Display for EvalError {
EvalError::LiteralParseError(err) => {
err.fmt(f)
}
EvalError::InvalidArgIndex(i) => {
f.write_fmt(format_args!("The circuit does not an input argument with index {i}"))
}
EvalError::InvalidLiteralType(literal, ty) => {
f.write_fmt(format_args!("The argument literal is not of type {ty}: '{literal}'"))
}
Expand Down
131 changes: 124 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,53 @@
//! A purely functional programming language with a Rust-like syntax that compiles to logic gates
//! for secure multi-party computation.
//!
//! Garble programs always terminate and are compiled into a combination of boolean AND / XOR / NOT
//! gates. These boolean circuits can either be executed directly (mostly for testing purposes) or
//! passed to a multi-party computation engine.
//!
//! ```rust
//! use garble_lang::{compile, literal::Literal, token::UnsignedNumType::U32};
//!
//! // Compile and type-check a simple program to add the inputs of 3 parties:
//! let code = "pub fn main(x: u32, y: u32, z: u32) -> u32 { x + y + z }";
//! let prg = compile(code).map_err(|e| e.prettify(&code)).unwrap();
//!
//! // We can evaluate the circuit directly, useful for testing purposes:
//! let mut eval = prg.evaluator();
//! eval.set_u32(2);
//! eval.set_u32(10);
//! eval.set_u32(100);
//! let output = eval.run().map_err(|e| e.prettify(&code)).unwrap();
//! assert_eq!(u32::try_from(output).map_err(|e| e.prettify(&code)).unwrap(), 2 + 10 + 100);
//!
//! // Or we can run the compiled circuit in an MPC engine, simulated using `prg.circuit.eval()`:
//! let x = prg.parse_arg(0, "2u32").unwrap().as_bits();
//! let y = prg.parse_arg(1, "10u32").unwrap().as_bits();
//! let z = prg.parse_arg(2, "100u32").unwrap().as_bits();
//! let output = prg.circuit.eval(&[x, y, z]); // use your own MPC engine here instead
//! let result = prg.parse_output(&output).unwrap();
//! assert_eq!("112u32", result.to_string());
//!
//! // Input arguments can also be constructed directly as literals:
//! let x = prg.literal_arg(0, Literal::NumUnsigned(2, U32)).unwrap().as_bits();
//! let y = prg.literal_arg(1, Literal::NumUnsigned(10, U32)).unwrap().as_bits();
//! let z = prg.literal_arg(2, Literal::NumUnsigned(100, U32)).unwrap().as_bits();
//! let output = prg.circuit.eval(&[x, y, z]); // use your own MPC engine here instead
//! let result = prg.parse_output(&output).unwrap();
//! assert_eq!(Literal::NumUnsigned(112, U32), result);
//! ```
#![deny(unsafe_code)]
#![deny(missing_docs)]
#![deny(rustdoc::broken_intra_doc_links)]

use check::TypeError;
use compile::CompilerError;
use eval::EvalError;
use eval::{EvalError, Evaluator};
use literal::Literal;
use parse::ParseError;
use scan::{scan, ScanError};
use std::fmt::Write as _;
use std::fmt::{Display, Write as _};
use token::MetaInfo;

use ast::{Expr, FnDef, Pattern, Program, Stmt, Type, VariantExpr};
Expand Down Expand Up @@ -56,12 +93,92 @@ pub fn check(prg: &str) -> Result<TypedProgram, Error> {
Ok(scan(prg)?.parse()?.type_check()?)
}

/// Scans, parses, type-checks and then compiles a program to a circuit of gates.
pub fn compile(prg: &str, fn_name: &str) -> Result<(TypedProgram, TypedFnDef, Circuit), Error> {
/// Scans, parses, type-checks and then compiles the `"main"` fn of a program to a boolean circuit.
pub fn compile(prg: &str) -> Result<GarbleProgram, Error> {
let program = check(prg)?;
let (circuit, main_fn) = program.compile(fn_name)?;
let main_fn = main_fn.clone();
Ok((program, main_fn, circuit))
let (circuit, main) = program.compile("main")?;
let main = main.clone();
Ok(GarbleProgram {
program,
main,
circuit,
})
}

/// The result of type-checking and compiling a Garble program.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct GarbleProgram {
/// The type-checked represenation of the full program.
pub program: TypedProgram,
/// The function to be executed as a circuit.
pub main: TypedFnDef,
/// The compilation output, as a circuit of boolean gates.
pub circuit: Circuit,
}

/// An input argument for a Garble program and circuit.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct GarbleArgument<'a>(Literal, &'a TypedProgram);

impl GarbleProgram {
/// Returns an evaluator that can be used to run the compiled circuit.
pub fn evaluator(&self) -> Evaluator<'_> {
Evaluator::new(&self.program, &self.main, &self.circuit)
}

/// Type-checks and uses the literal as the circuit input argument with the given index.
pub fn literal_arg(
&self,
arg_index: usize,
literal: Literal,
) -> Result<GarbleArgument<'_>, EvalError> {
let Some(param) = self.main.params.get(arg_index) else {
return Err(EvalError::InvalidArgIndex(arg_index));
};
if !literal.is_of_type(&self.program, &param.ty) {
return Err(EvalError::InvalidLiteralType(literal, param.ty.clone()));
}
Ok(GarbleArgument(literal, &self.program))
}

/// Tries to parse the string as the circuit input argument with the given index.
pub fn parse_arg(
&self,
arg_index: usize,
literal: &str,
) -> Result<GarbleArgument<'_>, EvalError> {
let Some(param) = self.main.params.get(arg_index) else {
return Err(EvalError::InvalidArgIndex(arg_index));
};
let literal = Literal::parse(&self.program, &param.ty, literal)
.map_err(EvalError::LiteralParseError)?;
Ok(GarbleArgument(literal, &self.program))
}

/// Tries to convert the circuit output back to a Garble literal.
pub fn parse_output(&self, bits: &[bool]) -> Result<Literal, EvalError> {
Literal::from_result_bits(&self.program, &self.main.ty, bits)
}
}

impl GarbleArgument<'_> {
/// Converts the argument to input bits for the compiled circuit.
pub fn as_bits(&self) -> Vec<bool> {
self.0.as_bits(self.1)
}

/// Converts the argument to a Garble literal.
pub fn as_literal(&self) -> Literal {
self.0.clone()
}
}

impl Display for GarbleArgument<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}

/// Errors that can occur during compile time, while a program is scanned, parsed or type-checked.
Expand Down
54 changes: 36 additions & 18 deletions tests/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ pub fn main(_x: bool) -> bool {
true
}
";
let (_, _, unoptimized) = compile(unoptimized, "main").map_err(|e| e.prettify(unoptimized))?;
let (_, _, optimized) = compile(optimized, "main").map_err(|e| e.prettify(optimized))?;
let unoptimized = compile(unoptimized).map_err(|e| e.prettify(unoptimized))?;
let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?;

assert_eq!(unoptimized.gates.len(), optimized.gates.len());
assert_eq!(
unoptimized.circuit.gates.len(),
optimized.circuit.gates.len()
);
Ok(())
}

Expand All @@ -31,10 +34,13 @@ pub fn main(_x: i32) -> i32 {
10i32
}
";
let (_, _, unoptimized) = compile(unoptimized, "main").map_err(|e| e.prettify(unoptimized))?;
let (_, _, optimized) = compile(optimized, "main").map_err(|e| e.prettify(optimized))?;
let unoptimized = compile(unoptimized).map_err(|e| e.prettify(unoptimized))?;
let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?;

assert_eq!(unoptimized.gates.len(), optimized.gates.len());
assert_eq!(
unoptimized.circuit.gates.len(),
optimized.circuit.gates.len()
);
Ok(())
}

Expand All @@ -51,9 +57,12 @@ pub fn main(b: bool, x: i32) -> bool {
if b { y } else { y }
}
";
let (_, _, unoptimized) = compile(unoptimized, "main").map_err(|e| e.prettify(unoptimized))?;
let (_, _, optimized) = compile(optimized, "main").map_err(|e| e.prettify(optimized))?;
assert_eq!(unoptimized.gates.len(), optimized.gates.len());
let unoptimized = compile(unoptimized).map_err(|e| e.prettify(unoptimized))?;
let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?;
assert_eq!(
unoptimized.circuit.gates.len(),
optimized.circuit.gates.len()
);
Ok(())
}

Expand All @@ -69,9 +78,12 @@ pub fn main(b: bool) -> bool {
b
}
";
let (_, _, unoptimized) = compile(unoptimized, "main").map_err(|e| e.prettify(unoptimized))?;
let (_, _, optimized) = compile(optimized, "main").map_err(|e| e.prettify(optimized))?;
assert_eq!(unoptimized.gates.len(), optimized.gates.len());
let unoptimized = compile(unoptimized).map_err(|e| e.prettify(unoptimized))?;
let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?;
assert_eq!(
unoptimized.circuit.gates.len(),
optimized.circuit.gates.len()
);
Ok(())
}

Expand All @@ -98,9 +110,12 @@ pub fn main(arr1: [u8; 8], arr2: [u8; 8], choice: bool) -> [u8; 8] {
arr
}
";
let (_, _, unoptimized) = compile(unoptimized, "main").map_err(|e| e.prettify(unoptimized))?;
let (_, _, optimized) = compile(optimized, "main").map_err(|e| e.prettify(optimized))?;
assert_eq!(unoptimized.gates.len(), optimized.gates.len());
let unoptimized = compile(unoptimized).map_err(|e| e.prettify(unoptimized))?;
let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?;
assert_eq!(
unoptimized.circuit.gates.len(),
optimized.circuit.gates.len()
);
Ok(())
}

Expand All @@ -127,8 +142,11 @@ pub fn main(arr1: [u8; 8], arr2: [u8; 8], choice: bool) -> [u8; 8] {
arr
}
";
let (_, _, unoptimized) = compile(unoptimized, "main").map_err(|e| e.prettify(unoptimized))?;
let (_, _, optimized) = compile(optimized, "main").map_err(|e| e.prettify(optimized))?;
assert_eq!(unoptimized.gates.len(), optimized.gates.len());
let unoptimized = compile(unoptimized).map_err(|e| e.prettify(unoptimized))?;
let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?;
assert_eq!(
unoptimized.circuit.gates.len(),
optimized.circuit.gates.len()
);
Ok(())
}
Loading

0 comments on commit 85b783e

Please sign in to comment.