Skip to content

Commit

Permalink
merge develop branch
Browse files Browse the repository at this point in the history
  • Loading branch information
unisolate committed Jun 30, 2021
1 parent 49a0339 commit 4232259
Show file tree
Hide file tree
Showing 3 changed files with 1,167 additions and 57 deletions.
200 changes: 143 additions & 57 deletions src/gsv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@
#include "support.h"

#define SM2 // enable optimization for SM2 (a=-3)
// #define BIT256 // use 256-bit integer instead of 512-bit

// #define BIT256

#ifdef BIT256 // use 256-bit integer instead of 512-bit
#define SM2_BITS 256
#else
#define SM2_BITS 512
#endif

#define TABLE_SIZE 512

// The CGBN context uses the following three parameters:
// TBP - threads per block (zero means to use the blockDim.x)
Expand All @@ -22,6 +31,9 @@
// TPI - threads per instance
// BITS - number of bits per instance

__constant__ cgbn_mem_t<SM2_BITS> d_mul_table[TABLE_SIZE];
__constant__ cgbn_mem_t<SM2_BITS> d_mul_table2[TABLE_SIZE];

template <uint32_t tpi, uint32_t bits>
class gsv_params_t {
public:
Expand All @@ -44,14 +56,14 @@ class gsv_t {
cgbn_mem_t<params::BITS> r; // sig->r
cgbn_mem_t<params::BITS> s; // sig->s
cgbn_mem_t<params::BITS> e; // digest
cgbn_mem_t<params::BITS> key_x; // public key
cgbn_mem_t<params::BITS> key_y; // public key
// cgbn_mem_t<params::BITS> key_x; // public key
// cgbn_mem_t<params::BITS> key_y; // public key
} instance_t;

typedef struct {
cgbn_mem_t<params::BITS> order; // group order
cgbn_mem_t<params::BITS> g_x; // base point (generator)
cgbn_mem_t<params::BITS> g_y; // base point (generator)
// cgbn_mem_t<params::BITS> g_x; // base point (generator)
// cgbn_mem_t<params::BITS> g_y; // base point (generator)
cgbn_mem_t<params::BITS> field; // prime p
cgbn_mem_t<params::BITS> g_a;
} ec_t;
Expand Down Expand Up @@ -502,6 +514,46 @@ class gsv_t {
const bn_t &p_z, const bn_t &d, const bn_t &field, const bn_t &g_a,
const uint32_t np0) {}

// Fixed-point multiplication, can be applied for the generator point
// Expected complexity: n/2 * A
// Table size: 32 KB (512 bn)
__device__ __forceinline__ void fixed_point_mult(bn_t &r_x, bn_t &r_y, bn_t &r_z, bn_t &k, const bn_t &field, const bn_t &g_a, const uint32_t np0) {
int i = 0;
bn_t q_x, q_y, one;

_env.set(one, r_z);
_env.set_ui32(r_z, 0);

while (_env.compare_ui32(k, 0) > 0) {
if (_env.ctz(k) == 0) { // k_i = 1
_env.load(q_x, &d_mul_table[i * 2]);
_env.load(q_y, &d_mul_table[i * 2 + 1]);
point_add_ipp(r_x, r_y, r_z, r_x, r_y, r_z, q_x, q_y, one, field, g_a, np0);
}
_env.shift_right(k, k, 1);
i++;
}
}

__device__ __forceinline__ void fixed_point_mult2(bn_t &r_x, bn_t &r_y, bn_t &r_z, const bn_t &d, const bn_t &field, const bn_t &g_a, const uint32_t np0) {
int i = 0;
bn_t k, q_x, q_y, one;

_env.set(one, r_z);
_env.set(k, d);
_env.set_ui32(r_z, 0);

while (_env.compare_ui32(k, 0) > 0) {
if (_env.ctz(k) == 0) { // k_i = 1
_env.load(q_x, &d_mul_table2[i * 2]);
_env.load(q_y, &d_mul_table2[i * 2 + 1]);
point_add_ipp(r_x, r_y, r_z, r_x, r_y, r_z, q_x, q_y, one, field, g_a, np0);
}
_env.shift_right(k, k, 1);
i++;
}
}

// transform (X, Y, Z) into (x, y) := (X/Z^2, Y/Z^3)
__device__ __forceinline__ void conv_affine_x_y(bn_t &a_x, bn_t &a_y, const bn_t &j_x, const bn_t &j_y, const bn_t &j_z,
const bn_t &field, const uint32_t np0) {
Expand Down Expand Up @@ -546,7 +598,7 @@ class gsv_t {
_env.bn2mont(Z_1, Z_1, field);
_env.mont_sqr(Z_2, Z_1, field, np0);
_env.mont_mul(a_x, j_x, Z_2, field, np0);
_env.mont2bn(a_x, a_x, field, np0);
// _env.mont2bn(a_x, a_x, field, np0);
}

// _env.modular_inverse(Z_1, j_z, field);
Expand All @@ -572,22 +624,32 @@ class gsv_t {
mod(y1, field);
_env.bn2mont(y1, y1, field);

_env.set(x2, key_x);
_env.set(y2, key_y);
mod(x2, field);
_env.bn2mont(x2, x2, field);
mod(y2, field);
_env.bn2mont(y2, y2, field);
// __syncthreads();

// point_dbl_ipp(x1, y1, z1, x1, y1, z1, field, g_a, np0);

// __syncthreads();

// conv_affine_x(x1, x1, z1, field, np0);

_env.set(tmp, x1);

// _env.set(x2, key_x);
// _env.set(y2, key_y);
// mod(x2, field);
// _env.bn2mont(x2, x2, field);
// mod(y2, field);
// _env.bn2mont(y2, y2, field);

// point_add(x1, y1, z1, x1, y1, one, x2, y2, one, field, g_a, np0);
point_add(x1, y1, z1, x2, y2, one, x1, y1, one, field, g_a, np0);
// point_add(x1, y1, z1, x2, y2, one, x1, y1, one, field, g_a, np0);
// point_add_ipp(x1, y1, z1, x1, y1, one, x2, y2, one, field, g_a, np0);
// point_add_ipp(x1, y1, z1, x2, y2, one, x1, y1, one, field, g_a, np0);
// point_add(x1, y1, z1, x1, y1, one, one, one, zero, field, g_a, np0);
// point_add(x1, y1, z1, one, one, zero, x1, y1, one, field, g_a, np0);

// _env.set(tmp, z1);
conv_affine_x(tmp, x1, z1, field, np0);
// conv_affine_x(tmp, x1, z1, field, np0);

// point_add(x1, y1, z1, x1, y1, z1, r, s, one, field, g_a, np0);
// point_add_ipp(x1, y1, z1, x1, y1, z1, r, s, one, field, g_a, np0);
Expand All @@ -610,9 +672,7 @@ class gsv_t {
const bn_t &order, const bn_t &g_x, const bn_t &g_y, const bn_t &field,
bn_t &g_a, bn_t &tmp)
#else
__device__ __forceinline__ int32_t sig_verify(const bn_t &r, const bn_t &s, const bn_t &e, const bn_t &key_x,
const bn_t &key_y, const bn_t &order, const bn_t &g_x, const bn_t &g_y,
const bn_t &field, bn_t &g_a)
__device__ __forceinline__ int32_t sig_verify(const bn_t &r, bn_t &s, const bn_t &e, const bn_t &order, const bn_t &field, bn_t &g_a)
#endif
{
bn_t t, x1, y1, z1, x2, y2, z2;
Expand All @@ -633,35 +693,39 @@ class gsv_t {
np0 = _env.bn2mont(z1, z1, field);
_env.set(z2, z1);

mod(g_a, field);
// mod(g_a, field); // unnecessary
_env.bn2mont(g_a, g_a, field);

// s * generator + t * pkey
_env.set(x1, g_x);
_env.set(y1, g_y);
mod(x1, field);
_env.bn2mont(x1, x1, field);
mod(y1, field);
_env.bn2mont(y1, y1, field);
point_mult_naf(x1, y1, z1, x1, y1, z1, s, field, g_a, np0);
// _env.set(x1, g_x);
// _env.set(y1, g_y);
// mod(x1, field);
// _env.bn2mont(x1, x1, field);
// mod(y1, field);
// _env.bn2mont(y1, y1, field);

fixed_point_mult(x1, y1, z1, s, field, g_a, np0);

__syncthreads(); // TODO: temp fix of wrong answer, need to test on different input

_env.set(x2, key_x);
_env.set(y2, key_y);
mod(x2, field);
_env.bn2mont(x2, x2, field);
mod(y2, field);
_env.bn2mont(y2, y2, field);
point_mult_naf(x2, y2, z2, x2, y2, z2, t, field, g_a, np0);
// _env.set(x2, key_x);
// _env.set(y2, key_y);
// mod(x2, field);
// _env.bn2mont(x2, x2, field);
// mod(y2, field);
// _env.bn2mont(y2, y2, field);

point_add(x1, y1, z1, x1, y1, z1, x2, y2, z2, field, g_a, np0);
fixed_point_mult2(x2, y2, z2, t, field, g_a, np0);

conv_affine_x(x1, x1, z1, field, np0);
point_add_ipp(x1, y1, z1, x1, y1, z1, x2, y2, z2, field, g_a, np0);

mod_add(t, e, x1, order);
// avoid coordinate transformation by converting (r-e) to Jacobian

return _env.compare(r, t);
mod_sub(t, r, e, order);
_env.bn2mont(t, t, field);
_env.mont_sqr(z1, z1, field, np0);
_env.mont_mul(t, t, z1, field, np0);
return _env.compare(x1, t);
}

__host__ static instance_t *generate_instances(uint32_t count) {
Expand All @@ -673,10 +737,10 @@ class gsv_t {
params::BITS / 32);
set_words(instances[index].s._limbs, "E11F5909F947D5BE08C84A22CE9F7C338F7CF4A5B941B9268025495D7D433071",
params::BITS / 32);
set_words(instances[index].key_x._limbs, "D5548C7825CBB56150A3506CD57464AF8A1AE0519DFAF3C58221DC810CAF28DD",
params::BITS / 32);
set_words(instances[index].key_y._limbs, "921073768FE3D59CE54E79A49445CF73FED23086537027264D168946D479533E",
params::BITS / 32);
// set_words(instances[index].key_x._limbs, "D5548C7825CBB56150A3506CD57464AF8A1AE0519DFAF3C58221DC810CAF28DD",
// params::BITS / 32);
// set_words(instances[index].key_y._limbs, "921073768FE3D59CE54E79A49445CF73FED23086537027264D168946D479533E",
// params::BITS / 32);
set_words(instances[index].e._limbs, "10D51CB90C0C0522E94875A2BEA7AB72299EBE7192E64EFE0573B1C77110E5C9",
params::BITS / 32);
#else
Expand Down Expand Up @@ -708,6 +772,22 @@ class gsv_t {
return instances;
}

static cgbn_mem_t<params::BITS> *prepare_table() {
cgbn_mem_t<params::BITS> *mul_table = (cgbn_mem_t<params::BITS>*)malloc(sizeof(cgbn_mem_t<params::BITS>) * TABLE_SIZE);

#include "sm2_base_512.table"

return mul_table;
}

static cgbn_mem_t<params::BITS> *prepare_table2() {
cgbn_mem_t<params::BITS> *mul_table = (cgbn_mem_t<params::BITS>*)malloc(sizeof(cgbn_mem_t<params::BITS>) * TABLE_SIZE);

#include "sm2_pkey_512.table"

return mul_table;
}

__host__ static void verify_results(instance_t *instances, uint32_t count, int32_t *results) {
for (int index = 0; index < count; index++) {
int openssl_result = -1;
Expand Down Expand Up @@ -736,7 +816,7 @@ __global__ void kernel_sig_verify(cgbn_error_report_t *report, typename gsv_t<pa
typedef gsv_t<params> local_gsv_t;

local_gsv_t gsv(cgbn_report_monitor, report, instance);
typename local_gsv_t::bn_t r, s, e, key_x, key_y, order, g_x, g_y, field, g_a;
typename local_gsv_t::bn_t r, s, e, order, field, g_a;

#ifdef DEBUG
typename local_gsv_t::bn_t tmp;
Expand All @@ -745,21 +825,21 @@ __global__ void kernel_sig_verify(cgbn_error_report_t *report, typename gsv_t<pa
cgbn_load(gsv._env, r, &(instances[instance].r));
cgbn_load(gsv._env, s, &(instances[instance].s));
cgbn_load(gsv._env, e, &(instances[instance].e));
cgbn_load(gsv._env, key_x, &(instances[instance].key_x));
cgbn_load(gsv._env, key_y, &(instances[instance].key_y));
// cgbn_load(gsv._env, key_x, &(instances[instance].key_x));
// cgbn_load(gsv._env, key_y, &(instances[instance].key_y));

cgbn_load(gsv._env, order, &(ec.order));
cgbn_load(gsv._env, g_x, &(ec.g_x));
cgbn_load(gsv._env, g_y, &(ec.g_y));
// cgbn_load(gsv._env, g_x, &(ec.g_x));
// cgbn_load(gsv._env, g_y, &(ec.g_y));
cgbn_load(gsv._env, field, &(ec.field));
cgbn_load(gsv._env, g_a, &(ec.g_a));

#ifdef DEBUG
results[instance] = gsv.sig_verify(r, s, e, key_x, key_y, order, g_x, g_y, field, g_a, tmp);
// results[instance] = gsv.debug_kernel(r, s, e, key_x, key_y, order, g_x, g_y, field, g_a, tmp);
// results[instance] = gsv.sig_verify(r, s, e, key_x, key_y, order, g_x, g_y, field, g_a, tmp);
results[instance] = gsv.debug_kernel(r, s, e, key_x, key_y, order, g_x, g_y, field, g_a, tmp);
cgbn_store(gsv._env, &(instances[instance].r), tmp);
#else
results[instance] = gsv.sig_verify(r, s, e, key_x, key_y, order, g_x, g_y, field, g_a);
results[instance] = gsv.sig_verify(r, s, e, order, field, g_a);
#endif
}

Expand All @@ -771,17 +851,21 @@ void test_sig_verify(uint32_t instance_count, typename gsv_t<params>::instance_t

instance_t *instances;
ec_t sm2;
cgbn_mem_t<params::BITS> *mul_table;
cgbn_mem_t<params::BITS> *mul_table2;
int32_t *results; // signature verification result, 0 is true, 1 is false
int32_t TPB = (params::TPB == 0) ? 128 : params::TPB; // default threads per block is 128
int32_t TPI = params::TPI, IPB = TPB / TPI; // IPB: instances per block

results = (int32_t *)malloc(sizeof(int32_t) * instance_count);
instances = gsv_t<params>::generate_instances(instance_count);
mul_table = gsv_t<params>::prepare_table();
mul_table2 = gsv_t<params>::prepare_table2();

#ifdef SM2
set_words(sm2.order._limbs, "FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123", params::BITS / 32);
set_words(sm2.g_x._limbs, "32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7", params::BITS / 32);
set_words(sm2.g_y._limbs, "BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0", params::BITS / 32);
// set_words(sm2.g_x._limbs, "32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7", params::BITS / 32);
// set_words(sm2.g_y._limbs, "BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0", params::BITS / 32);
set_words(sm2.field._limbs, "FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", params::BITS / 32);
set_words(sm2.g_a._limbs, "FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC", params::BITS / 32);
#else
Expand All @@ -801,6 +885,9 @@ void test_sig_verify(uint32_t instance_count, typename gsv_t<params>::instance_t

CUDA_CHECK(cudaMemcpy(d_instances, instances, sizeof(instance_t) * instance_count, cudaMemcpyHostToDevice));

CUDA_CHECK(cudaMemcpyToSymbol(d_mul_table, mul_table, sizeof(cgbn_mem_t<params::BITS>) * TABLE_SIZE));
CUDA_CHECK(cudaMemcpyToSymbol(d_mul_table2, mul_table2, sizeof(cgbn_mem_t<params::BITS>) * TABLE_SIZE));

auto k_start = std::chrono::high_resolution_clock::now();

kernel_sig_verify<params><<<(instance_count + IPB - 1) / IPB, TPB>>>(report, d_instances, instance_count, sm2, d_results);
Expand Down Expand Up @@ -828,9 +915,11 @@ void test_sig_verify(uint32_t instance_count, typename gsv_t<params>::instance_t

free(instances);
free(results);
free(mul_table);
free(mul_table2);
}

#define MAX_INS 262144
#define MAX_INS 1048576

int main(int argc, char **argv) {
int device_id = 0;
Expand All @@ -839,11 +928,8 @@ int main(int argc, char **argv) {
}
CUDA_CHECK(cudaSetDevice(device_id));

#ifdef BIT256
typedef gsv_params_t<16, 256> params; // threads per instance, instance size
#else
typedef gsv_params_t<16, 512> params; // threads per instance, instance size
#endif
typedef gsv_params_t<4, SM2_BITS> params; // threads per instance, instance size

typedef typename gsv_t<params>::instance_t instance_t;

instance_t *d_instances;
Expand All @@ -854,7 +940,7 @@ int main(int argc, char **argv) {
CUDA_CHECK(cudaMalloc((void **)&d_results, sizeof(int32_t) * MAX_INS));
CUDA_CHECK(cgbn_error_report_alloc(&report));

test_sig_verify<params>(256, d_instances, d_results, report);
test_sig_verify<params>(4096, d_instances, d_results, report);

// test_sig_verify<params>(32768, d_instances, d_results, report);

Expand Down
Loading

0 comments on commit 4232259

Please sign in to comment.