diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..9acb4be --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "3rd/hyrax-bls12-381"] + path = 3rd/hyrax-bls12-381 + url = git@github.com:TAMUCrypto/hyrax-bls12-381.git diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..73f69e0 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml +# Editor-based HTTP Client requests +/httpRequests/ diff --git a/.idea/.name b/.idea/.name new file mode 100644 index 0000000..3ba80a4 --- /dev/null +++ b/.idea/.name @@ -0,0 +1 @@ +zkCNN \ No newline at end of file diff --git a/.idea/deployment.xml b/.idea/deployment.xml new file mode 100644 index 0000000..b9fc2ad --- /dev/null +++ b/.idea/deployment.xml @@ -0,0 +1,182 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..f1c67df --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..1d88d5d --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..dfb26aa --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/zkCNN-quant.iml b/.idea/zkCNN-quant.iml new file mode 100644 index 0000000..6d70257 --- /dev/null +++ b/.idea/zkCNN-quant.iml @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/3rd/hyrax-bls12-381 b/3rd/hyrax-bls12-381 new file mode 160000 index 0000000..baedf71 --- /dev/null +++ b/3rd/hyrax-bls12-381 @@ -0,0 +1 @@ +Subproject commit baedf71d215549e8b7147809c7fb807d315a8843 diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..e49db18 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,12 @@ +cmake_minimum_required(VERSION 3.10) +project(zkCNN) +set(CMAKE_CXX_STANDARD 14) + +link_directories(3rd/hyrax-bls12-381) + +include_directories(src) +include_directories(3rd) +include_directories(3rd/hyrax-bls12-381/3rd/mcl/include) + +add_subdirectory(src) +add_subdirectory(3rd/hyrax-bls12-381) \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..12dc852 --- /dev/null +++ b/README.md @@ -0,0 +1,83 @@ +# zkCNN + +## Introduction + +This is a GKR-based zero-knowledge proof for CNN reference, containing some widely used network such as LeNet5, vgg11 and vgg16. + + + +## Requirement + +- C++14 +- cmake >= 3.10 +- GMP library + + + +## Input Format + +The input has two part which are data and weight in the matrix. + +### Data Part + +There are two cases supported in this repo. + +- **Single picture** + + Then the picture is a vector reshaped from its original matrix by + + ![formula1](https://render.githubusercontent.com/render/math?math=ch_{in}%20%5Ccdot%20h\times%20w) + + where ![formula2](https://render.githubusercontent.com/render/math?math=ch_{in}) is the number of channel, ![formula3](https://render.githubusercontent.com/render/math?math=h) is the height, ![formula4](https://render.githubusercontent.com/render/math?math=w) is the width. + + + +- **Multiply picture** + + This solve the case when the user wants to infer multiple pictures by the same network. Then the picture is a vector reshaped from its original matrix by + + ![formula5](https://render.githubusercontent.com/render/math?math=n_{pic}%20\times%20ch_{in}%20\times%20h%20\times%20w) + + where ![formula6](https://render.githubusercontent.com/render/math?math=n_{pic}) is the number of pictures, ![formula7](https://render.githubusercontent.com/render/math?math=ch_{in}) is the number of channel, ![formula8](https://render.githubusercontent.com/render/math?math=h) is the height, ![formula9](https://render.githubusercontent.com/render/math?math=w) is the width. + +### Weight Part + +This part is for weight in the neural network, which contains + +- convolution kernel of size ![formula10](https://render.githubusercontent.com/render/math?math=ch_{out}%20\times%20ch_{in}%20\times%20m%20\times%20m) + + where ![formula11](https://render.githubusercontent.com/render/math?math=ch_{out}) and ![formula12](https://render.githubusercontent.com/render/math?math=ch_{in}) are the number of output and input channels, ![formula13](https://render.githubusercontent.com/render/math?math=m) is the sideness of the kernel (here we only support square kernel). + +- convolution bias of size ![formula16](https://render.githubusercontent.com/render/math?math=ch_{out}) + +- fully-connected kernel of size ![formula14](https://render.githubusercontent.com/render/math?math=ch_{in}\times%20ch_{out}) + + +- fully-connected bias of size ![formula15](https://render.githubusercontent.com/render/math?math=ch_{out}) + + +All the input above are scanned one by one. + +## Experiment Script +### Clone the repo +To run the code, make sure you clone with +``` bash +git clone git@github.com:TAMUCrypto/zkCNN.git +git submodule update --init --recursive +``` +since the polynomial commitment is included as a submodule. + +### Run a demo of vgg11 +The script to run vgg11 model (please run the script in ``script/`` directory). +``` bash +./demo.sh +``` + +- The input data is in ``data/vgg11/``. +- The experiment evaluation is ``output/single/demo-result.txt``. +- The inference result is ``output/single/vgg11.cifar.relu-1-infer.csv``. + +## Polynomial Commitment + +Here we implement a hyrax polynomial commitment based on BLS12-381 elliptic curve. It is a submodule and someone who is interested can refer to this repo [hyrax-bls12-381](https://github.com/TAMUCrypto/hyrax-bls12-381). + diff --git a/data.tar.gz b/data.tar.gz new file mode 100644 index 0000000..a62959b Binary files /dev/null and b/data.tar.gz differ diff --git a/script/build.sh b/script/build.sh new file mode 100644 index 0000000..3d02834 --- /dev/null +++ b/script/build.sh @@ -0,0 +1,12 @@ +#!/bin/bash +cd .. +mkdir -p cmake-build-release +cd cmake-build-release +/usr/bin/cmake -DCMAKE_BUILD_TYPE=Release -G "CodeBlocks - Unix Makefiles" .. +cd .. + +if [ ! -d "./data" ] +then + tar -xzvf data.tar.gz +fi +cd script \ No newline at end of file diff --git a/script/demo.sh b/script/demo.sh new file mode 100644 index 0000000..91ccc00 --- /dev/null +++ b/script/demo.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +set -x + +./build.sh +/usr/bin/cmake --build ../cmake-build-release --target demo_run -- -j 6 + +run_file=../cmake-build-release/src/demo_run +out_file=../output/single/demo-result.txt + +mkdir -p ../output/single +mkdir -p ../log/single + +vgg11_i=../data/vgg11/vgg11.cifar.relu-1-images-weights-qint8.csv +vgg11_c=../data/vgg11/vgg11.cifar.relu-1-scale-zeropoint-uint8.csv +vgg11_o=../output/single/vgg11.cifar.relu-1-infer.csv +vgg11_n=../data/vgg11/vgg11-config.csv + +${run_file} ${vgg11_i} ${vgg11_c} ${vgg11_o} ${vgg11_n} 1 > ${out_file} \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000..a2f2cae --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,7 @@ +aux_source_directory(. conv_src) +list(FILTER conv_src EXCLUDE REGEX "main*") + +add_library(cnn_lib ${conv_src}) + +add_executable(demo_run main_demo.cpp) +target_link_libraries(demo_run cnn_lib hyrax_lib mcl mclbn384_256) \ No newline at end of file diff --git a/src/circuit.cpp b/src/circuit.cpp new file mode 100644 index 0000000..fa48894 --- /dev/null +++ b/src/circuit.cpp @@ -0,0 +1,100 @@ +#include "circuit.h" +#include "utils.hpp" + +void layeredCircuit::initSubset() { + cerr << "begin subset init." << endl; + vector visited_uidx(circuit[0].size); // whether the i-th layer, j-th gate has been visited in the current layer + vector subset_uidx(circuit[0].size); // the subset index of the i-th layer, j-th gate + vector visited_vidx(circuit[0].size); // whether the i-th layer, j-th gate has been visited in the current layer + vector subset_vidx(circuit[0].size); // the subset index of the i-th layer, j-th gate + + for (u8 i = 1; i < size; ++i) { + auto &cur = circuit[i], &lst = circuit[i - 1]; + bool has_pre_layer_u = circuit[i].ty == layerType::FFT || circuit[i].ty == layerType::IFFT; + bool has_pre_layer_v = false; + + for (auto &gate: cur.uni_gates) { + if (!gate.lu) { + if (visited_uidx[gate.u] != i) { + visited_uidx[gate.u] = i; + subset_uidx[gate.u] = cur.size_u[0]; + cur.ori_id_u.push_back(gate.u); + ++cur.size_u[0]; + } + gate.u = subset_uidx[gate.u]; + } + has_pre_layer_u |= (gate.lu != 0); + } + + for (auto &gate: cur.bin_gates) { + if (!gate.getLayerIdU(i)) { + if (visited_uidx[gate.u] != i) { + visited_uidx[gate.u] = i; + subset_uidx[gate.u] = cur.size_u[0]; + cur.ori_id_u.push_back(gate.u); + ++cur.size_u[0]; + } + gate.u = subset_uidx[gate.u]; + } + if (!gate.getLayerIdV(i)) { + if (visited_vidx[gate.v] != i) { + visited_vidx[gate.v] = i; + subset_vidx[gate.v] = cur.size_v[0]; + cur.ori_id_v.push_back(gate.v); + ++cur.size_v[0]; + } + gate.v = subset_vidx[gate.v]; + } + has_pre_layer_u |= (gate.getLayerIdU(i) != 0); + has_pre_layer_v |= (gate.getLayerIdV(i) != 0); + } + + cur.bit_length_u[0] = ceilPow2BitLength(cur.size_u[0]); + cur.bit_length_v[0] = ceilPow2BitLength(cur.size_v[0]); + + if (has_pre_layer_u) switch (cur.ty) { + case layerType::FFT: + cur.size_u[1] = 1ULL << cur.fft_bit_length - 1; + cur.bit_length_u[1] = cur.fft_bit_length - 1; + break; + case layerType::IFFT: + cur.size_u[1] = 1ULL << cur.fft_bit_length; + cur.bit_length_u[1] = cur.fft_bit_length; + break; + default: + cur.size_u[1] = lst.size ; + cur.bit_length_u[1] = lst.bit_length; + break; + } else { + cur.size_u[1] = 0; + cur.bit_length_u[1] = -1; + } + + if (has_pre_layer_v) { + if (cur.ty == layerType::DOT_PROD) { + cur.size_v[1] = lst.size >> cur.fft_bit_length; + cur.bit_length_v[1] = lst.bit_length - cur.fft_bit_length; + } else { + cur.size_v[1] = lst.size; + cur.bit_length_v[1] = lst.bit_length; + } + } else { + cur.size_v[1] = 0; + cur.bit_length_v[1] = -1; + } + cur.updateSize(); + } + cerr << "begin subset finish." << endl; +} + +void layeredCircuit::init(u8 q_bit_size, u8 _layer_sz) { + two_mul.resize((q_bit_size + 1) << 1); + two_mul[0] = F_ONE; + two_mul[q_bit_size + 1] = -F_ONE; + for (int i = 1; i <= q_bit_size; ++i) { + two_mul[i] = two_mul[i - 1] + two_mul[i - 1]; + two_mul[i + q_bit_size + 1] = -two_mul[i]; + } + size = _layer_sz; + circuit.resize(size); +} diff --git a/src/circuit.h b/src/circuit.h new file mode 100644 index 0000000..46d47e0 --- /dev/null +++ b/src/circuit.h @@ -0,0 +1,89 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include "global_var.hpp" + +using std::cerr; +using std::endl; +using std::vector; + +struct uniGate { + u32 g, u; + u8 lu, sc; + uniGate(u32 _g, u32 _u, u8 _lu, u8 _sc) : + g(_g), u(_u), lu(_lu), sc(_sc) { +// cerr << "uni: " << g << ' ' << u << ' ' << lu <<' ' << sc.real << endl; + } +}; + +struct binGate { + u32 g, u, v; + u8 sc, l; + binGate(u32 _g, u32 _u, u32 _v, u8 _sc, u8 _l): + g(_g), u(_u), v(_v), sc(_sc), l(_l) { +// cerr << "bin: " << g << ' ' << u << ' ' << lu << ' ' << v << ' ' << lu << ' ' << sc.real << endl; + } + [[nodiscard]] u8 getLayerIdU(u8 layer_id) const { return !l ? 0 : layer_id - 1; } + [[nodiscard]] u8 getLayerIdV(u8 layer_id) const { return !(l & 1) ? 0 : layer_id - 1; } +}; + +enum class layerType { + INPUT, FFT, IFFT, ADD_BIAS, RELU, Sqr, OPT_AVG_POOL, MAX_POOL, AVG_POOL, DOT_PROD, PADDING, FCONN, NCONV, NCONV_MUL, NCONV_ADD +}; + +class layer { +public: + layerType ty; + u32 size{}, size_u[2]{}, size_v[2]{}; + i8 bit_length_u[2]{}, bit_length_v[2]{}, bit_length{}; + i8 max_bl_u{}, max_bl_v{}; + + bool need_phase2; + + // bit decomp related + u32 zero_start_id; + + std::vector uni_gates; + std::vector bin_gates; + + vector ori_id_u, ori_id_v; + i8 fft_bit_length; + + // iFFT or avg pooling. + F scale; + + layer() { + bit_length_u[0] = bit_length_v[0] = -1; + size_u[0] = size_v[0] = 0; + bit_length_u[1] = bit_length_v[1] = -1; + size_u[1] = size_v[1] = 0; + need_phase2 = false; + zero_start_id = 0; + fft_bit_length = -1; + scale = F_ONE; + } + + void updateSize() { + max_bl_u = std::max(bit_length_u[0], bit_length_u[1]); + max_bl_v = 0; + if (!need_phase2) return; + + max_bl_v = std::max(bit_length_v[0], bit_length_v[1]); + } +}; + +class layeredCircuit { +public: + vector circuit; + u8 size; + vector two_mul; + + void init(u8 q_bit_size, u8 _layer_sz); + void initSubset(); +}; + diff --git a/src/global_var.hpp b/src/global_var.hpp new file mode 100644 index 0000000..d81f8ea --- /dev/null +++ b/src/global_var.hpp @@ -0,0 +1,58 @@ +// +// Created by 69029 on 5/4/2021. +// + +#include +#include + +#ifndef ZKCNN_GLOBAL_VAR_HPP +#define ZKCNN_GLOBAL_VAR_HPP + +// the output format +#define MO_INFO_OUT_ID 0 +#define PSIZE_OUT_ID 1 +#define KSIZE_OUT_ID 2 +#define PCNT_OUT_ID 3 +#define CONV_TY_OUT_ID 4 +#define QS_OUT_ID 5 +#define WS_OUT_ID 6 +#define PT_OUT_ID 7 +#define VT_OUT_ID 8 +#define PS_OUT_ID 9 +#define POLY_PT_OUT_ID 10 +#define POLY_VT_OUT_ID 11 +#define POLY_PS_OUT_ID 12 +#define TOT_PT_OUT_ID 13 +#define TOT_VT_OUT_ID 14 +#define TOT_PS_OUT_ID 15 + +using std::cerr; +using std::endl; +using std::vector; +using std::string; +using std::max; +using std::min; +using std::ifstream; +using std::ofstream; +using std::ostream; +using std::pair; +using std::make_pair; + +extern vector output_tb; + +#define F Fr +#define G G1 +#define F_ONE (Fr::one()) +#define F_ZERO (Fr(0)) + +#define F_BYTE_SIZE (Fr::getByteSize()) + +template +string to_string_wp(const T a_value, const int n = 4) { + std::ostringstream out; + out.precision(n); + out << std::fixed << a_value; + return out.str(); +} + +#endif //ZKCNN_GLOBAL_VAR_HPP diff --git a/src/main_demo.cpp b/src/main_demo.cpp new file mode 100644 index 0000000..3c78ba4 --- /dev/null +++ b/src/main_demo.cpp @@ -0,0 +1,45 @@ +// +// Created by 69029 on 4/12/2021. +// + +#include "circuit.h" +#include "neuralNetwork.hpp" +#include "verifier.hpp" +#include "models.hpp" +#include "global_var.hpp" + +// the arguments' format +#define INPUT_FILE_ID 1 // the input filename +#define CONFIG_FILE_ID 2 // the config filename +#define OUTPUT_FILE_ID 3 // the input filename +#define NETWORK_FILE_ID 4 // the configuration of vgg +#define PIC_CNT 5 // the number of picture paralleled + +char QSIZE; +char SCALE; +vector output_tb(16, ""); + +int main(int argc, char **argv) { + initPairing(mcl::BLS12_381); + + char i_filename[500], c_filename[500], o_filename[500], n_filename[500]; + + strcpy(i_filename, argv[INPUT_FILE_ID]); + strcpy(c_filename, argv[CONFIG_FILE_ID]); + strcpy(o_filename, argv[OUTPUT_FILE_ID]); + strcpy(n_filename, argv[NETWORK_FILE_ID]); + + int pic_cnt = atoi(argv[PIC_CNT]); + + output_tb[MO_INFO_OUT_ID] ="vgg (relu)"; + output_tb[PCNT_OUT_ID] = std::to_string(pic_cnt); + + prover p; + vgg nn(32, 32, 3, pic_cnt, i_filename, c_filename, o_filename, n_filename); + nn.create(p, false); + verifier v(&p, p.C); + v.verify(); + + for (auto &s: output_tb) printf("%s, ", s.c_str()); + puts(""); +} \ No newline at end of file diff --git a/src/models.cpp b/src/models.cpp new file mode 100644 index 0000000..c9f80cc --- /dev/null +++ b/src/models.cpp @@ -0,0 +1,365 @@ +// +// Created by 69029 on 3/16/2021. +// + +#include +#include +#include "models.hpp" +#include "utils.hpp" +#undef USE_VIRGO + + +vgg::vgg(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, const string &i_filename, + const string &c_filename, const std::string &o_filename, const std::string &n_filename): + neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) { + assert(psize_x == psize_y); + conv_section.resize(5); + + convType conv_ty = NAIVE_FAST; + + ifstream config_in(n_filename); + string con; + i64 kernel_size = 3, ch_in = pic_channel, ch_out, new_nx = pic_size_x, new_ny = pic_size_y; + + int idx = 0; + while (config_in >> con) { + if (con[0] != 'M' && con[0] != 'A') { + ch_out = stoi(con, nullptr, 10); + conv_section[idx].emplace_back(conv_ty, ch_out, ch_in, kernel_size); + ch_in = ch_out; + } else { + ++idx; + pool.emplace_back(con[0] == 'M' ? MAX : AVG, 2, 1); + new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; + new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; + } + } + + assert(pic_size_x == 32); + full_conn.emplace_back(512, new_nx * new_ny * ch_in); + full_conn.emplace_back(512, 512); + full_conn.emplace_back(10, 512); +} + +vgg16::vgg16(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty, + const std::string &i_filename, const string &c_filename, const std::string &o_filename) + : neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) { + assert(psize_x == psize_y); + conv_section.resize(5); + + i64 start = 64, kernel_size = 3, new_nx = pic_size_x, new_ny = pic_size_y; + + conv_section[0].emplace_back(conv_ty, start, pic_channel, kernel_size); + conv_section[0].emplace_back(conv_ty, start, start, kernel_size); + pool.emplace_back(pool_ty, 2, 1); + new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; + new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; + + conv_section[1].emplace_back(conv_ty, start << 1, start, kernel_size); + conv_section[1].emplace_back(conv_ty, start << 1, start << 1, kernel_size); + pool.emplace_back(pool_ty, 2, 1); + new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; + new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; + + conv_section[2].emplace_back(conv_ty, start << 2, start << 1, kernel_size); + conv_section[2].emplace_back(conv_ty, start << 2, start << 2, kernel_size); + conv_section[2].emplace_back(conv_ty, start << 2, start << 2, kernel_size); + pool.emplace_back(pool_ty, 2, 1); + new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; + new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; + + conv_section[3].emplace_back(conv_ty, start << 3, start << 2, 3); + conv_section[3].emplace_back(conv_ty, start << 3, start << 3, 3); + conv_section[3].emplace_back(conv_ty, start << 3, start << 3, 3); + pool.emplace_back(pool_ty, 2, 1); + new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; + new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; + + conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3); + conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3); + conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3); + + pool.emplace_back(pool_ty, 2, 1); + new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; + new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; + + if (pic_size_x == 224) { + full_conn.emplace_back(4096, new_nx * new_ny * (start << 3)); + full_conn.emplace_back(4096, 4096); + full_conn.emplace_back(1000, 4096); + } else { + assert(pic_size_x == 32); + full_conn.emplace_back(512, new_nx * new_ny * (start << 3)); + full_conn.emplace_back(512, 512); + full_conn.emplace_back(10, 512); + } +} + +vgg11::vgg11(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty, + const std::string &i_filename, const string &c_filename, const std::string &o_filename) + : neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) { + assert(psize_x == psize_y); + conv_section.resize(5); + + i64 start = 64, kernel_size = 3, new_nx = pic_size_x, new_ny = pic_size_y; + + conv_section[0].emplace_back(conv_ty, start, pic_channel, kernel_size); + pool.emplace_back(pool_ty, 2, 1); + new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; + new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; + + conv_section[1].emplace_back(conv_ty, start << 1, start, kernel_size); + pool.emplace_back(pool_ty, 2, 1); + new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; + new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; + + conv_section[2].emplace_back(conv_ty, start << 2, start << 1, kernel_size); + conv_section[2].emplace_back(conv_ty, start << 2, start << 2, kernel_size); + pool.emplace_back(pool_ty, 2, 1); + new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; + new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; + + conv_section[3].emplace_back(conv_ty, start << 3, start << 2, 3); + conv_section[3].emplace_back(conv_ty, start << 3, start << 3, 3); + pool.emplace_back(pool_ty, 2, 1); + new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; + new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; + + conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3); + conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3); + + pool.emplace_back(pool_ty, 2, 1); + new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; + new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; + + if (pic_size_x == 224) { + full_conn.emplace_back(4096, new_nx * new_ny * (start << 3)); + full_conn.emplace_back(4096, 4096); + full_conn.emplace_back(1000, 4096); + } else { + assert(pic_size_x == 32); + full_conn.emplace_back(512, new_nx * new_ny * (start << 3)); + full_conn.emplace_back(512, 512); + full_conn.emplace_back(10, 512); + } +} + +ccnn::ccnn(i64 psize_x, i64 psize_y, i64 pparallel, i64 pchannel, poolType pool_ty) : + neuralNetwork(psize_x, psize_y, pchannel, pparallel, "", "", "") { + conv_section.resize(1); + + conv_section[0].emplace_back(NAIVE_FAST, 2, pchannel, 3, 0, 0); + pool.emplace_back(pool_ty, 2, 1); + +// conv_section[1].emplace_back(FFT, 64, 4, 3); +// conv_section[1].emplace_back(NAIVE, 64, 64, 3); +// pool.emplace_back(pool_ty, 2, 1); + +// conv_section[0].emplace_back(FFT, 2, pic_channel, 3); +// conv_section[1].emplace_back(NAIVE, 1, 2, 3); +// pool.emplace_back(pool_ty, 2, 1); +} + +lenet::lenet(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty, + const std::string &i_filename, const string &c_filename, const std::string &o_filename) + : neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) { + conv_section.emplace_back(); + if (psize_x == 28 && psize_y == 28) + conv_section[0].emplace_back(conv_ty, 6, pchannel, 5, 0, 2); + else conv_section[0].emplace_back(conv_ty, 6, pchannel, 5, 0, 0); + pool.emplace_back(pool_ty, 2, 1); + + conv_section.emplace_back(); + conv_section[1].emplace_back(conv_ty, 16, 6, 5, 0, 0); + pool.emplace_back(pool_ty, 2, 1); + + full_conn.emplace_back(120, 400); + full_conn.emplace_back(84, 120); + full_conn.emplace_back(10, 84); +} + +lenetCifar::lenetCifar(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty, + const std::string &i_filename, const string &c_filename, const std::string &o_filename) + : neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) { + conv_section.resize(3); + + conv_section[0].emplace_back(conv_ty, 6, pchannel, 5, 0, 0); + pool.emplace_back(pool_ty, 2, 1); + + conv_section[1].emplace_back(conv_ty, 16, 6, 5, 0, 0); + pool.emplace_back(pool_ty, 2, 1); + + conv_section[2].emplace_back(conv_ty, 120, 16, 5, 0, 0); + + full_conn.emplace_back(84, 120); + full_conn.emplace_back(10, 84); +} + +void singleConv::createConv(prover &p) { + initParamConv(); + p.C.init(Q_BIT_SIZE, SIZE); + + p.val.resize(SIZE); + val = p.val.begin(); + two_mul = p.C.two_mul.begin(); + + i64 layer_id = 0; + inputLayer(p.C.circuit[layer_id++]); + + new_nx_in = pic_size_x; + new_ny_in = pic_size_y; + pool_ty = NONE; + for (i64 i = 0; i < conv_section.size(); ++i) { + auto &sec = conv_section[i]; + for (i64 j = 0; j < sec.size(); ++j) { + auto &conv = sec[j]; + refreshConvParam(new_nx_in, new_ny_in, conv); + + switch (conv.ty) { + case FFT: + paddingLayer(p.C.circuit[layer_id], layer_id, conv.weight_start_id); + fftLayer(p.C.circuit[layer_id], layer_id); + dotProdLayer(p.C.circuit[layer_id], layer_id); + ifftLayer(p.C.circuit[layer_id], layer_id); + break; + case NAIVE_FAST: + naiveConvLayerFast(p.C.circuit[layer_id], layer_id, conv.weight_start_id, conv.bias_start_id); + break; + default: + naiveConvLayerMul(p.C.circuit[layer_id], layer_id, conv.weight_start_id); + naiveConvLayerAdd(p.C.circuit[layer_id], layer_id, conv.bias_start_id); + } + } + } + p.C.initSubset(); +// for (i64 i = 0; i < SIZE; ++i) { +// cerr << i << "(" << p.C.circuit[i].zero_start_id << ", " << p.C.circuit[i].size << "):\t"; +// for (i64 j = 0; j < std::min(100u, p.C.circuit[i].size); ++j) +// cerr << p.val[i][j] << ' '; +// cerr << endl; +// bool flag = false; +// for (i64 j = 0; j < p.C.circuit[i].size; ++j) +// if (p.val[i][j] != F_ZERO) flag = true; +// if (flag) cerr << "not all zero: " << i << endl; +// for (i64 j = p.C.circuit[i].zero_start_id; j < p.C.circuit[i].size; ++j) +// if (p.val[i][j] != F_ZERO) { cerr << "WRONG! " << i << ' ' << j << ' ' << p.val[i][j] << endl; exit(EXIT_FAILURE); } +// } + cerr << "finish creating circuit." << endl; +} + +void singleConv::initParamConv() { + i64 conv_layer_cnt = 0; + total_in_size = 0; + total_para_size = total_relu_in_size = total_ave_in_size = total_max_in_size = 0; + + // data + i64 pos = pic_size_x * pic_size_y * pic_channel * pic_parallel; + + new_nx_in = pic_size_x; + new_ny_in = pic_size_y; + for (i64 i = 0; i < conv_section.size(); ++i) { + auto &sec = conv_section[i]; + for (i64 j = 0; j < sec.size(); ++j) { + refreshConvParam(new_nx_in, new_ny_in, sec[j]); + conv_layer_cnt += sec[j].ty == FFT ? FFT_SIZE - 1 : sec[j].ty == NAIVE ? NCONV_SIZE : NCONV_FAST_SIZE; + // conv_kernel + sec[j].weight_start_id = pos; + pos += sqr(m) * channel_in * channel_out; + total_para_size += sqr(m) * channel_in * channel_out; + sec[j].bias_start_id = -1; + } + } + total_in_size = pos; + + SIZE = 1 + conv_layer_cnt; + cerr << "SIZE: " << SIZE << endl; +} + +vector singleConv::getFFTAns(const vector &output) { + vector res; + res.resize(nx_out * ny_out * channel_out * pic_channel * pic_parallel); + + i64 lst_fft_lenh = getFFTLen() >> 1; + i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding; + for (i64 p = 0; p < pic_parallel; ++p) + for (i64 co = 0; co < channel_out; ++co) + for (i64 x = L; x + m <= Rx; x += (1 << log_stride)) + for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) { + i64 idx = tesIdx(p, co, ((x - L) >> log_stride), ((y - L) >> log_stride), channel_out, nx_out, ny_out); + i64 i = cubIdx(p, co, matIdx(Rx - x - 1, Ry - y - 1, ny_padded_in), channel_out, lst_fft_lenh); + res[idx] = output[i]; + } + return res; +} + +double singleConv::calcRawFFT() { + auto in = val[0].begin(); + auto conv = val[0].begin() + conv_section[0][0].weight_start_id; + auto bias = val[0].begin() + conv_section[0][0].bias_start_id; + + timer tm; + int logn = ceilPow2BitLength(nx_padded_in * ny_padded_in) + 1; + vector res(nx_out * ny_out * channel_out * pic_parallel, F_ZERO); + vector arr1(1 << logn, F_ZERO); + vector arr2(arr1.size(), F_ZERO); + + assert(pic_parallel == 1 && pic_channel == 1); + // data matrix + i64 L = -padding; + i64 Rx = nx_in + padding, Ry = ny_in + padding; + + tm.start(); + for (i64 x = L; x < Rx; ++x) + for (i64 y = L; y < Ry; ++y) + if (check(x, y, nx_in, ny_in)) { + i64 g = matIdx(Rx - x - 1, Ry - y - 1, ny_padded_in); + i64 u = matIdx(x, y, ny_in); + arr1[g] = in[u]; + } + + // kernel matrix + for (i64 x = 0; x < nx_padded_in; ++x) + for (i64 y = 0; y < ny_padded_in; ++y) + if (check(x, y, m, m)) { + i64 g = matIdx(x, y, ny_padded_in); + i64 u = matIdx(x, y, m); + arr2[g] = conv[u]; + } + + fft(arr1, logn, false); + fft(arr2, logn, false); + for (i64 i = 0; i < arr1.size(); ++i) + arr1[i] = arr1[i] * arr2[i]; + fft(arr1, logn, true); + reverse(arr1.begin(), arr1.end()); + + tm.stop(); + return tm.elapse_sec(); +} + +double singleConv::calcRawNaive() { + auto in = val[0].begin(); + auto conv = val[0].begin() + conv_section[0][0].weight_start_id; + auto bias = val[0].begin() + conv_section[0][0].bias_start_id; + + timer tm; + vector res(nx_out * ny_out * channel_out * pic_parallel, F_ZERO); + tm.start(); + for (i64 p = 0; p < pic_parallel; ++p) + for (i64 i = 0; i < channel_out; ++i) + for (i64 j = 0; j < channel_in; ++j) + for (i64 x = -padding; x + m <= nx_in + padding; x += (1 << log_stride)) + for (i64 y = -padding; y + m <= ny_in + padding; y += (1 << log_stride)) { + i64 idx = tesIdx(p, i, (x + padding) >> log_stride, (y + padding) >> log_stride, channel_out, nx_out, ny_out); + if (j == 0) res[idx] = res[idx] + bias[i]; + for (i64 tx = x; tx < x + m; ++tx) + for (i64 ty = y; ty < y + m; ++ty) + if (check(tx, ty, nx_in, ny_in)) { + i64 u = tesIdx(p, j, tx, ty, channel_in, nx_in, ny_in); + i64 v = tesIdx(i, j, tx - x, ty - y, channel_in, m, m); + res.at(idx) = res.at(idx) + in[u] * conv[v]; + } + } + tm.stop(); + return tm.elapse_sec(); +} diff --git a/src/models.hpp b/src/models.hpp new file mode 100644 index 0000000..3a71a9c --- /dev/null +++ b/src/models.hpp @@ -0,0 +1,66 @@ +// +// Created by 69029 on 3/16/2021. +// + +#ifndef ZKCNN_VGG_HPP +#define ZKCNN_VGG_HPP + +#include "neuralNetwork.hpp" + +class vgg: public neuralNetwork { + +public: + explicit vgg(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, const std::string &i_filename, const string &c_filename, const std::string &o_filename, const std::string &n_filename); + +}; + +class vgg16: public neuralNetwork { + +public: + explicit vgg16(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty, + const std::string &i_filename, const string &c_filename, const std::string &o_filename); + +}; + +class vgg11: public neuralNetwork { + +public: + explicit vgg11(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty, + const std::string &i_filename, const string &c_filename, const std::string &o_filename); + +}; + +class lenet: public neuralNetwork { +public: + explicit lenet(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty, + const std::string &i_filename, const string &c_filename, const std::string &o_filename); +}; + +class lenetCifar: public neuralNetwork { +public: + explicit lenetCifar(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty, + const std::string &i_filename, const string &c_filename, const std::string &o_filename); +}; + +class ccnn: public neuralNetwork { +public: + explicit ccnn(i64 psize_x, i64 psize_y, i64 pparallel, i64 pchannel, poolType pool_ty); +}; + +class singleConv: public neuralNetwork { +public: + explicit singleConv(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, i64 kernel_size, i64 channel_out, + i64 log_stride, i64 padding, convType conv_ty); + + void createConv(prover &p); + + void initParamConv(); + + vector getFFTAns(const vector &output); + + double calcRawFFT(); + + double calcRawNaive(); +}; + +#endif //ZKCNN_VGG_HPP diff --git a/src/neuralNetwork.cpp b/src/neuralNetwork.cpp new file mode 100644 index 0000000..7d77cdc --- /dev/null +++ b/src/neuralNetwork.cpp @@ -0,0 +1,1016 @@ +// +// Created by 69029 on 3/16/2021. +// + +#include "neuralNetwork.hpp" +#include "utils.hpp" +#include "global_var.hpp" +#include +#include +#include +#include + +using std::cerr; +using std::endl; +using std::max; +using std::ifstream; +using std::ofstream; + +ifstream in; +ifstream conf; +ofstream out; + +neuralNetwork::neuralNetwork(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, const string &i_filename, + const string &c_filename, const string &o_filename) : + pic_size_x(psize_x), pic_size_y(psize_y), pic_channel(pchannel), pic_parallel(pparallel), + SIZE(0), NCONV_FAST_SIZE(1), NCONV_SIZE(2), FFT_SIZE(5), + AVE_POOL_SIZE(1), FC_SIZE(1), RELU_SIZE(1), act_ty(RELU_ACT) { + + in.open(i_filename); + if (!in.is_open()) + fprintf(stderr, "Can't find the input file!!!\n"); + conf.open(c_filename); + if (!conf.is_open()) + fprintf(stderr, "Can't find the config file!!!\n"); + + if (!o_filename.empty()) out.open(o_filename); +} + +neuralNetwork::neuralNetwork(i64 psize, i64 pchannel, i64 pparallel, i64 kernel_size, i64 sec_size, i64 fc_size, + i64 start_channel, convType conv_ty, poolType pool_ty) + : neuralNetwork(psize, psize, pchannel, pparallel, "", "", "") { + pool_bl = 2; + pool_stride_bl = pool_bl >> 1; + conv_section.resize(sec_size); + + i64 start = start_channel; + for (i64 i = 0; i < sec_size; ++i) { + conv_section[i].emplace_back(conv_ty, start << i, i ? (start << (i - 1)) : pic_channel, kernel_size); + conv_section[i].emplace_back(conv_ty, start << i, start << i, kernel_size); + pool.emplace_back(pool_ty, 2, 1); + } + + i64 new_nx = (pic_size_x >> pool_stride_bl * conv_section.size()); + i64 new_ny = (pic_size_y >> pool_stride_bl * conv_section.size()); + for (i64 i = 0; i < fc_size; ++i) + full_conn.emplace_back(i == fc_size - 1 ? 1000 : 4096, i ? 4096 : new_nx * new_ny * (start << (sec_size - 1))); +} + +void neuralNetwork::create(prover &pr, bool only_compute) { + assert(pool.size() >= conv_section.size() - 1); + + initParam(); + pr.C.init(Q_BIT_SIZE, SIZE); + + pr.val.resize(SIZE); + val = pr.val.begin(); + two_mul = pr.C.two_mul.begin(); + + i64 layer_id = 0; + inputLayer(pr.C.circuit[layer_id++]); + + new_nx_in = pic_size_x; + new_ny_in = pic_size_y; + for (i64 i = 0; i < conv_section.size(); ++i) { + auto &sec = conv_section[i]; + for (i64 j = 0; j < sec.size(); ++j) { + auto &conv = sec[j]; + refreshConvParam(new_nx_in, new_ny_in, conv); + pool_ty = i < pool.size() && j == sec.size() - 1 ? pool[i].ty : NONE; + x_bit = x_next_bit; + switch (conv.ty) { + case FFT: + paddingLayer(pr.C.circuit[layer_id], layer_id, conv.weight_start_id); + fftLayer(pr.C.circuit[layer_id], layer_id); + dotProdLayer(pr.C.circuit[layer_id], layer_id); + ifftLayer(pr.C.circuit[layer_id], layer_id); + addBiasLayer(pr.C.circuit[layer_id], layer_id, conv.bias_start_id); + break; + case NAIVE_FAST: + naiveConvLayerFast(pr.C.circuit[layer_id], layer_id, conv.weight_start_id, conv.bias_start_id); + break; + default: + naiveConvLayerMul(pr.C.circuit[layer_id], layer_id, conv.weight_start_id); + naiveConvLayerAdd(pr.C.circuit[layer_id], layer_id, conv.bias_start_id); + } + + // update the scale bit + x_next_bit = getNextBit(layer_id - 1); + T = x_bit + w_bit - x_next_bit; + Q_MAX = Q + T; + if (pool_ty != MAX) + reluActConvLayer(pr.C.circuit[layer_id], layer_id); + } + + if (i >= pool.size()) continue; + calcSizeAfterPool(pool[i]); + switch (pool[i].ty) { + case AVG: avgPoolingLayer(pr.C.circuit[layer_id], layer_id); break; + case MAX: maxPoolingLayer(pr.C, layer_id, pool[i].dcmp_start_id, pool[i].max_start_id, + pool[i].max_dcmp_start_id); break; + } + } + + pool_ty = NONE; + for (int i = 0; i < full_conn.size(); ++i) { + auto &fc = full_conn[i]; + refreshFCParam(fc); + x_bit = x_next_bit; + fullyConnLayer(pr.C.circuit[layer_id], layer_id, fc.weight_start_id, fc.bias_start_id); + if (i == full_conn.size() - 1) break; + + // update the scale bit + x_next_bit = getNextBit(layer_id - 1); + T = x_bit + w_bit - x_next_bit; + Q_MAX = Q + T; + reluActFconLayer(pr.C.circuit[layer_id], layer_id); + } + + assert(SIZE == layer_id); + + total_in_size += total_max_in_size + total_ave_in_size + total_relu_in_size; + initLayer(pr.C.circuit[0], total_in_size, layerType::INPUT); + assert(total_in_size == pr.val[0].size()); + + printInfer(pr); +// printLayerValues(pr); + + if (only_compute) return; + pr.C.initSubset(); + cerr << "finish creating circuit." << endl; +} + +void neuralNetwork::inputLayer(layer &circuit) { + initLayer(circuit, total_in_size, layerType::INPUT); + + for (i64 i = 0; i < total_in_size; ++i) + circuit.uni_gates.emplace_back(i, 0, 0, 0); + + calcInputLayer(circuit); + printLayerInfo(circuit, 0); +} + +void +neuralNetwork::paddingLayer(layer &circuit, i64 &layer_id, i64 first_conv_id) { + i64 lenh = getFFTLen() >> 1; + i64 size = lenh * channel_in * (pic_parallel + channel_out); + initLayer(circuit, size, layerType::PADDING); + circuit.fft_bit_length = getFFTBitLen(); + + // data matrix + i64 L = -padding; + i64 Rx = nx_in + padding, Ry = ny_in + padding; + + for (i64 p = 0; p < pic_parallel; ++p) + for (i64 ci = 0; ci < channel_in; ++ci) + for (i64 x = L; x < Rx; ++x) + for (i64 y = L; y < Ry; ++y) + if (check(x, y, nx_in, ny_in)) { + i64 g = cubIdx(p, ci, matIdx(Rx - x - 1, Ry - y - 1, ny_padded_in), channel_in, lenh); + i64 u = tesIdx(p, ci, x, y, channel_in, nx_in, ny_in); + circuit.uni_gates.emplace_back(g, u, layer_id - 1, 0); + } + + // kernel matrix + i64 first = pic_parallel * channel_in * lenh; + for (i64 co = 0; co < channel_out; ++co) + for (i64 ci = 0; ci < channel_in; ++ci) + for (i64 x = 0; x < nx_padded_in; ++x) + for (i64 y = 0; y < ny_padded_in; ++y) + if (check(x, y, m, m)) { + i64 g = first + cubIdx(co, ci, matIdx(x, y, ny_padded_in), channel_in, lenh) ; + i64 u = first_conv_id + tesIdx(co, ci, x, y, channel_in, m, m); + circuit.uni_gates.emplace_back(g, u, 0, 0); + } + + readConvWeight(first_conv_id); + calcNormalLayer(circuit, layer_id); + printLayerInfo(circuit, layer_id++); +} + +void neuralNetwork::fftLayer(layer &circuit, i64 &layer_id) { + i64 size = getFFTLen() * channel_in * (pic_parallel + channel_out); + initLayer(circuit, size, layerType::FFT); + circuit.fft_bit_length = getFFTBitLen(); + + calcFFTLayer(circuit, layer_id); + printLayerInfo(circuit, layer_id++); +} + +void neuralNetwork::dotProdLayer(layer &circuit, i64 &layer_id) { + i64 len = getFFTLen(); + i64 size = len * channel_out * pic_parallel; + initLayer(circuit, size, layerType::DOT_PROD); + circuit.need_phase2 = true; + circuit.fft_bit_length = getFFTBitLen(); + + for (i64 p = 0; p < pic_parallel; ++p) + for (i64 co = 0; co < channel_out; ++co) + for (i64 ci = 0; ci < channel_in; ++ci) { + i64 g = matIdx(p, co, channel_out); + i64 u = matIdx(p, ci, channel_in); + i64 v = matIdx(pic_parallel + co, ci, channel_in); + circuit.bin_gates.emplace_back(g, u, v, 0, 1); + } + + calcDotProdLayer(circuit, layer_id); + printLayerInfo(circuit, layer_id++); +} + +void neuralNetwork::ifftLayer(layer &circuit, i64 &layer_id) { + i64 len = getFFTLen(), lenh = len >> 1; + i64 size = lenh * channel_out * pic_parallel; + initLayer(circuit, size, layerType::IFFT); + circuit.fft_bit_length = getFFTBitLen(); + F::inv(circuit.scale, F(1ULL << circuit.fft_bit_length)); + + calcFFTLayer(circuit, layer_id); + printLayerInfo(circuit, layer_id++); +} + +void neuralNetwork::addBiasLayer(layer &circuit, i64 &layer_id, i64 first_bias_id) { + i64 len = getFFTLen(); + i64 size = nx_out * ny_out * channel_out * pic_parallel; + initLayer(circuit, size, layerType::ADD_BIAS); + + i64 lenh = len >> 1; + i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding; + for (i64 p = 0; p < pic_parallel; ++p) + for (i64 co = 0; co < channel_out; ++co) + for (i64 x = L; x + m <= Rx; x += (1 << log_stride)) + for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) { + i64 u = cubIdx(p, co, matIdx(Rx - x - 1, Ry - y - 1, ny_padded_in), channel_out, lenh); + i64 g = tesIdx(p, co, (x - L) >> log_stride, (y - L) >> log_stride, channel_out, nx_out, ny_out); + circuit.uni_gates.emplace_back(g, first_bias_id + co, 0, 0); + circuit.uni_gates.emplace_back(g, u, layer_id - 1, 0); + } + + readBias(first_bias_id); + calcNormalLayer(circuit, layer_id); + printLayerInfo(circuit, layer_id++); +} + +void neuralNetwork::naiveConvLayerFast(layer &circuit, i64 &layer_id, i64 first_conv_id, i64 first_bias_id) { + i64 size = nx_out * ny_out * channel_out * pic_parallel; + initLayer(circuit, size, layerType::NCONV); + circuit.need_phase2 = true; + + i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding; + i64 mat_in_size = nx_in * ny_in; + i64 m_sqr = sqr(m); + for (i64 p = 0; p < pic_parallel; ++p) + for (i64 co = 0; co < channel_out; ++co) + for (i64 ci = 0; ci < channel_in; ++ci) + for (i64 x = L; x + m <= Rx; x += (1 << log_stride)) + for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) { + i64 g = tesIdx(p, co, ((x - L) >> log_stride), ((y - L) >> log_stride), channel_out, nx_out, ny_out); + if (ci == 0 && ~first_bias_id) circuit.uni_gates.emplace_back(g, first_bias_id + co, 0, 0); + for (i64 tx = x; tx < x + m; ++tx) + for (i64 ty = y; ty < y + m; ++ty) + if (check(tx, ty, nx_in, ny_in)) { + i64 u = tesIdx(p, ci, tx, ty, channel_in, nx_in, ny_in); + i64 v = first_conv_id + tesIdx(co, ci, tx - x, ty - y, channel_in, m, m); + circuit.bin_gates.emplace_back(g, u, v, 0, 2 * (u8) (layer_id > 1)); + } + } + + readConvWeight(first_conv_id); + if (~first_bias_id) readBias(first_bias_id); + calcNormalLayer(circuit, layer_id); + printLayerInfo(circuit, layer_id++); +} + +void neuralNetwork::naiveConvLayerMul(layer &circuit, i64 &layer_id, i64 first_conv_id) { + i64 mat_out_size = nx_out * ny_out; + i64 mat_in_size = nx_in * ny_in; + i64 m_sqr = sqr(m); + i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding; + + i64 g = 0; + for (i64 p = 0; p < pic_parallel; ++p) + for (i64 co = 0; co < channel_out; ++co) + for (i64 ci = 0; ci < channel_in; ++ci) + for (i64 x = L; x + m <= Rx; x += (1 << log_stride)) + for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) + for (i64 tx = x; tx < x + m; ++tx) + for (i64 ty = y; ty < y + m; ++ty) + if (check(tx, ty, nx_in, ny_in)) { + i64 u = tesIdx(p, ci, tx, ty, channel_in, nx_in, ny_in); + i64 v = first_conv_id + tesIdx(co, ci, tx - x, ty - y, channel_in, m, m); + circuit.bin_gates.emplace_back(g++, u, v, 0, 2 * (u8) (layer_id > 1)); + } + + initLayer(circuit, g, layerType::NCONV_MUL); + circuit.need_phase2 = true; + readConvWeight(first_conv_id); + calcNormalLayer(circuit, layer_id); + printLayerInfo(circuit, layer_id++); +} + +void neuralNetwork::naiveConvLayerAdd(layer &circuit, i64 &layer_id, i64 first_bias_id) { + i64 size = nx_out * ny_out * channel_out * pic_parallel; + initLayer(circuit, size, layerType::NCONV_ADD); + + i64 mat_in_size = nx_in * ny_in; + i64 m_sqr = sqr(m); + i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding; + + i64 u = 0; + for (i64 p = 0; p < pic_parallel; ++p) + for (i64 co = 0; co < channel_out; ++co) + for (i64 ci = 0; ci < channel_in; ++ci) + for (i64 x = L; x + m <= Rx; x += (1 << log_stride)) + for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) { + i64 g = tesIdx(p, co, ((x - L) >> log_stride),( (y - L) >> log_stride), channel_out, nx_out, ny_out); + i64 cnt = 0; + if (ci == 0 && ~first_bias_id) { + circuit.uni_gates.emplace_back(g, first_bias_id + co, 0, 0); + ++cnt; + } + for (i64 tx = x; tx < x + m; ++tx) + for (i64 ty = y; ty < y + m; ++ty) + if (check(tx, ty, nx_in, ny_in)) { + circuit.uni_gates.emplace_back(g, u++, layer_id - 1, 0); + ++cnt; + } + } + + if (~first_bias_id) readBias(first_bias_id); + calcNormalLayer(circuit, layer_id); + printLayerInfo(circuit, layer_id++); +} + +void neuralNetwork::reluActConvLayer(layer &circuit, i64 &layer_id) { + i64 mat_out_size = nx_out * ny_out; + i64 size = 1L * mat_out_size * channel_out * (2 + Q_MAX) * pic_parallel; + i64 block_len = mat_out_size * channel_out * pic_parallel; + + i64 dcmp_cnt = block_len * Q_MAX; + i64 first_dcmp_id = val[0].size(); + val[0].resize(val[0].size() + dcmp_cnt); + total_relu_in_size += dcmp_cnt; + + initLayer(circuit, size, layerType::RELU); + circuit.need_phase2 = true; + + circuit.zero_start_id = block_len; + + for (i64 g = 0; g < block_len; ++g) { + i64 sign_u = first_dcmp_id + g * Q_MAX; + for (i64 s = 1; s < Q; ++s) { + i64 v = sign_u + s; + circuit.uni_gates.emplace_back(g, v, 0, Q - 1 - s); + circuit.bin_gates.emplace_back(g, sign_u, v, Q - s + Q_BIT_SIZE, 0); + } + } + + i64 len = getFFTLen(); + i64 lenh = len >> 1; + i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding; + for (i64 p = 0; p < pic_parallel; ++p) + for (i64 co = 0; co < channel_out; ++ co) + for (i64 x = L; x + m <= Rx; x += (1 << log_stride)) + for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) { + i64 u = tesIdx(p, co, (x - L) >> log_stride, (y - L) >> log_stride, channel_out, nx_out, ny_out); + i64 g = block_len + u, sign_v = first_dcmp_id + u * Q_MAX; + circuit.uni_gates.emplace_back(g, u, layer_id - 1, Q_BIT_SIZE + 1); + circuit.bin_gates.emplace_back(g, u, sign_v, 1, 2 * (u8) (layer_id > 1)); + prepareSignBit(layer_id - 1, u, sign_v); + for (i64 s = 1; s < Q_MAX; ++s) { + i64 v = sign_v + s; + circuit.uni_gates.emplace_back(g, v, 0, Q_MAX - s - 1); + prepareDecmpBit(layer_id - 1, u, v, Q_MAX - s - 1); + } + } + + for (i64 g = block_len << 1; g < (block_len << 1) + block_len * Q_MAX; ++g) { + i64 u = first_dcmp_id + g - (block_len << 1); + circuit.bin_gates.emplace_back(g, u, u, 0, 0); + circuit.uni_gates.emplace_back(g, u, 0, Q_BIT_SIZE + 1); + } + + calcNormalLayer(circuit, layer_id); + printLayerInfo(circuit, layer_id++); +} + +void neuralNetwork::reluActFconLayer(layer &circuit, i64 &layer_id) { + i64 block_len = channel_out * pic_parallel; + i64 size = block_len * (2 + Q_MAX); + initLayer(circuit, size, layerType::RELU); + circuit.zero_start_id = block_len; + circuit.need_phase2 = true; + + i64 dcmp_cnt = block_len * Q_MAX; + i64 first_dcmp_id = val[0].size(); + val[0].resize(val[0].size() + dcmp_cnt); + total_relu_in_size += dcmp_cnt; + + for (i64 g = 0; g < block_len; ++g) { + i64 sign_u = first_dcmp_id + g * Q_MAX; + for (i64 s = 1; s < Q; ++s) { + i64 v = sign_u + s; + circuit.uni_gates.emplace_back(g, v, 0, (Q - s - 1)); + circuit.bin_gates.emplace_back(g, sign_u, v, Q - s + Q_BIT_SIZE, 0); + } + } + + for (i64 u = 0; u < block_len; ++u) { + i64 g = block_len + u, sign_v = first_dcmp_id + u * Q_MAX; + circuit.uni_gates.emplace_back(g, u, layer_id - 1, Q_BIT_SIZE + 1); + circuit.bin_gates.emplace_back(g, u, sign_v, 1, 2 * (u8) (layer_id > 1)); + prepareSignBit(layer_id - 1, u, sign_v); + + for (i64 s = 1; s < Q_MAX; ++s) { + i64 v = sign_v + s; + circuit.uni_gates.emplace_back(g, v, 0, Q_MAX - s - 1); + prepareDecmpBit(layer_id - 1, u, v, Q_MAX - s - 1); + } + } + + for (i64 g = block_len << 1; g < (block_len << 1) + block_len * Q_MAX; ++g) { + i64 u = first_dcmp_id + g - (block_len << 1); + circuit.bin_gates.emplace_back(g, u, u, 0, 0); + circuit.uni_gates.emplace_back(g, u, 0, Q_BIT_SIZE + 1); + } + + calcNormalLayer(circuit, layer_id); + printLayerInfo(circuit, layer_id++); +} + +void neuralNetwork::avgPoolingLayer(layer &circuit, i64 &layer_id) { + i64 mat_out_size = nx_out * ny_out; + i64 zero_start_id = new_nx_in * new_ny_in * channel_out * pic_parallel; + i64 size = zero_start_id + getPoolDecmpSize(); + u8 dpool_bl = pool_bl << 1; + i64 pool_sz_sqr = sqr(pool_sz); + initLayer(circuit, size, layerType::AVG_POOL); + F::inv(circuit.scale, pool_sz_sqr); + circuit.zero_start_id = zero_start_id; + circuit.need_phase2 = true; + + i64 first_gate_id = val[0].size(); + val[0].resize(val[0].size() + zero_start_id * dpool_bl); + total_ave_in_size += zero_start_id * dpool_bl; + + // [0 .. zero_start_id] + // [zero_start_id .. zero_start_id + (g = 0..channel_out * mat_new_size) * dpool_bl + rm_i .. channel_out * mat_new_size * (1 + dpool_bl)] + for (i64 p = 0; p < pic_parallel; ++p) + for (i64 co = 0; co < channel_out; ++co) + for (i64 x = 0; x + pool_sz <= nx_out; x += pool_stride) + for (i64 y = 0; y + pool_sz <= ny_out; y += pool_stride) { + i64 g = tesIdx(p, co, (x >> pool_stride_bl), (y >> pool_stride_bl), channel_out, new_nx_in, new_ny_in); + F data = F_ZERO; + for (i64 tx = x; tx < x + pool_sz; ++tx) + for (i64 ty = y; ty < y + pool_sz; ++ty) { + i64 u = tesIdx(p, co, tx, ty, channel_out, nx_out, ny_out); + circuit.uni_gates.emplace_back(g, u, layer_id - 1, 0); + data = data + val[layer_id - 1][u]; + } + + for (i64 rm_i = 0; rm_i < dpool_bl; ++rm_i) { + i64 idx = matIdx(g, rm_i, dpool_bl), u = first_gate_id + idx, g_bit = zero_start_id + idx; + circuit.uni_gates.emplace_back(g, u, 0, dpool_bl - rm_i + Q_BIT_SIZE); + prepareFieldBit(F(data), u, dpool_bl - rm_i - 1); + + // check bit + circuit.bin_gates.emplace_back(g_bit, u, u, 0, 0); + circuit.uni_gates.emplace_back(g_bit, u, 0, Q_BIT_SIZE + 1); + } + } + + calcNormalLayer(circuit, layer_id); + printLayerInfo(circuit, layer_id++); +} + +void +neuralNetwork::maxPoolingLayer(layeredCircuit &C, i64 &layer_id, i64 first_dcmp_id, i64 first_max_id, + i64 first_max_dcmp_id) { + i64 mat_out_size = nx_out * ny_out; + i64 tot_out_size = mat_out_size * channel_out * pic_parallel; + i64 mat_new_size = new_nx_in * new_ny_in; + i64 tot_new_size = mat_new_size * channel_out * pic_parallel; + i64 pool_sz_sqr = sqr(pool_sz); + + i64 dcmp_cnt = getPoolDecmpSize(); + first_dcmp_id = val[0].size(); + val[0].resize(val[0].size() + dcmp_cnt); + total_max_in_size += dcmp_cnt; + + i64 max_cnt = tot_new_size; + first_max_id = val[0].size(); + val[0].resize(val[0].size() + max_cnt); + total_max_in_size += max_cnt; + + i64 max_dcmp_cnt = tot_new_size * (Q_MAX - 1); + first_max_dcmp_id = val[0].size(); + val[0].resize(val[0].size() + max_dcmp_cnt); + total_max_in_size += max_dcmp_cnt; + + // 0: max - everyone & max - (max bits) == 0 + // [0..tot_new_size * sqr(pool_sz)][tot_new_size * sqr(pool_sz)..tot_new_size * sqr(pool_sz) + tot_new_size] + i64 size_0 = tot_new_size * pool_sz_sqr + tot_new_size; + layer &circuit = C.circuit[layer_id]; + initLayer(circuit, size_0, layerType::MAX_POOL); + circuit.zero_start_id = tot_new_size * pool_sz_sqr; + i64 fft_len = getFFTLen(), fft_lenh = fft_len >> 1; + for (i64 p = 0; p < pic_parallel; ++p) + for (i64 co = 0; co < channel_out; ++co) { + for (i64 x = 0; x + pool_sz <= nx_out; x += pool_stride) + for (i64 y = 0; y + pool_sz <= ny_out; y += pool_stride) { + i64 i_max = tesIdx(p, co, x >> pool_stride_bl, y >> pool_stride_bl, channel_out, new_nx_in, new_ny_in); + i64 u_max = first_max_id + i_max; + for (i64 tx = x; tx < x + pool_sz; ++tx) + for (i64 ty = y; ty < y + pool_sz; ++ty) { + i64 g = cubIdx(tesIdx(p, co, x >> pool_stride_bl, y >> pool_stride_bl, channel_out, new_nx_in, new_ny_in), tx - x, ty - y, pool_sz, pool_sz); + i64 u_g = tesIdx(p, co, tx, ty, channel_out, nx_out, ny_out); + circuit.uni_gates.emplace_back(g, u_max, 0, 0); + circuit.uni_gates.emplace_back(g, u_g, layer_id - 1, Q_BIT_SIZE + 1); + prepareMax(layer_id - 1, u_g, u_max); + } + } + } + + for (i64 i_new = 0; i_new < tot_new_size; ++i_new) { + i64 g_new = circuit.zero_start_id + i_new; + i64 u_new = first_max_id + i_new; + circuit.uni_gates.emplace_back(g_new, u_new, 0, Q_BIT_SIZE + 1); + for (i64 i_new_bit = 0; i_new_bit < Q_MAX - 1; ++i_new_bit) { + i64 u_new_bit = first_max_dcmp_id + matIdx(i_new, i_new_bit, Q_MAX - 1); + circuit.uni_gates.emplace_back(g_new, u_new_bit, 0, Q_MAX - 2 - i_new_bit); + prepareDecmpBit(0, u_new, u_new_bit, Q_MAX - 2 - i_new_bit); + } + } + calcNormalLayer(circuit, layer_id); + printLayerInfo(circuit, layer_id++); + + // 1: (max - someone)^2 & max - everyone - ((max - everyone) bits) == 0 + // [0..tot_new_size * (sqr(pool_sz) + 1 >> 1)][tot_new_size * (sqr(pool_sz) + 1 >> 1)..tot_new_size * (sqr(pool_sz) + 1 >> 1) + tot_new_size * sqr(pool_sz)] + // 2: (max - someone)^4 + // ?: (max - someone)^(2^?) + // [0..(((tot_out_size + 1) / 2 + 1) / 2...+ 1) / 2] + // f: new tensor & (max - someone)^(pool_sz^2 + 1) & all (include minus and max) bits check + // [0..tot_new_size] + // [tot_new_size..tot_new_size * 2] + // [tot_new_size * 2..tot_new_size * (Q + 1)] + // [tot_new_size * (Q + 1) + // ..tot_new_size * (Q + 1) + (g = 0..tot_out_size) * (Q - 1) + bit_i + // ..tot_new_size * (Q + 1) + tot_new_size * (pool_sz^2) * (Q - 1)] + i64 contain_max_ly = 1, ksize = pool_sz_sqr; + while (!(ksize & 1)) { ksize >>= 1; ++contain_max_ly; } + ksize = pool_sz_sqr; + + for (int i = 1; i < pool_layer_cnt; ++i) { + layer &circuit = C.circuit[layer_id]; + i64 size = tot_new_size * ( ((ksize + 1 )>> 1) + (i64) (i == 1) * ksize ) + + (i64) (i == pool_layer_cnt - 1) * tot_new_size * Q_MAX + + (i64) (i == pool_layer_cnt - 1) * tot_new_size * pool_sz_sqr * (Q_MAX - 1); + initLayer(circuit, size, layerType::MAX_POOL); + circuit.need_phase2 = true; + + // new tensor + i64 before_mul = 0; + if (i == pool_layer_cnt - 1) { + before_mul = tot_new_size; + for (i64 g = 0; g < tot_new_size; ++g) + for (i64 j = 0; j < Q - 1; ++j) { + i64 u = first_max_dcmp_id + matIdx(g, j, Q_MAX - 1); + circuit.uni_gates.emplace_back(g, u, 0, Q - 2 - j); + } + } + + // multiplications of subtraction + for (i64 cnt = 0; cnt < tot_new_size; ++cnt) { + i64 v_max = first_max_id + cnt; + for (i64 j = 0; (j << 1) < ksize; ++j) { + i64 idx = matIdx(cnt, j, (ksize + 1) >> 1); + i64 g = before_mul + idx; + i64 u = matIdx(cnt, (j << 1), ksize); + if ((j << 1 | 1) < ksize) { + i64 v = matIdx(cnt, (j << 1 | 1), ksize); + circuit.bin_gates.emplace_back(g, u, v, 0, layer_id > 1); + } else if (i == contain_max_ly) + circuit.bin_gates.emplace_back(g, u, v_max, 0, 2 * (u8) (layer_id > 1)); + else + circuit.uni_gates.emplace_back(g, u, layer_id - 1, 0); + } + } + + if (i == 1) { + i64 minus_cnt = tot_new_size * ksize; + i64 minus_new_cnt = tot_new_size * ((ksize + 1) >> 1); + circuit.zero_start_id = minus_new_cnt; + for (i64 v = 0; v < minus_cnt; ++v) { + i64 g = minus_new_cnt + v; + circuit.uni_gates.emplace_back(g, v, layer_id - 1, Q_BIT_SIZE + 1); + for (i64 bit_j = 0; bit_j < Q_MAX - 1; ++bit_j) { + i64 u = first_dcmp_id + matIdx(v, bit_j, Q_MAX - 1); + circuit.uni_gates.emplace_back(g, u, 0, Q_MAX - 2 - bit_j); + prepareDecmpBit(layer_id - 1, v, u, Q_MAX - 2 - bit_j); + } + } + } else if (i == pool_layer_cnt - 1) { + i64 minus_cnt = tot_new_size * pool_sz_sqr; + circuit.zero_start_id = before_mul; + for (i64 j = 0; j < minus_cnt; ++j) { + i64 g = before_mul + tot_new_size + j; + i64 u = first_dcmp_id + j; + circuit.bin_gates.emplace_back(g, u, u, 0, 0); + circuit.uni_gates.emplace_back(g, u, 0, Q_BIT_SIZE + 1); + } + } + ksize = (ksize + 1) >> 1; + calcNormalLayer(circuit, layer_id); + printLayerInfo(circuit, layer_id++); + } + +} + +void neuralNetwork::fullyConnLayer(layer &circuit, i64 &layer_id, i64 first_fc_id, i64 first_bias_id) { + i64 size = channel_out * pic_parallel; + initLayer(circuit, size, layerType::FCONN); + circuit.need_phase2 = true; + + for (i64 p = 0; p < pic_parallel; ++p) + for (i64 co = 0; co < channel_out; ++co) { + i64 g = matIdx(p, co, channel_out); + circuit.uni_gates.emplace_back(g, first_bias_id + co, 0, 0); + for (i64 ci = 0; ci < channel_in; ++ci) { + i64 u = matIdx(p, ci, channel_in); + i64 v = first_fc_id + matIdx(co, ci, channel_in); + circuit.bin_gates.emplace_back(g, u, v, 0, 2 * (u8) (layer_id > 1)); + } + } + + readFconWeight(first_fc_id); + readBias(first_bias_id); + calcNormalLayer(circuit, layer_id); + printLayerInfo(circuit, layer_id++); +} + +void +neuralNetwork::refreshConvParam(i64 new_nx, i64 new_ny, const convKernel &conv) { + nx_in = new_nx; + ny_in = new_ny; + padding = conv.padding; + nx_padded_in = nx_in + (conv.padding * 2); + ny_padded_in = ny_in + (conv.padding * 2); + + m = conv.size; + channel_in = conv.channel_in; + channel_out = conv.channel_out; + log_stride = conv.stride_bl; + + nx_out = ((nx_padded_in - m) >> log_stride) + 1; + ny_out = ((ny_padded_in - m) >> log_stride) + 1; + + new_nx_in = nx_out; + new_ny_in = ny_out; + conv_layer_cnt = conv.ty == FFT ? FFT_SIZE : conv.ty == NAIVE ? NCONV_SIZE : NCONV_FAST_SIZE; +} + +void neuralNetwork::refreshFCParam(const fconKernel &fc) { + nx_in = nx_out = m = 1; + ny_in = ny_out = 1; + channel_in = fc.channel_in; + channel_out = fc.channel_out; +} + +i64 neuralNetwork::getFFTLen() const { + return 1L << getFFTBitLen(); +} + +i8 neuralNetwork::getFFTBitLen() const { + return ceilPow2BitLength( (u32)nx_padded_in * ny_padded_in ) + 1; +} + +// input: [data] +// [[conv_kernel || relu_conv_bit_decmp]{sec.size()}[max_pool]{if maxPool}[pool_bit_decmp]]{conv_section.size()} +// [fc_kernel || relu_fc_bit_decmp] +void neuralNetwork::initParam() { + act_layer_cnt = RELU_SIZE; + i64 total_conv_layer_cnt = 0, total_pool_layer_cnt = 0; + total_in_size = 0; + total_para_size = 0; + total_relu_in_size = 0; + total_ave_in_size = 0; + total_max_in_size = 0; + + // data + i64 pos = pic_size_x * pic_size_y * pic_channel * pic_parallel; + + new_nx_in = pic_size_x; + new_ny_in = pic_size_y; + for (i64 i = 0; i < conv_section.size(); ++i) { + auto &sec = conv_section[i]; + for (i64 j = 0; j < sec.size(); ++j) { + refreshConvParam(new_nx_in, new_ny_in, sec[j]); + // conv_kernel + sec[j].weight_start_id = pos; + u32 para_size = sqr(m) * channel_in * channel_out; + pos += para_size; + total_para_size += para_size; + fprintf(stderr, "kernel weight: %11d%11lld\n", para_size, total_para_size); + + sec[j].bias_start_id = pos; + pos += channel_out; + total_para_size += channel_out; + fprintf(stderr, "bias weight: %11lld%11lld\n", channel_out, total_para_size); + } + + total_conv_layer_cnt += sec.size() * (conv_layer_cnt + act_layer_cnt); + + if (i >= pool.size()) continue; + calcSizeAfterPool(pool[i]); + total_pool_layer_cnt += pool_layer_cnt; + if (pool[i].ty == MAX) + if (act_ty == RELU_ACT) total_conv_layer_cnt -= act_layer_cnt; + } + + for (int i = 0; i < full_conn.size(); ++i) { + auto &fc = full_conn[i]; + refreshFCParam(fc); + // fc_kernel + fc.weight_start_id = pos; + u32 para_size = channel_out * channel_in; + pos += para_size; + total_para_size += para_size; + fprintf(stderr, "kernel weight: %11d%11lld\n", para_size, total_para_size); + fc.bias_start_id = pos; + pos += channel_out; + total_para_size += channel_out; + fprintf(stderr, "bias weight: %11lld%11lld\n", channel_out, total_para_size); + if (i == full_conn.size() - 1) break; + } + total_in_size = pos; + + SIZE = 1 + total_conv_layer_cnt + total_pool_layer_cnt + (FC_SIZE + RELU_SIZE) * full_conn.size(); + if (!full_conn.empty()) SIZE -= RELU_SIZE; + cerr << "SIZE: " << SIZE << endl; +} + +void neuralNetwork::printLayerInfo(const layer &circuit, i64 layer_id) { +// fprintf(stderr, "+ %2lld " , layer_id); +// switch (circuit.ty) { +// case layerType::INPUT: fprintf(stderr, "inputLayer "); break; +// case layerType::PADDING: fprintf(stderr, "paddingLayer "); break; +// case layerType::FFT: fprintf(stderr, "fftLayer "); break; +// case layerType::DOT_PROD: fprintf(stderr, "dotProdLayer "); break; +// case layerType::IFFT: fprintf(stderr, "ifftLayer "); break; +// case layerType::ADD_BIAS: fprintf(stderr, "addBiasLayer "); break; +// case layerType::RELU: fprintf(stderr, "reluActLayer "); break; +// case layerType::Sqr: fprintf(stderr, "squareActLayer "); break; +// case layerType::OPT_AVG_POOL: fprintf(stderr, "avgOptPoolingLayer "); break; +// case layerType::AVG_POOL: fprintf(stderr, "avgPoolingLayer "); break; +// case layerType::MAX_POOL: fprintf(stderr, "maxPoolingLayer "); break; +// case layerType::FCONN: fprintf(stderr, "fullyConnLayer "); break; +// case layerType::NCONV: fprintf(stderr, "naiveConvFast "); break; +// case layerType::NCONV_MUL: fprintf(stderr, "naiveConvMul "); break; +// case layerType::NCONV_ADD: fprintf(stderr, "naiveConvAdd "); break; +//m +// } +// fprintf(stderr, "%11u (2^%2d)\n", circuit.size, (int) circuit.bit_length); +} + +void neuralNetwork::printWitnessInfo(const layer &circuit) const { + assert(circuit.size == total_in_size); + u32 total_data_in_size = total_in_size - total_relu_in_size - total_ave_in_size - total_max_in_size; + fprintf(stderr,"%u (2^%2d) = %u (%.2f%% data) + %lld (%.2f%% relu) + %lld (%.2f%% ave) + %lld (%.2f%% max), ", + circuit.size, circuit.bit_length, total_data_in_size, 100.0 * total_data_in_size / (double) total_in_size, + total_relu_in_size, 100.0 * total_relu_in_size / (double) total_in_size, + total_ave_in_size, 100.0 * total_ave_in_size / (double) total_in_size, + total_max_in_size, 100.0 * total_max_in_size / (double) total_in_size); + output_tb[WS_OUT_ID] = std::to_string(circuit.size) + "(2^" + std::to_string(ceilPow2BitLength(circuit.size)) + ")"; +} + +i64 neuralNetwork::getPoolDecmpSize() const { + switch (pool_ty) { + case AVG: return new_nx_in * new_ny_in * (pool_bl << 1) * channel_out * pic_parallel; + case MAX: return new_nx_in * new_ny_in * sqr(pool_sz) * channel_out * pic_parallel * (Q_MAX - 1); + default: + assert(false); + } +} + +void neuralNetwork::calcSizeAfterPool(const poolKernel &p) { + pool_sz = p.size; + pool_bl = ceilPow2BitLength(pool_sz); + pool_stride_bl = p.stride_bl; + pool_stride = 1 << p.stride_bl; + pool_layer_cnt = p.ty == MAX ? 1 + ceilPow2BitLength(sqr(p.size) + 1) : AVE_POOL_SIZE; + new_nx_in = ((nx_out - pool_sz) >> pool_stride_bl) + 1; + new_ny_in = ((ny_out - pool_sz) >> pool_stride_bl) + 1; +} + +void neuralNetwork::calcInputLayer(layer &circuit) { + val[0].resize(circuit.size); + + assert(val[0].size() == total_in_size); + auto val_0 = val[0].begin(); + + double num, mx = -10000, mn = 10000; + vector input_dat; + for (i64 ci = 0; ci < pic_channel; ++ci) + for (i64 x = 0; x < pic_size_x; ++x) + for (i64 y = 0; y < pic_size_y; ++y) { + in >> num; + input_dat.push_back(num); + mx = max(mx, num); + mn = min(mn, num); + } + + // (mx - mn) * 2^i <= 2^Q - 1 + // quant_shr = i + x_next_bit = (int) (log( ((1 << (Q - 1)) - 1) / (mx - mn) ) / log(2)); + if ((int) ((mx - mn) * exp2(x_next_bit)) > (1 << (Q - 1)) - 1) --x_next_bit; + + for (i64 p = 0; p < pic_parallel; ++p) { + i64 i = 0; + for (i64 ci = 0; ci < pic_channel; ++ci) + for (i64 x = 0; x < pic_size_x; ++x) + for (i64 y = 0; y < pic_size_y; ++y) + *val_0++ = F((i64)(input_dat[i++] * exp2(x_next_bit))); + } + for (; val_0 < val[0].begin() + circuit.size; ++val_0) val_0 -> clear(); +} + + +void neuralNetwork::readConvWeight(i64 first_conv_id) { + auto val_0 = val[0].begin() + first_conv_id; + + double num, mx = -10000, mn = 10000; + vector input_dat; + for (i64 co = 0; co < channel_out; ++co) + for (i64 ci = 0; ci < channel_in; ++ci) + for (i64 x = 0; x < m; ++x) + for (i64 y = 0; y < m; ++y) { + in >> num; + input_dat.push_back(num); + mx = max(mx, num); + mn = min(mn, num); + } + + // (mx - mn) * 2^i <= 2^Q - 1 + // quant_shr = i + w_bit = (int) (log( ((1 << (Q - 1)) - 1) / (mx - mn) ) / log(2)); + if ((int) ((mx - mn) * exp2(w_bit)) > (1 << (Q - 1)) - 1) --w_bit; + + for (double i : input_dat) *val_0++ = F((i64) (i * exp2(w_bit))); + +} + +void neuralNetwork::readBias(i64 first_bias_id) { + auto val_0 = val[0].begin() + first_bias_id; + + double num, mx = -10000, mn = 10000; + vector input_dat; + for (i64 co = 0; co < channel_out; ++co) { + in >> num; + input_dat.push_back(num); + mx = max(mx, num); + mn = min(mn, num); + } + + for (double i : input_dat) *val_0++ = F((i64) (i * exp2(w_bit + x_bit))); + +} + +void neuralNetwork::readFconWeight(i64 first_fc_id) { + double num, mx = -10000, mn = 10000; + auto val_0 = val[0].begin() + first_fc_id; + + vector input_dat; + for (i64 co = 0; co < channel_out; ++co) + for (i64 ci = 0; ci < channel_in; ++ci) { + in >> num; + input_dat.push_back(num); + mx = max(mx, num); + mn = min(mn, num); + } + + // (mx - mn) * 2^i <= 2^Q - 1 + // quant_shr = i + w_bit = (int) (log( ((1 << (Q - 1)) - 1) / (mx - mn) ) / log(2)); + if ((int) ((mx - mn) * exp2(w_bit)) > (1 << (Q - 1)) - 1) --w_bit; + + for (double i : input_dat) *val_0++ = F((i64) (i * exp2(w_bit))); +} + +void neuralNetwork::prepareDecmpBit(i64 layer_id, i64 idx, i64 dcmp_id, i64 bit_shift) { + auto data = abs(val[layer_id].at(idx).getInt64()); + val[0].at(dcmp_id) = (data >> bit_shift) & 1; +} + +void neuralNetwork::prepareFieldBit(const F &data, i64 dcmp_id, i64 bit_shift) { + auto tmp = abs(data.getInt64()); + val[0].at(dcmp_id) = (tmp >> bit_shift) & 1; +} + +void neuralNetwork::prepareSignBit(i64 layer_id, i64 idx, i64 dcmp_id) { + val[0].at(dcmp_id) = val[layer_id].at(idx).isNegative() ? F_ONE : F_ZERO; +} + +void neuralNetwork::prepareMax(i64 layer_id, i64 idx, i64 max_id) { + auto data = val[layer_id].at(idx).isNegative() ? F_ZERO : val[layer_id].at(idx); + if (data > val[0].at(max_id)) val[0].at(max_id) = data; +} + +void neuralNetwork::calcNormalLayer(const layer &circuit, i64 layer_id) { + val[layer_id].resize(circuit.size); + for (auto &x: val[layer_id]) x.clear(); + + for (auto &gate: circuit.uni_gates) { + val[layer_id].at(gate.g) = val[layer_id].at(gate.g) + val[gate.lu].at(gate.u) * two_mul[gate.sc]; + } + + + for (auto &gate: circuit.bin_gates) { + u8 bin_lu = gate.getLayerIdU(layer_id), bin_lv = gate.getLayerIdV(layer_id); + val[layer_id].at(gate.g) = val[layer_id].at(gate.g) + val[bin_lu].at(gate.u) * val[bin_lv][gate.v] * two_mul[gate.sc]; + } + + F mx_val = F_ZERO, mn_val = F_ZERO; + for (i64 g = 0; g < circuit.size; ++g) + val[layer_id].at(g) = val[layer_id].at(g) * circuit.scale; +} + +void neuralNetwork::calcDotProdLayer(const layer &circuit, i64 layer_id) { + val[layer_id].resize(circuit.size); + for (int i = 0; i < circuit.size; ++i) val[layer_id][i].clear(); + + char fft_bit = circuit.fft_bit_length; + u32 fft_len = 1 << fft_bit; + u8 l = layer_id - 1; + for (auto &gate: circuit.bin_gates) + for (int s = 0; s < fft_len; ++s) + val[layer_id][gate.g << fft_bit | s] = val[layer_id][gate.g << fft_bit | s] + + val[l][gate.u << fft_bit | s] * val[l][gate.v << fft_bit | s]; +} + +void neuralNetwork::calcFFTLayer(const layer &circuit, i64 layer_id) { + i64 fft_len = 1ULL << circuit.fft_bit_length; + i64 fft_lenh = fft_len >> 1; + val[layer_id].resize(circuit.size); + std::vector arr(fft_len, F_ZERO); + if (circuit.ty == layerType::FFT) for (i64 c = 0, d = 0; d < circuit.size; c += fft_lenh, d += fft_len) { + for (i64 j = c; j < c + fft_lenh; ++j) arr[j - c] = val[layer_id - 1].at(j); + for (i64 j = fft_lenh; j < fft_len; ++j) arr[j].clear(); + fft(arr, circuit.fft_bit_length, circuit.ty == layerType::IFFT); + for (i64 j = d; j < d + fft_len; ++j) val[layer_id].at(j) = arr[j - d]; + } else for (u32 c = 0, d = 0; c < circuit.size; c += fft_lenh, d += fft_len) { + for (i64 j = d; j < d + fft_len; ++j) arr[j - d] = val[layer_id - 1].at(j); + fft(arr, circuit.fft_bit_length, circuit.ty == layerType::IFFT); + for (i64 j = c; j < c + fft_lenh; ++j) val[layer_id].at(j) = arr[j - c]; + } +} + +int neuralNetwork::getNextBit(int layer_id) { + F mx = F_ZERO, mn = F_ZERO; + for (const auto &x: val[layer_id]) { + if (!x.isNegative()) mx = max(mx, x); + else mn = max(mn, -x); + } + i64 x = (mx + mn).getInt64(); + double real_scale = x / exp2(x_bit + w_bit); + int res = (int) log2( ((1 << (Q - 1)) - 1) / real_scale ); + return res; +} + +void neuralNetwork::printLayerValues(prover &pr) { + for (i64 i = 0; i < SIZE; ++i) { +// if (pr.C.circuit[i].ty == layerType::FCONN || pr.C.circuit[i].ty == layerType::ADD_BIAS || i && i < SIZE - 1 && pr.C.circuit[i + 1].ty == layerType::PADDING) { + cerr << i << "(" << pr.C.circuit[i].zero_start_id << ", " << pr.C.circuit[i].size << "):\t"; + for (i64 j = 0; j < std::min(200u, pr.C.circuit[i].size); ++j) + if (!pr.val[i][j].isZero()) cerr << pr.val[i][j] << ' '; + cerr << endl; + for (i64 j = pr.C.circuit[i].zero_start_id; j < pr.C.circuit[i].size; ++j) + if (pr.val[i].at(j) != F_ZERO) { + cerr << "WRONG! " << i << ' ' << j << ' ' << (-pr.val[i][j] * F_ONE) << endl; + exit(EXIT_FAILURE); + } + } +} + +void neuralNetwork::printInfer(prover &pr) { + // output the inference result with the size of (pic_parallel x n_class) + if (out.is_open()) { + int n_class = full_conn.back().channel_out; + for (int p = 0; p < pic_parallel; ++p) { + int k = -1; + F v; + for (int c = 0; c < n_class; ++c) { + auto tmp = val[SIZE - 1].at(matIdx(p, c, n_class)); + if (!tmp.isNegative() && (k == -1 || v < tmp)) { + k = c; + v = tmp; + } + } + out << k << endl; + + // output one-hot +// for (int c = 0; c < n_class; ++c) out << (k == c) << ' '; +// out << endl; + } + } + out.close(); + printWitnessInfo(pr.C.circuit[0]); +} \ No newline at end of file diff --git a/src/neuralNetwork.hpp b/src/neuralNetwork.hpp new file mode 100644 index 0000000..acfcc99 --- /dev/null +++ b/src/neuralNetwork.hpp @@ -0,0 +1,169 @@ +// +// Created by 69029 on 3/16/2021. +// + +#ifndef ZKCNN_NEURALNETWORK_HPP +#define ZKCNN_NEURALNETWORK_HPP + +#include +#include +#include "circuit.h" +#include "prover.hpp" + +using std::vector; +using std::tuple; +using std::pair; + +enum convType { + FFT, NAIVE, NAIVE_FAST +}; + +struct convKernel { + convType ty; + i64 channel_out, channel_in, size, stride_bl, padding, weight_start_id, bias_start_id; + convKernel(convType _ty, i64 _channel_out, i64 _channel_in, i64 _size, i64 _log_stride, i64 _padding) : + ty(_ty), channel_out(_channel_out), channel_in(_channel_in), size(_size), stride_bl(_log_stride), padding(_padding){} + + convKernel(convType _ty, i64 _channel_out, i64 _channel_in, i64 _size): + convKernel(_ty, _channel_out, _channel_in, _size, 0, _size >> 1) {} +}; + +struct fconKernel { + i64 channel_out, channel_in, weight_start_id, bias_start_id; + fconKernel(i64 _channel_out, i64 _channel_in): + channel_out(_channel_out), channel_in(_channel_in) {} +}; + +enum poolType { + AVG, MAX, NONE +}; + +enum actType { + RELU_ACT +}; + +struct poolKernel { + poolType ty; + i64 size, stride_bl, dcmp_start_id, max_start_id, max_dcmp_start_id; + poolKernel(poolType _ty, i64 _size, i64 _log_stride): + ty(_ty), size(_size), stride_bl(_log_stride) {} +}; + + +class neuralNetwork { +public: + explicit neuralNetwork(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, const string &i_filename, + const string &c_filename, const string &o_filename); + + neuralNetwork(i64 psize, i64 pchannel, i64 pparallel, i64 kernel_size, i64 sec_size, i64 fc_size, + i64 start_channel, convType conv_ty, poolType pool_ty); + + void create(prover &pr, bool only_compute); + +protected: + + void initParam(); + + int getNextBit(int layer_id); + + void refreshConvParam(i64 new_nx, i64 new_ny, const convKernel &conv); + + void calcSizeAfterPool(const poolKernel &p); + + void refreshFCParam(const fconKernel &fc); + + [[nodiscard]] i64 getFFTLen() const; + + [[nodiscard]] i8 getFFTBitLen() const; + + [[nodiscard]] i64 getPoolDecmpSize() const; + + void prepareDecmpBit(i64 layer_id, i64 idx, i64 dcmp_id, i64 bit_shift); + + void prepareFieldBit(const F &data, i64 dcmp_id, i64 bit_shift); + + void prepareSignBit(i64 layer_id, i64 idx, i64 dcmp_id); + + void prepareMax(i64 layer_id, i64 idx, i64 max_id); + + void calcInputLayer(layer &circuit); + + void calcNormalLayer(const layer &circuit, i64 layer_id); + + void calcDotProdLayer(const layer &circuit, i64 layer_id); + + void calcFFTLayer(const layer &circuit, i64 layer_id); + + vector> conv_section; + vector pool; + poolType pool_ty; + i64 pool_bl, pool_sz; + i64 pool_stride_bl, pool_stride; + i64 pool_layer_cnt, act_layer_cnt, conv_layer_cnt; + actType act_ty; + + vector full_conn; + + i64 pic_size_x, pic_size_y, pic_channel, pic_parallel; + i64 SIZE; + const i64 NCONV_FAST_SIZE, NCONV_SIZE, FFT_SIZE, AVE_POOL_SIZE, FC_SIZE, RELU_SIZE; + i64 T; + const i64 Q = 9; + i64 Q_MAX; + const i64 Q_BIT_SIZE = 220; + + i64 nx_in, nx_out, ny_in, ny_out, m, channel_in, channel_out, log_stride, padding; + i64 new_nx_in, new_ny_in; + i64 nx_padded_in, ny_padded_in; + i64 total_in_size, total_para_size, total_relu_in_size, total_ave_in_size, total_max_in_size; + int x_bit, w_bit, x_next_bit; + + vector>::iterator val; + vector::iterator two_mul; + + void inputLayer(layer &circuit); + + void paddingLayer(layer &circuit, i64 &layer_id, i64 first_conv_id); + + void fftLayer(layer &circuit, i64 &layer_id); + + void dotProdLayer(layer &circuit, i64 &layer_id); + + void ifftLayer(layer &circuit, i64 &layer_id); + + void addBiasLayer(layer &circuit, i64 &layer_id, i64 first_bias_id); + + void naiveConvLayerFast(layer &circuit, i64 &layer_id, i64 first_conv_id, i64 first_bias_id); + + void naiveConvLayerMul(layer &circuit, i64 &layer_id, i64 first_conv_id); + + void naiveConvLayerAdd(layer &circuit, i64 &layer_id, i64 first_bias_id); + + void reluActConvLayer(layer &circuit, i64 &layer_id); + + void reluActFconLayer(layer &circuit, i64 &layer_id); + + void avgPoolingLayer(layer &circuit, i64 &layer_id); + + void + maxPoolingLayer(layeredCircuit &C, i64 &layer_id, i64 first_dcmp_id, i64 first_max_id, i64 first_max_dcmp_id); + + void fullyConnLayer(layer &circuit, i64 &layer_id, i64 first_fc_id, i64 first_bias_id); + + static void printLayerInfo(const layer &circuit, i64 layer_id); + + void readBias(i64 first_bias_id); + + void readConvWeight(i64 first_conv_id); + + void readFconWeight(i64 first_fc_id); + + void printWitnessInfo(const layer &circuit) const; + + void printLayerValues(prover &pr); + + void printInfer(prover &pr); +}; + + +#endif //ZKCNN_NEURALNETWORK_HPP diff --git a/src/polynomial.cpp b/src/polynomial.cpp new file mode 100644 index 0000000..5b0fe3f --- /dev/null +++ b/src/polynomial.cpp @@ -0,0 +1,130 @@ +#include +#include "polynomial.h" + +quintuple_poly::quintuple_poly() { a.clear(); b.clear(); c.clear(); d.clear(); e.clear(); f.clear();} +quintuple_poly::quintuple_poly(const F &aa, const F &bb, const F &cc, const F &dd, const F &ee, const F &ff) { + a = aa; + b = bb; + c = cc; + d = dd; + e = ee; + f = ff; +} + +quintuple_poly quintuple_poly::operator + (const quintuple_poly &x) const { + return quintuple_poly(a + x.a, b + x.b, c + x.c, d + x.d, e + x.e, f + x.f); +} + +F quintuple_poly::eval(const F &x) const { + return (((((a * x) + b) * x + c) * x + d) * x + e) * x + f; +} + +void quintuple_poly::clear() { + a.clear(); b.clear(); c.clear(); d.clear(); e.clear(); f.clear(); +} + +quadruple_poly::quadruple_poly() {a.clear(); b.clear(); c.clear(); d.clear(); e.clear();} +quadruple_poly::quadruple_poly(const F &aa, const F &bb, const F &cc, const F &dd, const F &ee) { + a = aa; + b = bb; + c = cc; + d = dd; + e = ee; +} + +quadruple_poly quadruple_poly::operator + (const quadruple_poly &x) const { + return quadruple_poly(a + x.a, b + x.b, c + x.c, d + x.d, e + x.e); +} + +F quadruple_poly::eval(const F &x) const { + return ((((a * x) + b) * x + c) * x + d) * x + e; +} + +void quadruple_poly::clear() { + a.clear(); b.clear(); c.clear(); d.clear(); e.clear(); +} + +cubic_poly::cubic_poly() {a.clear(); b.clear(); c.clear(); d.clear();} +cubic_poly::cubic_poly(const F &aa, const F &bb, const F &cc, const F &dd) { + a = aa; + b = bb; + c = cc; + d = dd; +} + +cubic_poly cubic_poly::operator + (const cubic_poly &x) const { + return cubic_poly(a + x.a, b + x.b, c + x.c, d + x.d); +} + +F cubic_poly::eval(const F &x) const { + return (((a * x) + b) * x + c) * x + d; +} + +quadratic_poly::quadratic_poly() {a.clear(); b.clear(); c.clear();} +quadratic_poly::quadratic_poly(const F &aa, const F &bb, const F &cc) { + a = aa; + b = bb; + c = cc; +} + +quadratic_poly quadratic_poly::operator + (const quadratic_poly &x) const { + return quadratic_poly(a + x.a, b + x.b, c + x.c); +} + +quadratic_poly quadratic_poly::operator+(const linear_poly &x) const { + return quadratic_poly(a, b + x.a, c + x.b); +} + +cubic_poly quadratic_poly::operator * (const linear_poly &x) const { + return cubic_poly(a * x.a, a * x.b + b * x.a, b * x.b + c * x.a, c * x.b); +} + +cubic_poly cubic_poly::operator * (const F &x) const { + return cubic_poly(a * x, b * x, c * x, d * x); +} + +void cubic_poly::clear() { + a.clear(); b.clear(); c.clear(); d.clear(); +} + +quadratic_poly quadratic_poly::operator*(const F &x) const { + return quadratic_poly(a * x, b * x, c * x); +} + +F quadratic_poly::eval(const F &x) const { + return ((a * x) + b) * x + c; +} + +void quadratic_poly::clear() { + a.clear(); b.clear(); c.clear(); +} + +linear_poly::linear_poly() {a.clear(); b.clear();} +linear_poly::linear_poly(const F &aa, const F &bb) { + a = aa; + b = bb; +} +linear_poly::linear_poly(const F &x) { + a.clear(); + b = x; +} + +linear_poly linear_poly::operator + (const linear_poly &x) const { + return linear_poly(a + x.a, b + x.b); +} + +quadratic_poly linear_poly::operator * (const linear_poly &x) const { + return quadratic_poly(a * x.a, a * x.b + b * x.a, b * x.b); +} + +linear_poly linear_poly::operator*(const F &x) const { + return linear_poly(a * x, b * x); +} + +F linear_poly::eval(const F &x) const { + return a * x + b; +} + +void linear_poly::clear() { + a.clear(); b.clear(); +} diff --git a/src/polynomial.h b/src/polynomial.h new file mode 100644 index 0000000..cc147f2 --- /dev/null +++ b/src/polynomial.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include +#include "global_var.hpp" + +class linear_poly; + +//ax^3 + bx^2 + cx + d +class cubic_poly { +public: + F a, b, c, d; + cubic_poly(); + cubic_poly(const F &, const F &, const F &, const F &); + cubic_poly operator + (const cubic_poly &) const; + cubic_poly operator * (const F &) const; + F eval(const F &) const; + void clear(); +}; + +//ax^2 + bx + c +class quadratic_poly { +public: + F a, b, c; + quadratic_poly(); + quadratic_poly(const F &, const F &, const F &); + quadratic_poly operator + (const quadratic_poly &) const; + quadratic_poly operator + (const linear_poly &) const; + cubic_poly operator * (const linear_poly &) const; + quadratic_poly operator * (const F &) const; + F eval(const F &) const; + void clear(); +}; + + +//ax + b +class linear_poly { +public: + F a, b; + linear_poly(); + linear_poly(const F &, const F &); + linear_poly(const F &); + linear_poly operator + (const linear_poly &) const; + quadratic_poly operator * (const linear_poly &) const; + linear_poly operator * (const F &) const; + F eval(const F &) const; + void clear(); +}; + + + +//ax^4 + bx^3 + cx^2 + dx + e +class quadruple_poly { +public: + F a, b, c, d, e; + quadruple_poly(); + quadruple_poly(const F &, const F &, const F &, const F &, const F &); + quadruple_poly operator + (const quadruple_poly &) const; + F eval(const F &) const; + void clear(); +}; + +//ax^5 + bx^4 + cx^3 + dx^2 + ex + f +class quintuple_poly { +public: + F a, b, c, d, e, f; + quintuple_poly(); + quintuple_poly(const F &, const F &, const F &, const F &, const F &, const F &); + quintuple_poly operator + (const quintuple_poly &) const; + F eval(const F &) const; + void clear(); +}; \ No newline at end of file diff --git a/src/prover.cpp b/src/prover.cpp new file mode 100644 index 0000000..0c99714 --- /dev/null +++ b/src/prover.cpp @@ -0,0 +1,511 @@ +// +// Created by 69029 on 3/9/2021. +// + +#include "prover.hpp" +#include +#include + +static vector beta_gs, beta_u; + +using std::unique_ptr; + +linear_poly interpolate(const F &zero_v, const F &one_v) { + return {one_v - zero_v, zero_v}; +} + +void prover::init() { + proof_size = 0; + r_u.resize(C.size + 1); + r_v.resize(C.size + 1); +} + +/** + * This is to initialize all process. + * + * @param the random point to be evaluated at the output layer + */ +void prover::sumcheckInitAll(const vector::const_iterator &r_0_from_v) { + sumcheck_id = C.size; + i8 last_bl = C.circuit[sumcheck_id - 1].bit_length; + r_u[sumcheck_id].resize(last_bl); + + prove_timer.start(); + for (int i = 0; i < last_bl; ++i) r_u[sumcheck_id][i] = r_0_from_v[i]; + prove_timer.stop(); +} + +/** + * This is to initialize before the process of a single layer. + * + * @param the random combination coefficiants for multiple reduction points + */ +void prover::sumcheckInit(const F &alpha_0, const F &beta_0) { + prove_timer.start(); + auto &cur = C.circuit[sumcheck_id]; + alpha = alpha_0; + beta = beta_0; + r_0 = r_u[sumcheck_id].begin(); + r_1 = r_v[sumcheck_id].begin(); + --sumcheck_id; + prove_timer.stop(); +} + +/** + * This is to initialize before the phase 1 of a single inner production layer. + */ +void prover::sumcheckDotProdInitPhase1() { + fprintf(stderr, "sumcheck level %d, phase1 init start\n", sumcheck_id); + + auto &cur = C.circuit[sumcheck_id]; + i8 fft_bl = cur.fft_bit_length; + i8 cnt_bl = cur.bit_length - fft_bl; + total[0] = 1ULL << fft_bl; + total[1] = 1ULL << cur.bit_length_u[1]; + total_size[1] = cur.size_u[1]; + u32 fft_len = total[0]; + + r_u[sumcheck_id].resize(cur.max_bl_u); + V_mult[0].resize(total[1]); + V_mult[1].resize(total[1]); + mult_array[1].resize(total[0]); + beta_gs.resize(1ULL << fft_bl); + + prove_timer.start(); + + initBetaTable(beta_gs, fft_bl, r_0, F_ONE); + + for (u32 t = 0; t < fft_len; ++t) + mult_array[1][t] = beta_gs[t]; + for (u32 u = 0; u < total[1]; ++u) { + V_mult[0][u].clear(); + if (u >= cur.size_u[1]) V_mult[1][u].clear(); + else V_mult[1][u] = val[sumcheck_id - 1][u]; + } + + for (auto &gate: cur.bin_gates) + for (u32 t = 0; t < fft_len; ++t) { + u32 idx_u = gate.u << fft_bl | t; + u32 idx_v = gate.v << fft_bl | t; + V_mult[0][idx_u] = V_mult[0][idx_u] + beta_g[gate.g] * val[sumcheck_id - 1][idx_v]; + } + + round = 0; + prove_timer.stop(); +} + +/** + * This is the one-step reduction within a sumcheck process of a single inner production layer. + * + * @param the random point of the reduction of the previous step + * @return the reducted cubic degree polynomial of the current variable from prover to verifier + */ +cubic_poly prover::sumcheckDotProdUpdate1(const F &previous_random) { + prove_timer.start(); + + if (round) r_u[sumcheck_id].at(round - 1) = previous_random; + ++round; + + auto &tmp_mult = mult_array[1]; + auto &tmp_v0 = V_mult[0], &tmp_v1 = V_mult[1]; + + if (total[0] == 1) + tmp_mult[0] = tmp_mult[0].eval(previous_random); + else for (u32 i = 0; i < (total[0] >> 1); ++i) { + u32 g0 = i << 1, g1 = i << 1 | 1; + tmp_mult[i] = interpolate(tmp_mult[g0].eval(previous_random), tmp_mult[g1].eval(previous_random)); + } + total[0] >>= 1; + + cubic_poly ret; + for (u32 i = 0; i < (total[1] >> 1); ++i) { + u32 g0 = i << 1, g1 = i << 1 | 1; + if (g0 >= total_size[1]) { + tmp_v0[i].clear(); + tmp_v1[i].clear(); + continue; + } + if (g1 >= total_size[1]) { + tmp_v0[g1].clear(); + tmp_v1[g1].clear(); + } + tmp_v0[i] = interpolate(tmp_v0[g0].eval(previous_random), tmp_v0[g1].eval(previous_random)); + tmp_v1[i] = interpolate(tmp_v1[g0].eval(previous_random), tmp_v1[g1].eval(previous_random)); + if (total[0]) ret = ret + tmp_mult[i & total[0] - 1] * tmp_v1[i] * tmp_v0[i]; + else ret = ret + tmp_mult[0] * tmp_v1[i] * tmp_v0[i]; + } + proof_size += F_BYTE_SIZE * (3 + (!ret.a.isZero())); + + total[1] >>= 1; + total_size[1] = (total_size[1] + 1) >> 1; + + prove_timer.stop(); + return ret; +} + +void prover::sumcheckDotProdFinalize1(const F &previous_random, F &claim_1) { + prove_timer.start(); + r_u[sumcheck_id].at(round - 1) = previous_random; + claim_1 = V_mult[1][0].eval(previous_random); + V_u1 = V_mult[1][0].eval(previous_random) * mult_array[1][0].eval(previous_random); + prove_timer.stop(); + proof_size += F_BYTE_SIZE * 1; +} + +void prover::sumcheckInitPhase1(const F &relu_rou_0) { + fprintf(stderr, "sumcheck level %d, phase1 init start\n", sumcheck_id); + + auto &cur = C.circuit[sumcheck_id]; + total[0] = ~cur.bit_length_u[0] ? 1ULL << cur.bit_length_u[0] : 0; + total_size[0] = cur.size_u[0]; + total[1] = ~cur.bit_length_u[1] ? 1ULL << cur.bit_length_u[1] : 0; + total_size[1] = cur.size_u[1]; + + r_u[sumcheck_id].resize(cur.max_bl_u); + V_mult[0].resize(total[0]); + V_mult[1].resize(total[1]); + mult_array[0].resize(total[0]); + mult_array[1].resize(total[1]); + beta_g.resize(1ULL << cur.bit_length); + if (cur.ty == layerType::PADDING) beta_gs.resize(1ULL << cur.fft_bit_length); + if (cur.ty == layerType::FFT || cur.ty == layerType::IFFT) + beta_gs.resize(total[1]); + + prove_timer.start(); + + relu_rou = relu_rou_0; + add_term.clear(); + for (int b = 0; b < 2; ++b) + for (u32 u = 0; u < total[b]; ++u) + mult_array[b][u].clear(); + + if (cur.ty == layerType::FFT || cur.ty == layerType::IFFT) { + i8 fft_bl = cur.fft_bit_length; + i8 fft_blh = cur.fft_bit_length - 1; + i8 cnt_bl = cur.ty == layerType::FFT ? cur.bit_length - fft_bl : cur.bit_length - fft_blh; + u32 cnt_len = cur.size >> (cur.ty == layerType::FFT ? fft_bl : fft_blh); + if (cur.ty == layerType::FFT) + initBetaTable(beta_g, cnt_bl, r_0 + fft_bl, r_1, alpha, beta); + else initBetaTable(beta_g, cnt_bl, r_0 + fft_blh, alpha); + for (u32 u = 0, l = sumcheck_id - 1; u < total[1]; ++u) { + V_mult[1][u].clear(); + if (u >= cur.size_u[1]) continue; + for (u32 g = 0; g < cnt_len; ++g) { + u32 idx = g << cur.max_bl_u | u; + V_mult[1][u] = V_mult[1][u] + val[l][idx] * beta_g[g]; + } + } + + beta_gs.resize(total[1]); + phiGInit(beta_gs, r_0, cur.scale, fft_bl, cur.ty == layerType::IFFT); + for (u32 u = 0; u < total[1] ; ++u) { + mult_array[1][u] = beta_gs[u]; + } + } else { + for (int b = 0; b < 2; ++b) { + auto dep = !b ? 0 : sumcheck_id - 1; + for (u32 u = 0; u < total[b]; ++u) { + if (u >= cur.size_u[b]) + V_mult[b][u].clear(); + else V_mult[b][u] = getCirValue(dep, cur.ori_id_u, u); + } + } + + if (cur.ty == layerType::PADDING) { + i8 fft_blh = cur.fft_bit_length - 1; + u32 fft_lenh = 1ULL << fft_blh; + initBetaTable(beta_gs, fft_blh, r_0, F_ONE); + for (long g = (1L << cur.bit_length) - 1; g >= 0; --g) + beta_g[g] = beta_g[g >> fft_blh] * beta_gs[g & fft_lenh - 1]; + } else initBetaTable(beta_g, cur.bit_length, r_0, r_1, alpha * cur.scale, beta * cur.scale); + if (cur.zero_start_id < cur.size) + for (u32 g = cur.zero_start_id; g < 1ULL << cur.bit_length; ++g) beta_g[g] = beta_g[g] * relu_rou; + + for (auto &gate: cur.uni_gates) { + bool idx = gate.lu != 0; + mult_array[idx][gate.u] = mult_array[idx][gate.u] + beta_g[gate.g] * C.two_mul[gate.sc]; + } + + for (auto &gate: cur.bin_gates) { + bool idx = gate.getLayerIdU(sumcheck_id) != 0; + auto val_lv = getCirValue(gate.getLayerIdV(sumcheck_id), cur.ori_id_v, gate.v); + mult_array[idx][gate.u] = mult_array[idx][gate.u] + val_lv * beta_g[gate.g] * C.two_mul[gate.sc]; + } + } + + round = 0; + prove_timer.stop(); + fprintf(stderr, "sumcheck level %d, phase1 init finished\n", sumcheck_id); +} + +void prover::sumcheckInitPhase2() { + fprintf(stderr, "sumcheck level %d, phase2 init start\n", sumcheck_id); + + auto &cur = C.circuit[sumcheck_id]; + total[0] = ~cur.bit_length_v[0] ? 1ULL << cur.bit_length_v[0] : 0; + total_size[0] = cur.size_v[0]; + total[1] = ~cur.bit_length_v[1] ? 1ULL << cur.bit_length_v[1] : 0; + total_size[1] = cur.size_v[1]; + i8 fft_bl = cur.fft_bit_length; + i8 cnt_bl = cur.max_bl_v; + + r_v[sumcheck_id].resize(cur.max_bl_v); + + V_mult[0].resize(total[0]); + V_mult[1].resize(total[1]); + mult_array[0].resize(total[0]); + mult_array[1].resize(total[1]); + + if (cur.ty == layerType::DOT_PROD) { + beta_u.resize(1ULL << cnt_bl); + beta_gs.resize(1ULL << fft_bl); + } else beta_u.resize(1ULL << cur.max_bl_u); + + prove_timer.start(); + + add_term.clear(); + for (int b = 0; b < 2; ++b) { + for (u32 v = 0; v < total[b]; ++v) + mult_array[b][v].clear(); + } + + if (cur.ty == layerType::DOT_PROD) { + u32 fft_len = 1ULL << cur.fft_bit_length; + initBetaTable(beta_u, cnt_bl, r_u[sumcheck_id].begin() + fft_bl, F_ONE); + initBetaTable(beta_gs, fft_bl, r_u[sumcheck_id].begin(), F_ONE); + + for (u32 v = 0; v < total[1]; ++v) { + V_mult[1][v].clear(); + if (v >= cur.size_v[1]) continue; + for (u32 t = 0; t < fft_len; ++t) { + u32 idx_v = (v << fft_bl) | t; + V_mult[1][v] = V_mult[1][v] + val[sumcheck_id - 1][idx_v] * beta_gs[t]; + } + } + + for (auto &gate: cur.bin_gates) + mult_array[1][gate.v] = + mult_array[1][gate.v] + beta_g[gate.g] * beta_u[gate.u] * V_u1; + } else { + initBetaTable(beta_u, cur.max_bl_u, r_u[sumcheck_id].begin(), F_ONE); + for (int b = 0; b < 2; ++b) { + auto dep = !b ? 0 : sumcheck_id - 1; + for (u32 v = 0; v < total[b]; ++v) { + V_mult[b][v] = v >= cur.size_v[b] ? F_ZERO : getCirValue(dep, cur.ori_id_v, v); + } + } + for (auto &gate: cur.uni_gates) { + auto V_u = !gate.lu ? V_u0 : V_u1; + add_term = add_term + beta_g[gate.g] * beta_u[gate.u] * V_u * C.two_mul[gate.sc]; + } + for (auto &gate: cur.bin_gates) { + bool idx = gate.getLayerIdV(sumcheck_id); + auto V_u = !gate.getLayerIdU(sumcheck_id) ? V_u0 : V_u1; + mult_array[idx][gate.v] = mult_array[idx][gate.v] + beta_g[gate.g] * beta_u[gate.u] * V_u * C.two_mul[gate.sc]; + } + } + + round = 0; + prove_timer.stop(); +} + +void prover::sumcheckLiuInit(const vector &s_u, const vector &s_v) { + sumcheck_id = 0; + total[1] = (1ULL << C.circuit[sumcheck_id].bit_length); + total_size[1] = C.circuit[sumcheck_id].size; + + r_u[0].resize(C.circuit[0].bit_length); + mult_array[1].resize(total[1]); + V_mult[1].resize(total[1]); + + i8 max_bl = 0; + for (int i = sumcheck_id + 1; i < C.size; ++i) + max_bl = max(max_bl, max(C.circuit[i].bit_length_u[0], C.circuit[i].bit_length_v[0])); + beta_g.resize(1ULL << max_bl); + + prove_timer.start(); + add_term.clear(); + + for (u32 g = 0; g < total[1]; ++g) { + mult_array[1][g].clear(); + V_mult[1][g] = (g < total_size[1]) ? val[0][g] : F_ZERO; + } + + for (u8 i = sumcheck_id + 1; i < C.size; ++i) { + i8 bit_length_i = C.circuit[i].bit_length_u[0]; + u32 size_i = C.circuit[i].size_u[0]; + if (~bit_length_i) { + initBetaTable(beta_g, bit_length_i, r_u[i].begin(), s_u[i - 1]); + for (u32 hu = 0; hu < size_i; ++hu) { + u32 u = C.circuit[i].ori_id_u[hu]; + mult_array[1][u] = mult_array[1][u] + beta_g[hu]; + } + } + + bit_length_i = C.circuit[i].bit_length_v[0]; + size_i = C.circuit[i].size_v[0]; + if (~bit_length_i) { + initBetaTable(beta_g, bit_length_i, r_v[i].begin(), s_v[i - 1]); + for (u32 hv = 0; hv < size_i; ++hv) { + u32 v = C.circuit[i].ori_id_v[hv]; + mult_array[1][v] = mult_array[1][v] + beta_g[hv]; + } + } + } + + round = 0; + prove_timer.stop(); +} + +quadratic_poly prover::sumcheckUpdate1(const F &previous_random) { + return sumcheckUpdate(previous_random, r_u[sumcheck_id]); +} + +quadratic_poly prover::sumcheckUpdate2(const F &previous_random) { + return sumcheckUpdate(previous_random, r_v[sumcheck_id]); +} + +quadratic_poly prover::sumcheckUpdate(const F &previous_random, vector &r_arr) { + prove_timer.start(); + + if (round) r_arr.at(round - 1) = previous_random; + ++round; + quadratic_poly ret; + + add_term = add_term * (F_ONE - previous_random); + for (int b = 0; b < 2; ++b) + ret = ret + sumcheckUpdateEach(previous_random, b); + ret = ret + quadratic_poly(F_ZERO, -add_term, add_term); + + prove_timer.stop(); + proof_size += F_BYTE_SIZE * 3; + return ret; +} + +quadratic_poly prover::sumcheckLiuUpdate(const F &previous_random) { + prove_timer.start(); + ++round; + + auto ret = sumcheckUpdateEach(previous_random, true); + + prove_timer.stop(); + proof_size += F_BYTE_SIZE * 3; + return ret; +} + +quadratic_poly prover::sumcheckUpdateEach(const F &previous_random, bool idx) { + auto &tmp_mult = mult_array[idx]; + auto &tmp_v = V_mult[idx]; + + if (total[idx] == 1) { + tmp_v[0] = tmp_v[0].eval(previous_random); + tmp_mult[0] = tmp_mult[0].eval(previous_random); + add_term = add_term + tmp_v[0].b * tmp_mult[0].b; + } + + quadratic_poly ret; + for (u32 i = 0; i < (total[idx] >> 1); ++i) { + u32 g0 = i << 1, g1 = i << 1 | 1; + if (g0 >= total_size[idx]) { + tmp_v[i].clear(); + tmp_mult[i].clear(); + continue; + } + if (g1 >= total_size[idx]) { + tmp_v[g1].clear(); + tmp_mult[g1].clear(); + } + tmp_v[i] = interpolate(tmp_v[g0].eval(previous_random), tmp_v[g1].eval(previous_random)); + tmp_mult[i] = interpolate(tmp_mult[g0].eval(previous_random), tmp_mult[g1].eval(previous_random)); + ret = ret + tmp_mult[i] * tmp_v[i]; + } + total[idx] >>= 1; + total_size[idx] = (total_size[idx] + 1) >> 1; + + return ret; +} + +/** + * This is to evaluate a multi-linear extension at a random point. + * + * @param the value of the array & random point & the size of the array & the size of the random point + * @return sum of `values`, or 0.0 if `values` is empty. + */ +F prover::Vres(const vector::const_iterator &r, u32 output_size, u8 r_size) { + prove_timer.start(); + + vector output(output_size); + for (u32 i = 0; i < output_size; ++i) + output[i] = val[C.size - 1][i]; + u32 whole = 1ULL << r_size; + for (u8 i = 0; i < r_size; ++i) { + for (u32 j = 0; j < (whole >> 1); ++j) { + if (j > 0) + output[j].clear(); + if ((j << 1) < output_size) + output[j] = output[j << 1] * (F_ONE - r[i]); + if ((j << 1 | 1) < output_size) + output[j] = output[j] + output[j << 1 | 1] * (r[i]); + } + whole >>= 1; + } + F res = output[0]; + + prove_timer.stop(); + proof_size += F_BYTE_SIZE; + return res; +} + +void prover::sumcheckFinalize1(const F &previous_random, F &claim_0, F &claim_1) { + prove_timer.start(); + r_u[sumcheck_id].at(round - 1) = previous_random; + V_u0 = claim_0 = total[0] ? V_mult[0][0].eval(previous_random) : (~C.circuit[sumcheck_id].bit_length_u[0]) ? V_mult[0][0].b : F_ZERO; + V_u1 = claim_1 = total[1] ? V_mult[1][0].eval(previous_random) : (~C.circuit[sumcheck_id].bit_length_u[1]) ? V_mult[1][0].b : F_ZERO; + prove_timer.stop(); + + mult_array[0].clear(); + mult_array[1].clear(); + V_mult[0].clear(); + V_mult[1].clear(); + proof_size += F_BYTE_SIZE * 2; +} + +void prover::sumcheckFinalize2(const F &previous_random, F &claim_0, F &claim_1) { + prove_timer.start(); + r_v[sumcheck_id].at(round - 1) = previous_random; + claim_0 = total[0] ? V_mult[0][0].eval(previous_random) : (~C.circuit[sumcheck_id].bit_length_v[0]) ? V_mult[0][0].b : F_ZERO; + claim_1 = total[1] ? V_mult[1][0].eval(previous_random) : (~C.circuit[sumcheck_id].bit_length_v[1]) ? V_mult[1][0].b : F_ZERO; + prove_timer.stop(); + + mult_array[0].clear(); + mult_array[1].clear(); + V_mult[0].clear(); + V_mult[1].clear(); + proof_size += F_BYTE_SIZE * 2; +} + +void prover::sumcheckLiuFinalize(const F &previous_random, F &claim_1) { + prove_timer.start(); + r_u[sumcheck_id].at(round - 1) = previous_random; + claim_1 = total[1] ? V_mult[1][0].eval(previous_random) : V_mult[1][0].b; + prove_timer.stop(); + proof_size += F_BYTE_SIZE; + + mult_array[1].clear(); + V_mult[1].clear(); + beta_g.clear(); +} + +F prover::getCirValue(u8 layer_id, const vector &ori, u32 u) { + return !layer_id ? val[0][ori[u]] : val[layer_id][u]; +} + +hyrax_bls12_381::polyProver &prover::commitInput(const vector &gens) { + if (C.circuit[0].size != (1ULL << C.circuit[0].bit_length)) { + val[0].resize(1ULL << C.circuit[0].bit_length); + for (int i = C.circuit[0].size; i < val[0].size(); ++i) + val[0][i].clear(); + } + poly_p = std::make_unique(val[0], gens); + return *poly_p; +} \ No newline at end of file diff --git a/src/prover.hpp b/src/prover.hpp new file mode 100644 index 0000000..505e18c --- /dev/null +++ b/src/prover.hpp @@ -0,0 +1,81 @@ +// +// Created by 69029 on 3/9/2021. +// + +#ifndef ZKCNN_PROVER_HPP +#define ZKCNN_PROVER_HPP + +#include "global_var.hpp" +#include "circuit.h" +#include "polynomial.h" + +using std::unique_ptr; + +class neuralNetwork; +class singleConv; +class prover { +public: + void init(); + + void sumcheckInitAll(const vector::const_iterator &r_0_from_v); + void sumcheckInit(const F &alpha_0, const F &beta_0); + void sumcheckDotProdInitPhase1(); + void sumcheckInitPhase1(const F &relu_rou_0); + void sumcheckInitPhase2(); + + cubic_poly sumcheckDotProdUpdate1(const F &previous_random); + quadratic_poly sumcheckUpdate1(const F &previous_random); + quadratic_poly sumcheckUpdate2(const F &previous_random); + + F Vres(const vector::const_iterator &r, u32 output_size, u8 r_size); + + void sumcheckDotProdFinalize1(const F &previous_random, F &claim_1); + void sumcheckFinalize1(const F &previous_random, F &claim_0, F &claim_1); + void sumcheckFinalize2(const F &previous_random, F &claim_0, F &claim_1); + void sumcheckLiuFinalize(const F &previous_random, F &claim_1); + + void sumcheckLiuInit(const vector &s_u, const vector &s_v); + quadratic_poly sumcheckLiuUpdate(const F &previous_random); + + hyrax_bls12_381::polyProver &commitInput(const vector &gens); + + timer prove_timer; + double proveTime() const { return prove_timer.elapse_sec(); } + double proofSize() const { return (double) proof_size / 1024.0; } + double polyProverTime() const { return poly_p -> getPT(); } + double polyProofSize() const { return poly_p -> getPS(); } + + layeredCircuit C; + vector> val; // the output of each gate +private: + quadratic_poly sumcheckUpdateEach(const F &previous_random, bool idx); + quadratic_poly sumcheckUpdate(const F &previous_random, vector &r_arr); + F getCirValue(u8 layer_id, const vector &ori, u32 u); + + vector::iterator r_0, r_1; // current positions + vector> r_u, r_v; // next positions + + vector beta_g; + + F add_term; + vector mult_array[2]; + vector V_mult[2]; + + F V_u0, V_u1; + + F alpha, beta, relu_rou; + + u64 proof_size; + + u32 total[2], total_size[2]; + u8 round; // step within a sumcheck + u8 sumcheck_id; // the level + + unique_ptr poly_p; + + friend neuralNetwork; + friend singleConv; +}; + + +#endif //ZKCNN_PROVER_HPP diff --git a/src/utils.cpp b/src/utils.cpp new file mode 100644 index 0000000..e892eb3 --- /dev/null +++ b/src/utils.cpp @@ -0,0 +1,234 @@ +// +// Created by 69029 on 3/9/2021. +// + +#include +#include +#include +#include "utils.hpp" + +using std::cerr; +using std::endl; +using std::string; +using std::cin; + +int ceilPow2BitLengthSigned(double n) { + return (i8) ceil(log2(n)); +} + +int floorPow2BitLengthSigned(double n) { + return (i8) floor(log2(n)); +} + +i8 ceilPow2BitLength(u32 n) { + return n < 1e-9 ? -1 : (i8) ceil(log(n) / log(2.)); +} + +i8 floorPow2BitLength(u32 n) { +// cerr << n << ' ' << log(n) / log(2.)< &beta_f, vector &beta_s, const vector::const_iterator &r, const F &init, u32 first_half, u32 second_half) { + beta_f.at(0) = init; + beta_s.at(0) = F_ONE; + + for (u32 i = 0; i < first_half; ++i) { + for (u32 j = 0; j < (1ULL << i); ++j) { + auto tmp = beta_f.at(j) * r[i]; + beta_f.at(j | (1ULL << i)) = tmp; + beta_f.at(j) = beta_f[j] - tmp; + } + } + + for (u32 i = 0; i < second_half; ++i) { + for (u32 j = 0; j < (1ULL << i); ++j) { + auto tmp = beta_s[j] * r[(i + first_half)]; + beta_s[j | (1ULL << i)] = tmp; + beta_s[j] = beta_s[j] - tmp; + } + } +} + +void phiPowInit(vector &phi_mul, int n, bool isIFFT) { + u32 N = 1ULL << n; + F phi = getRootOfUnit(n); + if (isIFFT) F::inv(phi, phi); + phi_mul[0] = F_ONE; + for (u32 i = 1; i < N; ++i) phi_mul[i] = phi_mul[i - 1] * phi; +} + +void phiGInit(vector &phi_g, const vector::const_iterator &rx, const F &scale, int n, bool isIFFT) { + vector phi_mul(1 << n); + phiPowInit(phi_mul, n, isIFFT); + + if (isIFFT) { +// cerr << "==" << endl; +// cerr << "gLength: " << n << endl; +// for (int i = 0; i < n - 1; ++i) { +// cerr << rx[i]; +// cerr << endl; +// } + phi_g[0] = phi_g[1] = scale; + for (int i = 2; i <= n; ++i) + for (u32 b = 0; b < (1ULL << (i - 1)); ++b) { + u32 l = b, r = b ^ (1ULL << (i - 1)); + int m = n - i; + F tmp1 = F_ONE - rx[m], tmp2 = rx[m] * phi_mul[b << m]; + phi_g[r] = phi_g[l] * (tmp1 - tmp2); + phi_g[l] = phi_g[l] * (tmp1 + tmp2); + } + } else { +// cerr << "==" << endl; +// cerr << "gLength: " << n << endl; +// for (int i = 0; i < n; ++i) { +// cerr << rx[i]; +// cerr << endl; +// } + phi_g[0] = scale; + for (int i = 1; i < n; ++i) + for (u32 b = 0; b < (1ULL << (i - 1)); ++b) { + u32 l = b, r = b ^ (1ULL << (i - 1)); + int m = n - i; + F tmp1 = F_ONE - rx[m], tmp2 = rx[m] * phi_mul[b << m]; + phi_g[r] = phi_g[l] * (tmp1 - tmp2); + phi_g[l] = phi_g[l] * (tmp1 + tmp2); + } + for (u32 b = 0; b < (1ULL << (n - 1)); ++b) { + u32 l = b; + F tmp1 = F_ONE - rx[0], tmp2 = rx[0] * phi_mul[b]; + phi_g[l] = phi_g[l] * (tmp1 + tmp2); + } + } +} + +void fft(vector &arr, int logn, bool flag) { +// cerr << "fft: " << endl; +// for (auto x: arr) cerr << x << ' '; +// cerr << endl; + static std::vector rev; + static std::vector w; + + u32 len = 1ULL << logn; + assert(arr.size() == len); + + rev.resize(len); + w.resize(len); + + rev[0] = 0; + for (u32 i = 1; i < len; ++i) + rev[i] = rev[i >> 1] >> 1 | (i & 1) << (logn - 1); + + w[0] = F_ONE; + w[1] = getRootOfUnit(logn); + if (flag) F::inv(w[1], w[1]); + for (u32 i = 2; i < len; ++i) w[i] = w[i - 1] * w[1]; + + for (u32 i = 0; i < len; ++i) + if (rev[i] < i) std::swap(arr[i], arr[rev[i]]); + + for (u32 i = 2; i <= len; i <<= 1) + for (u32 j = 0; j < len; j += i) + for (u32 k = 0; k < (i >> 1); ++k) { + auto u = arr[j + k]; + auto v = arr[j + k + (i >> 1)] * w[len / i * k]; + arr[j + k] = u + v; + arr[j + k + (i >> 1)] = u - v; + } + + if (flag) { + F ilen; + F::inv(ilen, len); + for (u32 i = 0; i < len; ++i) + arr[i] = arr[i] * ilen; + } +} + +void +initBetaTable(vector &beta_g, u8 gLength, const vector::const_iterator &r_0, const vector::const_iterator &r_1, + const F &alpha, const F &beta) { + u8 first_half = gLength >> 1, second_half = gLength - first_half; + u32 mask_fhalf = (1ULL << first_half) - 1; + + vector beta_f(1ULL << first_half), beta_s(1ULL << second_half); + if (!beta.isZero()) { + initHalfTable(beta_f, beta_s, r_1, beta, first_half, second_half); + for (u32 i = 0; i < (1ULL << gLength); ++i) + beta_g[i] = beta_f[i & mask_fhalf] * beta_s[i >> first_half]; + } else for (u32 i = 0; i < (1ULL << gLength); ++i) + beta_g[i].clear(); + + if (alpha.isZero()) return; + initHalfTable(beta_f, beta_s, r_0, alpha, first_half, second_half); + for (u32 i = 0; i < (1ULL << gLength); ++i) + beta_g[i] = beta_g[i] + beta_f[i & mask_fhalf] * beta_s[i >> first_half]; +} + + +void initBetaTable(vector &beta_g, u8 gLength, const vector::const_iterator &r, const F &init) { + if (gLength == -1) return; + int first_half = gLength >> 1, second_half = gLength - first_half; + u32 mask_fhalf = (1ULL << first_half) - 1; + vector beta_f(1ULL << first_half), beta_s(1ULL << second_half); + + if (!init.isZero()) { + initHalfTable(beta_f, beta_s, r, init, first_half, second_half); + for (u32 i = 0; i < (1ULL << gLength); ++i) + beta_g[i] = beta_f[i & mask_fhalf] * beta_s[i >> first_half]; + } else for (u32 i = 0; i < (1ULL << gLength); ++i) + beta_g[i].clear(); +} + +bool check(long x, long y, long nx, long ny) { + return 0 <= x && x < nx && 0 <= y && y < ny; +} +// +//F getData(u8 scale_bl) { +// double x; +// in >> x; +// long y = round(x * (1L << scale_bl)); +// return F(y); +//} + +void initLayer(layer &circuit, long size, layerType ty) { + circuit.size = circuit.zero_start_id = size; + circuit.bit_length = ceilPow2BitLength(size); + circuit.ty = ty; +} + +long sqr(long x) { + return x * x; +} + +double byte2KB(size_t x) { return x / 1024.0; } + +double byte2MB(size_t x) { return x / 1024.0 / 1024.0; } + +double byte2GB(size_t x) { return x / 1024.0 / 1024.0 / 1024.0; } + +long matIdx(long x, long y, long n) { + assert(y < n); + return x * n + y; +} + +long cubIdx(long x, long y, long z, long n, long m) { + assert(y < n && z < m); + return matIdx(matIdx(x, y, n), z, m); +} + +long tesIdx(long w, long x, long y, long z, long n, long m, long l) { + assert(x < n && y < m && z < l); + return matIdx(cubIdx(w, x, y, n, m), z, l); +} + +F getRootOfUnit(int n) { + F res = -F_ONE; + if (!n) return F_ONE; + while (--n) { + bool b = F::squareRoot(res, res); + assert(b); + } + return res; +} + + diff --git a/src/utils.hpp b/src/utils.hpp new file mode 100644 index 0000000..d5b5043 --- /dev/null +++ b/src/utils.hpp @@ -0,0 +1,49 @@ +// +// Created by 69029 on 3/9/2021. +// + +#ifndef ZKCNN_UTILS_HPP +#define ZKCNN_UTILS_HPP + +#include + +int ceilPow2BitLengthSigned(double n); +int floorPow2BitLengthSigned(double n); + +char ceilPow2BitLength(u32 n); +char floorPow2BitLength(u32 n); + + +void fft(vector &arr, int logn, bool flag); + +void +initBetaTable(vector &beta_g, u8 gLength, const vector::const_iterator &r_0, const vector::const_iterator &r_1, + const F &alpha, const F &beta); + +void initPhiTable(F *phi_g, const layer &cur_layer, const F *r_0, const F *r_1, F alpha, F beta); + +void phiGInit(vector &phi_g, const vector::const_iterator &rx, const F &scale, int n, bool isIFFT); + +void initBetaTable(vector &beta_g, u8 gLength, const vector::const_iterator &r, const F &init); + +bool check(long x, long y, long nx, long ny); + +long matIdx(long x, long y, long n); + +long cubIdx(long x, long y, long z, long n, long m); + +long tesIdx(long w, long x, long y, long z, long n, long m, long l); + +void initLayer(layer &circuit, long size, layerType ty); + +long sqr(long x); + +double byte2KB(size_t x); + +double byte2MB(size_t x); + +double byte2GB(size_t x); + +F getRootOfUnit(int n); + +#endif //ZKCNN_UTILS_HPP diff --git a/src/verifier.cpp b/src/verifier.cpp new file mode 100644 index 0000000..4602acf --- /dev/null +++ b/src/verifier.cpp @@ -0,0 +1,373 @@ +// +// Created by 69029 on 3/9/2021. +// + +#include "verifier.hpp" +#include "global_var.hpp" +#include +#include +#include + +vector beta_v; +static vector beta_u, beta_gs; + +verifier::verifier(prover *pr, const layeredCircuit &cir): + p(pr), C(cir) { + final_claim_u0.resize(C.size + 2); + final_claim_v0.resize(C.size + 2); + + r_u.resize(C.size + 2); + r_v.resize(C.size + 2); + // make the prover ready + p->init(); +} + +F verifier::getFinalValue(const F &claim_u0, const F &claim_u1, const F &claim_v0, const F &claim_v1) { + + auto test_value = bin_value[0] * (claim_u0 * claim_v0) + + bin_value[1] * (claim_u1 * claim_v1) + + bin_value[2] * (claim_u1 * claim_v0) + + uni_value[0] * claim_u0 + + uni_value[1] * claim_u1; + + return test_value; +} + +void verifier::betaInitPhase1(u8 depth, const F &alpha, const F &beta, const vector::const_iterator &r_0, const vector::const_iterator &r_1, const F &relu_rou) { + i8 bl = C.circuit[depth].bit_length; + i8 fft_bl = C.circuit[depth].fft_bit_length; + i8 fft_blh = C.circuit[depth].fft_bit_length - 1; + i8 cnt_bl = bl - fft_bl, cnt_bl2 = C.circuit[depth].max_bl_u - fft_bl; + + switch (C.circuit[depth].ty) { + case layerType::FFT: + case layerType::IFFT: + beta_gs.resize(1ULL << fft_bl); + phiGInit(beta_gs, r_0, C.circuit[depth].scale, fft_bl, C.circuit[depth].ty == layerType::IFFT); + beta_u.resize(1ULL << C.circuit[depth].max_bl_u); + initBetaTable(beta_u, C.circuit[depth].max_bl_u, r_u[depth].begin(), F_ONE); + break; + case layerType::PADDING: + beta_g.resize(1ULL << bl); + beta_gs.resize(1ULL << fft_blh); + initBetaTable(beta_g, bl - fft_blh, r_u[depth + 2].begin() + fft_bl, r_v[depth + 2].begin(), alpha, beta); + initBetaTable(beta_gs, fft_blh, r_0, F_ONE); + for (u32 g = (1ULL << bl) - 1; g < (1ULL << bl); --g) + beta_g[g] = beta_g[g >> fft_blh] * + beta_gs[g & (1ULL << fft_blh) - 1]; + beta_u.resize(1ULL << C.circuit[depth].max_bl_u); + initBetaTable(beta_u, C.circuit[depth].max_bl_u, r_u[depth].begin(), F_ONE); + break; + case layerType::DOT_PROD: + beta_g.resize(1ULL << cnt_bl); + initBetaTable(beta_g, cnt_bl, r_u[depth + 2].begin() + fft_bl - 1, alpha); + + beta_u.resize(1ULL << cnt_bl2); + initBetaTable(beta_u, cnt_bl2, r_u[depth].begin() + fft_bl, F_ONE); + for (u32 i = 0; i < 1ULL << cnt_bl2; ++i) + for (u32 j = 0; j < fft_bl; ++j) + beta_u[i] = beta_u[i] * ((r_0[j] * r_u[depth][j]) + (F_ONE - r_0[j]) * (F_ONE - r_u[depth][j])); + break; + + default: + beta_g.resize(1ULL << bl); + initBetaTable(beta_g, C.circuit[depth].bit_length, r_0, r_1, alpha * C.circuit[depth].scale, + beta * C.circuit[depth].scale); + if (C.circuit[depth].zero_start_id < C.circuit[depth].size) + for (u32 g = C.circuit[depth].zero_start_id; g < 1ULL << C.circuit[depth].bit_length; ++g) + beta_g[g] = beta_g[g] * relu_rou; + beta_u.resize(1ULL << C.circuit[depth].max_bl_u); + initBetaTable(beta_u, C.circuit[depth].max_bl_u, r_u[depth].begin(), F_ONE); + } +} + +void verifier::betaInitPhase2(u8 depth) { + beta_v.resize(1ULL << C.circuit[depth].max_bl_v); + initBetaTable(beta_v, C.circuit[depth].max_bl_v, r_v[depth].begin(), F_ONE); +} + +void verifier::predicatePhase1(u8 layer_id) { + auto &cur_layer = C.circuit[layer_id]; + + uni_value[0].clear(); + uni_value[1].clear(); + if (cur_layer.ty == layerType::FFT || cur_layer.ty == layerType::IFFT) + for (u32 u = 0; u < 1ULL << cur_layer.max_bl_u; ++u) + uni_value[1] = uni_value[1] + beta_gs[u] * beta_u[u]; + else for (auto &gate: cur_layer.uni_gates) { + bool idx = gate.lu; + uni_value[idx] = uni_value[idx] + beta_g[gate.g] * beta_u[gate.u] * C.two_mul[gate.sc]; + } + bin_value[0] = bin_value[1] = bin_value[2] = F_ZERO; +} + +void verifier::predicatePhase2(u8 layer_id) { + uni_value[0] = uni_value[0] * beta_v[0]; + uni_value[1] = uni_value[1] * beta_v[0]; + + auto &cur_layer = C.circuit[layer_id]; + if (C.circuit[layer_id].ty == layerType::DOT_PROD) { + for (auto &gate: cur_layer.bin_gates) + bin_value[gate.l] = + bin_value[gate.l] + + beta_g[gate.g] * beta_u[gate.u] * beta_v[gate.v]; + } else for (auto &gate: cur_layer.bin_gates) + bin_value[gate.l] = bin_value[gate.l] + beta_g[gate.g] * beta_u[gate.u] * beta_v[gate.v] * C.two_mul[gate.sc]; +} + +bool verifier::verify() { + u8 logn = C.circuit[0].bit_length; + u64 n_sqrt = 1ULL << (logn - (logn >> 1)); + vector gens(n_sqrt); + for (auto &x: gens) { + Fr tmp; + tmp.setByCSPRNG(); + x = mcl::bn::getG1basePoint() * tmp; + } + + poly_v = std::make_unique(p -> commitInput(gens), gens); + return verifyInnerLayers() && verifyFirstLayer() && verifyInput(); +} + +bool verifier::verifyInnerLayers() { + total_timer.start(); + total_slow_timer.start(); + + F alpha = F_ONE, beta = F_ZERO, relu_rou, final_claim_u1, final_claim_v1; + r_u[C.size].resize(C.circuit[C.size - 1].bit_length); + for (i8 i = 0; i < C.circuit[C.size - 1].bit_length; ++i) + r_u[C.size][i].setByCSPRNG(); + vector::const_iterator r_0 = r_u[C.size].begin(); + vector::const_iterator r_1; + + total_timer.stop(); + total_slow_timer.stop(); + + auto previousSum = p->Vres(r_0, C.circuit[C.size - 1].size, C.circuit[C.size - 1].bit_length); + p -> sumcheckInitAll(r_0); + + for (u8 i = C.size - 1; i; --i) { + auto &cur = C.circuit[i]; + p->sumcheckInit(alpha, beta); + total_timer.start(); + total_slow_timer.start(); + + // phase 1 + r_u[i].resize(cur.max_bl_u); + for (int j = 0; j < cur.max_bl_u; ++j) r_u[i][j].setByCSPRNG(); + if (cur.zero_start_id < cur.size) + relu_rou.setByCSPRNG(); + else relu_rou = F_ONE; + + total_timer.stop(); + total_slow_timer.stop(); + if (cur.ty == layerType::DOT_PROD) + p->sumcheckDotProdInitPhase1(); + else p->sumcheckInitPhase1(relu_rou); + + F previousRandom = F_ZERO; + for (i8 j = 0; j < cur.max_bl_u; ++j) { + F cur_claim, nxt_claim; + if (cur.ty == layerType::DOT_PROD) { + cubic_poly poly = p->sumcheckDotProdUpdate1(previousRandom); + total_timer.start(); + total_slow_timer.start(); + cur_claim = poly.eval(F_ZERO) + poly.eval(F_ONE); + nxt_claim = poly.eval(r_u[i][j]); + } else { + quadratic_poly poly = p->sumcheckUpdate1(previousRandom); + total_timer.start(); + total_slow_timer.start(); + cur_claim = poly.eval(F_ZERO) + poly.eval(F_ONE); + nxt_claim = poly.eval(r_u[i][j]); + } + + if (cur_claim != previousSum) { + cerr << cur_claim << ' ' << previousSum << endl; + fprintf(stderr, "Verification fail, phase1, circuit %d, current bit %d\n", i, j); + return false; + } + previousRandom = r_u[i][j]; + previousSum = nxt_claim; + total_timer.stop(); + total_slow_timer.stop(); + } + + if (cur.ty == layerType::DOT_PROD) + p->sumcheckDotProdFinalize1(previousRandom, final_claim_u1); + else p->sumcheckFinalize1(previousRandom, final_claim_u0[i], final_claim_u1); + + total_slow_timer.start(); + betaInitPhase1(i, alpha, beta, r_0, r_1, relu_rou); + predicatePhase1(i); + + total_timer.start(); + if (cur.need_phase2) { + r_v[i].resize(cur.max_bl_v); + for (int j = 0; j < cur.max_bl_v; ++j) r_v[i][j].setByCSPRNG(); + + total_timer.stop(); + total_slow_timer.stop(); + + p->sumcheckInitPhase2(); + previousRandom = F_ZERO; + for (u32 j = 0; j < cur.max_bl_v; ++j) { + quadratic_poly poly = p->sumcheckUpdate2(previousRandom); + + total_timer.start(); + total_slow_timer.start(); + if (poly.eval(F_ZERO) + poly.eval(F_ONE) != previousSum) { + fprintf(stderr, "Verification fail, phase2, circuit level %d, current bit %d, total is %d\n", i, j, + cur.max_bl_v); + return false; + } + + previousRandom = r_v[i][j]; + previousSum = poly.eval(previousRandom); + total_timer.stop(); + total_slow_timer.stop(); + } + p->sumcheckFinalize2(previousRandom, final_claim_v0[i], final_claim_v1); + + total_slow_timer.start(); + betaInitPhase2(i); + predicatePhase2(i); + total_timer.start(); + } + F test_value = getFinalValue(final_claim_u0[i], final_claim_u1, final_claim_v0[i], final_claim_v1); + + if (previousSum != test_value) { + std::cerr << test_value << ' ' << previousSum << std::endl; + fprintf(stderr, "Verification fail, semi final, circuit level %d\n", i); + return false; + } else fprintf(stderr, "Verification Pass, semi final, circuit level %d\n", i); + + if (cur.ty == layerType::FFT || cur.ty == layerType::IFFT) + previousSum = final_claim_u1; + else { + if (~cur.bit_length_u[1]) + alpha.setByCSPRNG(); + else alpha.clear(); + if ((~cur.bit_length_v[1]) || cur.ty == layerType::FFT) + beta.setByCSPRNG(); + else beta.clear(); + previousSum = alpha * final_claim_u1 + beta * final_claim_v1; + } + + r_0 = r_u[i].begin(); + r_1 = r_v[i].begin(); + + total_timer.stop(); + total_slow_timer.stop(); + beta_u.clear(); + beta_v.clear(); + } + return true; +} + +bool verifier::verifyFirstLayer() { + total_slow_timer.start(); + total_timer.start(); + + auto &cur = C.circuit[0]; + + vector sig_u(C.size - 1); + for (int i = 0; i < C.size - 1; ++i) sig_u[i].setByCSPRNG(); + vector sig_v(C.size - 1); + for (int i = 0; i < C.size - 1; ++i) sig_v[i].setByCSPRNG(); + r_u[0].resize(cur.bit_length); + for (int i = 0; i < cur.bit_length; ++i) r_u[0][i].setByCSPRNG(); + auto r_0 = r_u[0].begin(); + + F previousSum = F_ZERO; + for (int i = 1; i < C.size; ++i) { + if (~C.circuit[i].bit_length_u[0]) + previousSum = previousSum + sig_u[i - 1] * final_claim_u0[i]; + if (~C.circuit[i].bit_length_v[0]) + previousSum = previousSum + sig_v[i - 1] * final_claim_v0[i]; + } + total_timer.stop(); + total_slow_timer.stop(); + + p->sumcheckLiuInit(sig_u, sig_v); + F previousRandom = F_ZERO; + for (int j = 0; j < cur.bit_length; ++j) { + auto poly = p -> sumcheckLiuUpdate(previousRandom); + if (poly.eval(F_ZERO) + poly.eval(F_ONE) != previousSum) { + fprintf(stderr, "Liu fail, circuit 0, current bit %d\n", j); + return false; + } + previousRandom = r_0[j]; + previousSum = poly.eval(previousRandom); + } + + F gr = F_ZERO; + p->sumcheckLiuFinalize(previousRandom, eval_in); + + beta_g.resize(1ULL << cur.bit_length); + + total_slow_timer.start(); + initBetaTable(beta_g, cur.bit_length, r_0, F_ONE); + for (int i = 1; i < C.size; ++i) { + if (~C.circuit[i].bit_length_u[0]) { + beta_u.resize(1ULL << C.circuit[i].bit_length_u[0]); + initBetaTable(beta_u, C.circuit[i].bit_length_u[0], r_u[i].begin(), sig_u[i - 1]); + for (u32 j = 0; j < C.circuit[i].size_u[0]; ++j) + gr = gr + beta_g[C.circuit[i].ori_id_u[j]] * beta_u[j]; + } + + if (~C.circuit[i].bit_length_v[0]) { + beta_v.resize(1ULL << C.circuit[i].bit_length_v[0]); + initBetaTable(beta_v, C.circuit[i].bit_length_v[0], r_v[i].begin(), sig_v[i - 1]); + for (u32 j = 0; j < C.circuit[i].size_v[0]; ++j) + gr = gr + beta_g[C.circuit[i].ori_id_v[j]] * beta_v[j]; + } + } + + beta_u.clear(); + beta_v.clear(); + + total_timer.start(); + if (eval_in * gr != previousSum) { + fprintf(stderr, "Liu fail, semi final, circuit 0.\n"); + return false; + } + + total_timer.stop(); + total_slow_timer.stop(); + output_tb[PT_OUT_ID] = to_string_wp(p->proveTime()); + output_tb[VT_OUT_ID] = to_string_wp(verifierTime()); + output_tb[PS_OUT_ID] = to_string_wp(p -> proofSize()); + + fprintf(stderr, "Verification pass\n"); + fprintf(stderr, "Prove Time %lf\n", p->proveTime()); + fprintf(stderr, "verify time %lf = %lf + %lf(slow)\n", verifierSlowTime(), verifierTime(), verifierSlowTime() - verifierTime()); + fprintf(stderr, "proof size = %lf kb\n", p -> proofSize()); + + beta_g.clear(); + beta_gs.clear(); + beta_u.clear(); + beta_v.clear(); + r_u.resize(1); + r_v.clear(); + + sig_u.clear(); + sig_v.clear(); + return true; +} + +bool verifier::verifyInput() { + if (!poly_v -> verify(r_u[0], eval_in)) { + fprintf(stderr, "Verification fail, final input check fail.\n"); + return false; + } + + fprintf(stderr, "poly pt = %.5f, vt = %.5f, ps = %.5f\n", p -> polyProverTime(), poly_v -> getVT(), p -> polyProofSize()); + output_tb[POLY_PT_OUT_ID] = to_string_wp(p -> polyProverTime()); + output_tb[POLY_VT_OUT_ID] = to_string_wp(poly_v -> getVT()); + output_tb[POLY_PS_OUT_ID] = to_string_wp(p -> polyProofSize()); + output_tb[TOT_PT_OUT_ID] = to_string_wp(p -> polyProverTime() + p->proveTime()); + output_tb[TOT_VT_OUT_ID] = to_string_wp(poly_v -> getVT() + verifierTime()); + output_tb[TOT_PS_OUT_ID] = to_string_wp(p -> polyProofSize() + p -> proofSize()); + return true; +} diff --git a/src/verifier.hpp b/src/verifier.hpp new file mode 100644 index 0000000..d9af504 --- /dev/null +++ b/src/verifier.hpp @@ -0,0 +1,47 @@ +// +// Created by 69029 on 3/9/2021. +// + +#ifndef ZKCNN_CONVVERIFIER_HPP +#define ZKCNN_CONVVERIFIER_HPP + +#include "prover.hpp" + +using std::unique_ptr; +class verifier { +public: + prover *p; + const layeredCircuit &C; + + verifier(prover *pr, const layeredCircuit &cir); + + bool verify(); + + timer total_timer, total_slow_timer; + double verifierTime() const { return total_timer.elapse_sec(); } + double verifierSlowTime() const { return total_slow_timer.elapse_sec(); } + +private: + vector> r_u, r_v; + vector final_claim_u0, final_claim_v0; + bool verifyInnerLayers(); + bool verifyFirstLayer(); + bool verifyInput(); + + vector beta_g; + void betaInitPhase1(u8 depth, const F &alpha, const F &beta, const vector::const_iterator &r_0, const vector::const_iterator &r_1, const F &relu_rou); + void betaInitPhase2(u8 depth); + + F uni_value[2]; + F bin_value[3]; + void predicatePhase1(u8 layer_id); + void predicatePhase2(u8 layer_id); + + F getFinalValue(const F &claim_u0, const F &claim_u1, const F &claim_v0, const F &claim_v1); + + F eval_in; + unique_ptr poly_v; +}; + + +#endif //ZKCNN_CONVVERIFIER_HPP