Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ColoCarletti committed Oct 15, 2024
1 parent 5f97990 commit dc4124e
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 54 deletions.
222 changes: 189 additions & 33 deletions math/src/circle/cfft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,80 @@ extern crate alloc;
use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field};

#[cfg(feature = "alloc")]
pub fn inplace_cfft(
pub fn cfft(
input: &mut [FieldElement<Mersenne31Field>],
twiddles: Vec<Vec<FieldElement<Mersenne31Field>>>,
) {
let mut group_count = 1;
let mut group_size = input.len();
let mut round = 0;
let log_2_size = input.len().trailing_zeros();

(0..log_2_size).for_each(|i| {
let chunk_size = 1 << i + 1;
let half_chunk_size = 1 << i;
input.chunks_mut(chunk_size).for_each(|chunk| {
let (hi_part, low_part) = chunk.split_at_mut(half_chunk_size);
hi_part.into_iter().zip(low_part).enumerate().for_each( |(j, (hi, low))| {
let temp = *low * twiddles[i as usize][j];
*low = *hi - temp;
*hi = *hi + temp;
});
});
});
}

while group_count < input.len() {
let round_twiddles = &twiddles[round];
#[allow(clippy::needless_range_loop)] // the suggestion would obfuscate a bit the algorithm
for group in 0..group_count {
let first_in_group = group * group_size;
let first_in_next_group = first_in_group + group_size / 2;

let w = &round_twiddles[group]; // a twiddle factor is used per group
#[cfg(feature = "alloc")]
pub fn icfft(
input: &mut [FieldElement<Mersenne31Field>],
twiddles: Vec<Vec<FieldElement<Mersenne31Field>>>,
) {
let log_2_size = input.len().trailing_zeros();

println!("{:?}", twiddles);

(0..log_2_size).for_each(|i| {
let chunk_size = 1 << log_2_size - i;
let half_chunk_size = chunk_size >> 1;
input.chunks_mut(chunk_size).for_each(|chunk| {
let (hi_part, low_part) = chunk.split_at_mut(half_chunk_size);
hi_part.into_iter().zip(low_part).enumerate().for_each( |(j, (hi, low))| {
let temp = *hi + *low;
*low = (*hi - *low) * twiddles[i as usize][j];
*hi = temp;
});
});
});
}

for i in first_in_group..first_in_next_group {
let wi = w * input[i + group_size / 2];
pub fn order_cfft_result_naive(input: &mut [FieldElement<Mersenne31Field>]) -> Vec<FieldElement<Mersenne31Field>> {
let mut result = Vec::new();
let length = input.len();
for i in (0..length/2) {
result.push(input[i]);
result.push(input[length - i - 1]);
}
result
}

let y0 = input[i] + wi;
let y1 = input[i] - wi;
pub fn order_icfft_input_naive(input: &mut [FieldElement<Mersenne31Field>]) -> Vec<FieldElement<Mersenne31Field>> {
let mut result = Vec::new();
(0..input.len()).step_by(2).for_each( |i| {
result.push(input[i]);
});
(1..input.len()).step_by(2).rev().for_each( |i| {
result.push(input[i]);
});
result
}

input[i] = y0;
input[i + group_size / 2] = y1;
}
}
group_count *= 2;
group_size /= 2;
round += 1;
pub fn reverse_cfft_index(index: usize, length: usize) -> usize {
if index < (length >> 1) { // index < length / 2
index << 1 // index * 2
} else {
(((length - 1) - index) << 1) + 1
}
}


pub fn cfft_4(
input: &mut [FieldElement<Mersenne31Field>],
twiddles: Vec<Vec<FieldElement<Mersenne31Field>>>,
Expand Down Expand Up @@ -104,19 +145,134 @@ pub fn cfft_8(
stage3.into_iter().map(|elem| elem * f).collect()
}

pub fn inplace_order_cfft_values(input: &mut [FieldElement<Mersenne31Field>]) {
for i in 0..input.len() {
let cfft_index = reverse_cfft_index(i, input.len().trailing_zeros());
if cfft_index > i {
input.swap(i, cfft_index);

#[cfg(test)]
mod tests {
use super::*;
type FE = FieldElement<Mersenne31Field>;

#[test]
fn ordering_4() {
let expected_slice = [
FE::from(0),
FE::from(1),
FE::from(2),
FE::from(3),
];

let mut slice = [
FE::from(0),
FE::from(2),
FE::from(3),
FE::from(1),
];

let res = order_cfft_result_naive(&mut slice);

assert_eq!(res, expected_slice)
}

#[test]
fn ordering() {
let expected_slice = [
FE::from(0),
FE::from(1),
FE::from(2),
FE::from(3),
FE::from(4),
FE::from(5),
FE::from(6),
FE::from(7),
FE::from(8),
FE::from(9),
FE::from(10),
FE::from(11),
FE::from(12),
FE::from(13),
FE::from(14),
FE::from(15),
];

let mut slice = [
FE::from(0),
FE::from(2),
FE::from(4),
FE::from(6),
FE::from(8),
FE::from(10),
FE::from(12),
FE::from(14),
FE::from(15),
FE::from(13),
FE::from(11),
FE::from(9),
FE::from(7),
FE::from(5),
FE::from(3),
FE::from(1),
];

let res = order_cfft_result_naive(&mut slice);

assert_eq!(res, expected_slice)
}

#[test]
fn reverse_cfft_index_works() {
let mut reversed: Vec<usize> = Vec::with_capacity(16);
for i in 0..reversed.capacity() {
reversed.push(reverse_cfft_index(i, reversed.capacity()));
}
assert_eq!(
reversed[..],
[0, 2, 4, 6, 8, 10, 12, 14, 15, 13, 11, 9, 7, 5, 3, 1]
);
}
}

pub fn reverse_cfft_index(index: usize, log_2_size: u32) -> usize {
let (mut new_index, lsb) = (index >> 1, index & 1);
if (lsb == 1) & (log_2_size > 1) {
new_index = (1 << log_2_size) - new_index - 1;
#[test]
fn from_natural_to_icfft_input_order() {
let mut slice = [
FE::from(0),
FE::from(1),
FE::from(2),
FE::from(3),
FE::from(4),
FE::from(5),
FE::from(6),
FE::from(7),
FE::from(8),
FE::from(9),
FE::from(10),
FE::from(11),
FE::from(12),
FE::from(13),
FE::from(14),
FE::from(15),
];

let expected_slice = [
FE::from(0),
FE::from(2),
FE::from(4),
FE::from(6),
FE::from(8),
FE::from(10),
FE::from(12),
FE::from(14),
FE::from(15),
FE::from(13),
FE::from(11),
FE::from(9),
FE::from(7),
FE::from(5),
FE::from(3),
FE::from(1),
];

let res = order_icfft_input_naive(&mut slice);

assert_eq!(res, expected_slice)
}
new_index.reverse_bits() >> (usize::BITS - log_2_size)


}
45 changes: 31 additions & 14 deletions math/src/circle/polynomial.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field};
use crate::{
field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field},
fft::cpu::bit_reversing::in_place_bit_reverse_permute
};

use super::{
cfft::{cfft_4, cfft_8, inplace_cfft, inplace_order_cfft_values},
cfft::{cfft, icfft, cfft_4, cfft_8, order_cfft_result_naive, order_icfft_input_naive},
cosets::Coset,
twiddles::{
get_twiddles, get_twiddles_itnerpolation_4, get_twiddles_itnerpolation_8, TwiddlesConfig,
Expand All @@ -14,14 +17,15 @@ use super::{
pub fn evaluate_cfft(
mut coeff: Vec<FieldElement<Mersenne31Field>>,
) -> Vec<FieldElement<Mersenne31Field>> {
in_place_bit_reverse_permute::<FieldElement<Mersenne31Field>>(&mut coeff);
let domain_log_2_size: u32 = coeff.len().trailing_zeros();
let coset = Coset::new_standard(domain_log_2_size);
let config = TwiddlesConfig::Evaluation;
let twiddles = get_twiddles(coset, config);

inplace_cfft(&mut coeff, twiddles);
inplace_order_cfft_values(&mut coeff);
coeff
cfft(&mut coeff, twiddles);
let result = order_cfft_result_naive(&mut coeff);
result
}

/// Interpolates the 2^n evaluations of a two-variables polynomial on the points of the standard coset of size 2^n.
Expand All @@ -30,14 +34,15 @@ pub fn evaluate_cfft(
pub fn interpolate_cfft(
mut eval: Vec<FieldElement<Mersenne31Field>>,
) -> Vec<FieldElement<Mersenne31Field>> {
let mut eval_ordered = order_icfft_input_naive(&mut eval);
let domain_log_2_size: u32 = eval.len().trailing_zeros();
let coset = Coset::new_standard(domain_log_2_size);
let config = TwiddlesConfig::Interpolation;
let twiddles = get_twiddles(coset, config);

inplace_cfft(&mut eval, twiddles);
inplace_order_cfft_values(&mut eval);
eval
icfft(&mut eval_ordered, twiddles);
let result = order_cfft_result_naive(&mut eval);
result
}

pub fn interpolate_4(
Expand Down Expand Up @@ -116,7 +121,7 @@ mod tests {
// We create the coset points and evaluate them without the fft.
let coset = Coset::new_standard(2);
let points = Coset::get_coset_points(&coset);
let mut input = [FpE::from(1), FpE::from(2), FpE::from(3), FpE::from(4)];
let input = [FpE::from(1), FpE::from(2), FpE::from(3), FpE::from(4)];
let mut expected_result: Vec<FpE> = Vec::new();
for point in points {
let point_eval = evaluate_poly_4(&input, point.x, point.y);
Expand All @@ -133,7 +138,7 @@ mod tests {
// We create the coset points and evaluate them without the fft.
let coset = Coset::new_standard(3);
let points = Coset::get_coset_points(&coset);
let mut input = [
let input = [
FpE::from(1),
FpE::from(2),
FpE::from(3),
Expand All @@ -158,7 +163,7 @@ mod tests {
fn cfft_evaluation_16_points() {
let coset = Coset::new_standard(4);
let points = Coset::get_coset_points(&coset);
let mut input = [
let input = [
FpE::from(1),
FpE::from(2),
FpE::from(3),
Expand Down Expand Up @@ -231,8 +236,20 @@ mod tests {
}

#[test]
fn cuentas() {
println!("{:?}", FpE::from(32768).inv().unwrap()); // { value: 65536 }
println!("{:?}", FpE::from(2147450879).inv().unwrap()); // { value: 2147418111 }
fn evaluate_and_interpolate() {
let coeff = vec![
FpE::from(1),
FpE::from(2),
FpE::from(3),
FpE::from(4),
FpE::from(5),
FpE::from(6),
FpE::from(7),
FpE::from(8),
];
let evals = evaluate_cfft(coeff.clone());
let new_coeff = interpolate_cfft(evals);

assert_eq!(coeff, new_coeff);
}
}
14 changes: 7 additions & 7 deletions math/src/circle/twiddles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,32 @@ pub fn get_twiddles(
domain: Coset,
config: TwiddlesConfig,
) -> Vec<Vec<FieldElement<Mersenne31Field>>> {
let mut half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone()));
if config == TwiddlesConfig::Evaluation {
in_place_bit_reverse_permute::<CirclePoint<Mersenne31Field>>(&mut half_domain_points[..]);
}
let half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone()));

let mut twiddles: Vec<Vec<FieldElement<Mersenne31Field>>> =
vec![half_domain_points.iter().map(|p| p.y).collect()];

if domain.log_2_size >= 2 {
twiddles.push(half_domain_points.iter().step_by(2).map(|p| p.x).collect());
twiddles.push(half_domain_points.iter().take(half_domain_points.len() / 2 ).map(|p| p.x).collect());
for _ in 0..(domain.log_2_size - 2) {
let prev = twiddles.last().unwrap();
let cur = prev
.iter()
.step_by(2)
.take(prev.len() / 2 )
.map(|x| x.square().double() - FieldElement::<Mersenne31Field>::one())
.collect();
twiddles.push(cur);
}
}
twiddles.reverse();

if config == TwiddlesConfig::Interpolation {
// For the interpolation, we need to take the inverse element of each twiddle in the default order.
twiddles.iter_mut().for_each(|x| {
FieldElement::<Mersenne31Field>::inplace_batch_inverse(x).unwrap();
});
} else {
// For the evaluation, we need the vector of twiddles but in the inverse order.
twiddles.reverse();
}
twiddles
}
Expand Down

0 comments on commit dc4124e

Please sign in to comment.