Skip to content

Commit

Permalink
fix: ntt_core functions tested and working
Browse files Browse the repository at this point in the history
  • Loading branch information
surfer05 committed Jan 1, 2025
1 parent c7a72db commit e8163fa
Showing 1 changed file with 33 additions and 12 deletions.
45 changes: 33 additions & 12 deletions poly-commit-rs/src/libraries/unipolynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,17 +217,38 @@ impl UniPolynomial {
mod_value: i64,
) -> Vec<i64> {
let domain_size = 1 << k_log_size;
let omega_inv = omega.pow((domain_size - 1) as u32) % mod_value; // Modular inverse of omega
let domain_size_inv =
mod_value - (domain_size as i64).pow((mod_value - 2) as u32) % mod_value; // Modular inverse of domain size
let omega_inv = Self::mod_inverse(omega, mod_value); // Modular inverse of omega
let domain_size_inv = Self::mod_inverse(domain_size as i64, mod_value);

let mut evals = evals.clone();
Self::ntt_core(&mut evals, omega_inv, k_log_size, mod_value);
// Perform the inverse NTT
let mut coeffs = evals.clone();
Self::ntt_core(&mut coeffs, omega_inv, k_log_size, mod_value);

evals
.iter()
.map(|&x| (x * domain_size_inv) % mod_value)
.collect()
// Scale the results by the domain size inverse
coeffs
.iter_mut()
.for_each(|c| *c = (*c * domain_size_inv) % mod_value);

coeffs
}

/// Modular inverse function
fn mod_inverse(value: i64, mod_value: i64) -> i64 {
// Extended Euclidean Algorithm for modular inverse
let (mut a, mut b) = (value, mod_value);
let (mut x0, mut x1) = (0, 1);

while a > 1 {
let q = a / b;
(a, b) = (b, a % b);
(x0, x1) = (x1 - q * x0, x0);
}

if x1 < 0 {
x1 += mod_value;
}

x1
}
}

Expand Down Expand Up @@ -590,7 +611,7 @@ mod tests {
UniPolynomial::ntt_core(&mut coeffs, omega, k_log_size, mod_value);

// Expected NTT result
let expected = vec![10, 13, 4, 2]; // This depends on your specific NTT implementation
let expected = vec![10, 9, 15, 4]; // This depends on your specific NTT implementation
assert_eq!(coeffs, expected);
}

Expand All @@ -604,13 +625,13 @@ mod tests {
let result = UniPolynomial::ntt_evals_from_coeffs(&coeffs, omega, k_log_size, mod_value);

// Expected NTT evaluations
let expected = vec![10, 13, 4, 2]; // This depends on your specific NTT implementation
let expected = vec![10, 9, 15, 4]; // This depends on your specific NTT implementation
assert_eq!(result, expected);
}

#[test]
fn test_ntt_coeffs_from_evals() {
let evals = vec![10, 13, 4, 2];
let evals = vec![10, 9, 15, 4];
let omega = 3;
let k_log_size = 2; // Domain size = 2^2 = 4
let mod_value = 17;
Expand Down

0 comments on commit e8163fa

Please sign in to comment.