Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ChickenLover committed Oct 1, 2024
1 parent ff8d4eb commit 8e166df
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 49 deletions.
2 changes: 1 addition & 1 deletion icicle/include/curves/macro.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
typedef Affine<point_field_t> affine_t;

#define G2_CURVE_DEFINITIONS \
typedef ExtensionField<fq_config, point_field_t> g2_point_field_t; \
typedef ComplexExtensionField<fq_config, point_field_t> g2_point_field_t; \
static constexpr g2_point_field_t g2_generator_x = \
g2_point_field_t{point_field_t{g2_gen_x_re}, point_field_t{g2_gen_x_im}}; \
static constexpr g2_point_field_t g2_generator_y = \
Expand Down
86 changes: 43 additions & 43 deletions icicle/include/fields/quartic_extension.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "gpu-utils/sharedmem.cuh"

template <typename CONFIG, class T>
class QuartExtensionField
class QuarticExtensionField
{
private:
typedef typename T::Wide FWide;
Expand Down Expand Up @@ -36,88 +36,88 @@ public:
FF im2;
FF im3;

static constexpr HOST_DEVICE_INLINE QuartExtensionField zero()
static constexpr HOST_DEVICE_INLINE QuarticExtensionField zero()
{
return QuartExtensionField{FF::zero(), FF::zero(), FF::zero(), FF::zero()};
return QuarticExtensionField{FF::zero(), FF::zero(), FF::zero(), FF::zero()};
}

static constexpr HOST_DEVICE_INLINE QuartExtensionField one()
static constexpr HOST_DEVICE_INLINE QuarticExtensionField one()
{
return QuartExtensionField{FF::one(), FF::zero(), FF::zero(), FF::zero()};
return QuarticExtensionField{FF::one(), FF::zero(), FF::zero(), FF::zero()};
}

static constexpr HOST_DEVICE_INLINE QuartExtensionField to_montgomery(const QuartExtensionField& xs)
static constexpr HOST_DEVICE_INLINE QuarticExtensionField to_montgomery(const QuarticExtensionField& xs)
{
return QuartExtensionField{
return QuarticExtensionField{
FF::to_montgomery(xs.real), FF::to_montgomery(xs.im1), FF::to_montgomery(xs.im2), FF::to_montgomery(xs.im3)};
}

static constexpr HOST_DEVICE_INLINE QuartExtensionField from_montgomery(const QuartExtensionField& xs)
static constexpr HOST_DEVICE_INLINE QuarticExtensionField from_montgomery(const QuarticExtensionField& xs)
{
return QuartExtensionField{
return QuarticExtensionField{
FF::from_montgomery(xs.real), FF::from_montgomery(xs.im1), FF::from_montgomery(xs.im2),
FF::from_montgomery(xs.im3)};
}

static HOST_INLINE QuartExtensionField rand_host()
static HOST_INLINE QuarticExtensionField rand_host()
{
return QuartExtensionField{FF::rand_host(), FF::rand_host(), FF::rand_host(), FF::rand_host()};
return QuarticExtensionField{FF::rand_host(), FF::rand_host(), FF::rand_host(), FF::rand_host()};
}

static void rand_host_many(QuartExtensionField* out, int size)
static void rand_host_many(QuarticExtensionField* out, int size)
{
for (int i = 0; i < size; i++)
out[i] = rand_host();
}

template <unsigned REDUCTION_SIZE = 1>
static constexpr HOST_DEVICE_INLINE QuartExtensionField sub_modulus(const QuartExtensionField& xs)
static constexpr HOST_DEVICE_INLINE QuarticExtensionField sub_modulus(const QuarticExtensionField& xs)
{
return QuartExtensionField{
return QuarticExtensionField{
FF::sub_modulus<REDUCTION_SIZE>(&xs.real), FF::sub_modulus<REDUCTION_SIZE>(&xs.im1),
FF::sub_modulus<REDUCTION_SIZE>(&xs.im2), FF::sub_modulus<REDUCTION_SIZE>(&xs.im3)};
}

friend std::ostream& operator<<(std::ostream& os, const QuartExtensionField& xs)
friend std::ostream& operator<<(std::ostream& os, const QuarticExtensionField& xs)
{
os << "{ Real: " << xs.real << " }; { Im1: " << xs.im1 << " }; { Im2: " << xs.im2 << " }; { Im3: " << xs.im3
<< " };";
return os;
}

friend HOST_DEVICE_INLINE QuartExtensionField operator+(QuartExtensionField xs, const QuartExtensionField& ys)
friend HOST_DEVICE_INLINE QuarticExtensionField operator+(QuarticExtensionField xs, const QuarticExtensionField& ys)
{
return QuartExtensionField{xs.real + ys.real, xs.im1 + ys.im1, xs.im2 + ys.im2, xs.im3 + ys.im3};
return QuarticExtensionField{xs.real + ys.real, xs.im1 + ys.im1, xs.im2 + ys.im2, xs.im3 + ys.im3};
}

friend HOST_DEVICE_INLINE QuartExtensionField operator-(QuartExtensionField xs, const QuartExtensionField& ys)
friend HOST_DEVICE_INLINE QuarticExtensionField operator-(QuarticExtensionField xs, const QuarticExtensionField& ys)
{
return QuartExtensionField{xs.real - ys.real, xs.im1 - ys.im1, xs.im2 - ys.im2, xs.im3 - ys.im3};
return QuarticExtensionField{xs.real - ys.real, xs.im1 - ys.im1, xs.im2 - ys.im2, xs.im3 - ys.im3};
}

friend HOST_DEVICE_INLINE QuartExtensionField operator+(FF xs, const QuartExtensionField& ys)
friend HOST_DEVICE_INLINE QuarticExtensionField operator+(FF xs, const QuarticExtensionField& ys)
{
return QuartExtensionField{xs + ys.real, ys.im1, ys.im2, ys.im3};
return QuarticExtensionField{xs + ys.real, ys.im1, ys.im2, ys.im3};
}

friend HOST_DEVICE_INLINE QuartExtensionField operator-(FF xs, const QuartExtensionField& ys)
friend HOST_DEVICE_INLINE QuarticExtensionField operator-(FF xs, const QuarticExtensionField& ys)
{
return QuartExtensionField{xs - ys.real, FF::neg(ys.im1), FF::neg(ys.im2), FF::neg(ys.im3)};
return QuarticExtensionField{xs - ys.real, FF::neg(ys.im1), FF::neg(ys.im2), FF::neg(ys.im3)};
}

friend HOST_DEVICE_INLINE QuartExtensionField operator+(QuartExtensionField xs, const FF& ys)
friend HOST_DEVICE_INLINE QuarticExtensionField operator+(QuarticExtensionField xs, const FF& ys)
{
return QuartExtensionField{xs.real + ys, xs.im1, xs.im2, xs.im3};
return QuarticExtensionField{xs.real + ys, xs.im1, xs.im2, xs.im3};
}

friend HOST_DEVICE_INLINE QuartExtensionField operator-(QuartExtensionField xs, const FF& ys)
friend HOST_DEVICE_INLINE QuarticExtensionField operator-(QuarticExtensionField xs, const FF& ys)
{
return QuartExtensionField{xs.real - ys, xs.im1, xs.im2, xs.im3};
return QuarticExtensionField{xs.real - ys, xs.im1, xs.im2, xs.im3};
}

template <unsigned MODULUS_MULTIPLE = 1>
static constexpr HOST_DEVICE_INLINE ExtensionWide
mul_wide(const QuartExtensionField& xs, const QuartExtensionField& ys)
mul_wide(const QuarticExtensionField& xs, const QuarticExtensionField& ys)
{
if (CONFIG::nonresidue_is_negative)
return ExtensionWide{
Expand All @@ -144,74 +144,74 @@ public:
}

template <unsigned MODULUS_MULTIPLE = 1>
static constexpr HOST_DEVICE_INLINE ExtensionWide mul_wide(const QuartExtensionField& xs, const FF& ys)
static constexpr HOST_DEVICE_INLINE ExtensionWide mul_wide(const QuarticExtensionField& xs, const FF& ys)
{
return ExtensionWide{
FF::mul_wide(xs.real, ys), FF::mul_wide(xs.im1, ys), FF::mul_wide(xs.im2, ys), FF::mul_wide(xs.im3, ys)};
}

template <unsigned MODULUS_MULTIPLE = 1>
static constexpr HOST_DEVICE_INLINE ExtensionWide mul_wide(const FF& xs, const QuartExtensionField& ys)
static constexpr HOST_DEVICE_INLINE ExtensionWide mul_wide(const FF& xs, const QuarticExtensionField& ys)
{
return ExtensionWide{
FF::mul_wide(xs, ys.real), FF::mul_wide(xs, ys.im1), FF::mul_wide(xs, ys.im2), FF::mul_wide(xs, ys.im3)};
}

template <unsigned MODULUS_MULTIPLE = 1>
static constexpr HOST_DEVICE_INLINE QuartExtensionField reduce(const ExtensionWide& xs)
static constexpr HOST_DEVICE_INLINE QuarticExtensionField reduce(const ExtensionWide& xs)
{
return QuartExtensionField{
return QuarticExtensionField{
FF::template reduce<MODULUS_MULTIPLE>(xs.real), FF::template reduce<MODULUS_MULTIPLE>(xs.im1),
FF::template reduce<MODULUS_MULTIPLE>(xs.im2), FF::template reduce<MODULUS_MULTIPLE>(xs.im3)};
}

template <class T1, class T2>
friend HOST_DEVICE_INLINE QuartExtensionField operator*(const T1& xs, const T2& ys)
friend HOST_DEVICE_INLINE QuarticExtensionField operator*(const T1& xs, const T2& ys)
{
ExtensionWide xy = mul_wide(xs, ys);
return reduce(xy);
}

friend HOST_DEVICE_INLINE bool operator==(const QuartExtensionField& xs, const QuartExtensionField& ys)
friend HOST_DEVICE_INLINE bool operator==(const QuarticExtensionField& xs, const QuarticExtensionField& ys)
{
return (xs.real == ys.real) && (xs.im1 == ys.im1) && (xs.im2 == ys.im2) && (xs.im3 == ys.im3);
}

friend HOST_DEVICE_INLINE bool operator!=(const QuartExtensionField& xs, const QuartExtensionField& ys)
friend HOST_DEVICE_INLINE bool operator!=(const QuarticExtensionField& xs, const QuarticExtensionField& ys)
{
return !(xs == ys);
}

template <uint32_t multiplier, unsigned REDUCTION_SIZE = 1>
static constexpr HOST_DEVICE_INLINE QuartExtensionField mul_unsigned(const QuartExtensionField& xs)
static constexpr HOST_DEVICE_INLINE QuarticExtensionField mul_unsigned(const QuarticExtensionField& xs)
{
return {
FF::template mul_unsigned<multiplier>(xs.real), FF::template mul_unsigned<multiplier>(xs.im1),
FF::template mul_unsigned<multiplier>(xs.im2), FF::template mul_unsigned<multiplier>(xs.im3)};
}

template <unsigned MODULUS_MULTIPLE = 1>
static constexpr HOST_DEVICE_INLINE ExtensionWide sqr_wide(const QuartExtensionField& xs)
static constexpr HOST_DEVICE_INLINE ExtensionWide sqr_wide(const QuarticExtensionField& xs)
{
// TODO: change to a more efficient squaring
return mul_wide<MODULUS_MULTIPLE>(xs, xs);
}

template <unsigned MODULUS_MULTIPLE = 1>
static constexpr HOST_DEVICE_INLINE QuartExtensionField sqr(const QuartExtensionField& xs)
static constexpr HOST_DEVICE_INLINE QuarticExtensionField sqr(const QuarticExtensionField& xs)
{
// TODO: change to a more efficient squaring
return xs * xs;
}

template <unsigned MODULUS_MULTIPLE = 1>
static constexpr HOST_DEVICE_INLINE QuartExtensionField neg(const QuartExtensionField& xs)
static constexpr HOST_DEVICE_INLINE QuarticExtensionField neg(const QuarticExtensionField& xs)
{
return {FF::neg(xs.real), FF::neg(xs.im1), FF::neg(xs.im2), FF::neg(xs.im3)};
}

// inverse of zero is set to be zero which is what we want most of the time
static constexpr HOST_DEVICE_INLINE QuartExtensionField inverse(const QuartExtensionField& xs)
static constexpr HOST_DEVICE_INLINE QuarticExtensionField inverse(const QuarticExtensionField& xs)
{
FF x, x0, x2;
if (CONFIG::nonresidue_is_negative) {
Expand Down Expand Up @@ -251,10 +251,10 @@ public:
};

template <class CONFIG, class T>
struct SharedMemory<QuartExtensionField<CONFIG, T>> {
__device__ QuartExtensionField<CONFIG, T>* getPointer()
struct SharedMemory<QuarticExtensionField<CONFIG, T>> {
__device__ QuarticExtensionField<CONFIG, T>* getPointer()
{
extern __shared__ QuartExtensionField<CONFIG, T> s_ext4_scalar_[];
extern __shared__ QuarticExtensionField<CONFIG, T> s_ext4_scalar_[];
return s_ext4_scalar_;
}
};
2 changes: 1 addition & 1 deletion icicle/include/fields/stark_fields/babybear.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ namespace babybear {
/**
* Extension field of `scalar_t` enabled if `-DEXT_FIELD` env variable is.
*/
typedef QuartExtensionField<fp_config, scalar_t> extension_t;
typedef QuarticExtensionField<fp_config, scalar_t> extension_t;
} // namespace babybear
2 changes: 1 addition & 1 deletion icicle/include/fields/stark_fields/m31.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ namespace m31 {
/**
* Extension field of `scalar_t` enabled if `-DEXT_FIELD` env variable is.
*/
typedef QuartExtensionField<fp_config, scalar_t> q_extension_t;
typedef QuarticExtensionField<fp_config, scalar_t> q_extension_t;
} // namespace m31

template <typename CONFIG>
Expand Down
8 changes: 5 additions & 3 deletions icicle/src/ntt/kernel_ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1187,10 +1187,13 @@ namespace mxntt {
return CHK_LAST();
}

#ifdef DCCT
uint32_t twiddles_offset = 0;
#endif

// general case:
uint32_t nof_blocks = (1UL << (log_size - 9)) * (columns_batch ? ((batch_size + 31) / 32) * 32 : batch_size);
if (dit) {
uint32_t twiddles_offset = 0;
for (int i = 0; i < 5; i++) {
uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i];
uint32_t stride_log = 0;
Expand Down Expand Up @@ -1232,16 +1235,15 @@ namespace mxntt {
}
} else { // dif
bool first_run = false, prev_stage = false;
uint32_t twiddles_offset = 0;
for (int i = 4; i >= 0; i--) {
uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i];
uint32_t stride_log = 0;
for (int j = 0; j < i; j++)
stride_log += fast_tw ? STAGE_SIZES_HOST_FT[log_size][j] : STAGE_SIZES_HOST[log_size][j];
first_run = stage_size && !prev_stage;

uint32_t nof_ntt_blocks = (1 << log_size - stage_size) * (columns_batch ? 1 : batch_size);
#ifdef DCCT
uint32_t nof_ntt_blocks = (1 << log_size - stage_size) * (columns_batch ? 1 : batch_size);
if (stage_size == 6)
ntt64_dcct<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, basic_twiddles, log_size, tw_log_size, columns_batch ? batch_size : 0,
Expand Down

0 comments on commit 8e166df

Please sign in to comment.