diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000..9acb4be
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "3rd/hyrax-bls12-381"]
+ path = 3rd/hyrax-bls12-381
+ url = git@github.com:TAMUCrypto/hyrax-bls12-381.git
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..73f69e0
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
+# Editor-based HTTP Client requests
+/httpRequests/
diff --git a/.idea/.name b/.idea/.name
new file mode 100644
index 0000000..3ba80a4
--- /dev/null
+++ b/.idea/.name
@@ -0,0 +1 @@
+zkCNN
\ No newline at end of file
diff --git a/.idea/deployment.xml b/.idea/deployment.xml
new file mode 100644
index 0000000..b9fc2ad
--- /dev/null
+++ b/.idea/deployment.xml
@@ -0,0 +1,182 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..f1c67df
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..1d88d5d
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..dfb26aa
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/zkCNN-quant.iml b/.idea/zkCNN-quant.iml
new file mode 100644
index 0000000..6d70257
--- /dev/null
+++ b/.idea/zkCNN-quant.iml
@@ -0,0 +1,2 @@
+
+
\ No newline at end of file
diff --git a/3rd/hyrax-bls12-381 b/3rd/hyrax-bls12-381
new file mode 160000
index 0000000..baedf71
--- /dev/null
+++ b/3rd/hyrax-bls12-381
@@ -0,0 +1 @@
+Subproject commit baedf71d215549e8b7147809c7fb807d315a8843
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 0000000..e49db18
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,12 @@
+cmake_minimum_required(VERSION 3.10)
+project(zkCNN)
+set(CMAKE_CXX_STANDARD 14)
+
+link_directories(3rd/hyrax-bls12-381)
+
+include_directories(src)
+include_directories(3rd)
+include_directories(3rd/hyrax-bls12-381/3rd/mcl/include)
+
+add_subdirectory(src)
+add_subdirectory(3rd/hyrax-bls12-381)
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..12dc852
--- /dev/null
+++ b/README.md
@@ -0,0 +1,83 @@
+# zkCNN
+
+## Introduction
+
+This is a GKR-based zero-knowledge proof for CNN reference, containing some widely used network such as LeNet5, vgg11 and vgg16.
+
+
+
+## Requirement
+
+- C++14
+- cmake >= 3.10
+- GMP library
+
+
+
+## Input Format
+
+The input has two part which are data and weight in the matrix.
+
+### Data Part
+
+There are two cases supported in this repo.
+
+- **Single picture**
+
+ Then the picture is a vector reshaped from its original matrix by
+
+ 
+
+ where  is the number of channel,  is the height,  is the width.
+
+
+
+- **Multiply picture**
+
+ This solve the case when the user wants to infer multiple pictures by the same network. Then the picture is a vector reshaped from its original matrix by
+
+ 
+
+ where  is the number of pictures,  is the number of channel,  is the height,  is the width.
+
+### Weight Part
+
+This part is for weight in the neural network, which contains
+
+- convolution kernel of size 
+
+ where  and  are the number of output and input channels,  is the sideness of the kernel (here we only support square kernel).
+
+- convolution bias of size 
+
+- fully-connected kernel of size 
+
+
+- fully-connected bias of size 
+
+
+All the input above are scanned one by one.
+
+## Experiment Script
+### Clone the repo
+To run the code, make sure you clone with
+``` bash
+git clone git@github.com:TAMUCrypto/zkCNN.git
+git submodule update --init --recursive
+```
+since the polynomial commitment is included as a submodule.
+
+### Run a demo of vgg11
+The script to run vgg11 model (please run the script in ``script/`` directory).
+``` bash
+./demo.sh
+```
+
+- The input data is in ``data/vgg11/``.
+- The experiment evaluation is ``output/single/demo-result.txt``.
+- The inference result is ``output/single/vgg11.cifar.relu-1-infer.csv``.
+
+## Polynomial Commitment
+
+Here we implement a hyrax polynomial commitment based on BLS12-381 elliptic curve. It is a submodule and someone who is interested can refer to this repo [hyrax-bls12-381](https://github.com/TAMUCrypto/hyrax-bls12-381).
+
diff --git a/data.tar.gz b/data.tar.gz
new file mode 100644
index 0000000..a62959b
Binary files /dev/null and b/data.tar.gz differ
diff --git a/script/build.sh b/script/build.sh
new file mode 100644
index 0000000..3d02834
--- /dev/null
+++ b/script/build.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+cd ..
+mkdir -p cmake-build-release
+cd cmake-build-release
+/usr/bin/cmake -DCMAKE_BUILD_TYPE=Release -G "CodeBlocks - Unix Makefiles" ..
+cd ..
+
+if [ ! -d "./data" ]
+then
+ tar -xzvf data.tar.gz
+fi
+cd script
\ No newline at end of file
diff --git a/script/demo.sh b/script/demo.sh
new file mode 100644
index 0000000..91ccc00
--- /dev/null
+++ b/script/demo.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+
+set -x
+
+./build.sh
+/usr/bin/cmake --build ../cmake-build-release --target demo_run -- -j 6
+
+run_file=../cmake-build-release/src/demo_run
+out_file=../output/single/demo-result.txt
+
+mkdir -p ../output/single
+mkdir -p ../log/single
+
+vgg11_i=../data/vgg11/vgg11.cifar.relu-1-images-weights-qint8.csv
+vgg11_c=../data/vgg11/vgg11.cifar.relu-1-scale-zeropoint-uint8.csv
+vgg11_o=../output/single/vgg11.cifar.relu-1-infer.csv
+vgg11_n=../data/vgg11/vgg11-config.csv
+
+${run_file} ${vgg11_i} ${vgg11_c} ${vgg11_o} ${vgg11_n} 1 > ${out_file}
\ No newline at end of file
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
new file mode 100644
index 0000000..a2f2cae
--- /dev/null
+++ b/src/CMakeLists.txt
@@ -0,0 +1,7 @@
+aux_source_directory(. conv_src)
+list(FILTER conv_src EXCLUDE REGEX "main*")
+
+add_library(cnn_lib ${conv_src})
+
+add_executable(demo_run main_demo.cpp)
+target_link_libraries(demo_run cnn_lib hyrax_lib mcl mclbn384_256)
\ No newline at end of file
diff --git a/src/circuit.cpp b/src/circuit.cpp
new file mode 100644
index 0000000..fa48894
--- /dev/null
+++ b/src/circuit.cpp
@@ -0,0 +1,100 @@
+#include "circuit.h"
+#include "utils.hpp"
+
+void layeredCircuit::initSubset() {
+ cerr << "begin subset init." << endl;
+ vector visited_uidx(circuit[0].size); // whether the i-th layer, j-th gate has been visited in the current layer
+ vector subset_uidx(circuit[0].size); // the subset index of the i-th layer, j-th gate
+ vector visited_vidx(circuit[0].size); // whether the i-th layer, j-th gate has been visited in the current layer
+ vector subset_vidx(circuit[0].size); // the subset index of the i-th layer, j-th gate
+
+ for (u8 i = 1; i < size; ++i) {
+ auto &cur = circuit[i], &lst = circuit[i - 1];
+ bool has_pre_layer_u = circuit[i].ty == layerType::FFT || circuit[i].ty == layerType::IFFT;
+ bool has_pre_layer_v = false;
+
+ for (auto &gate: cur.uni_gates) {
+ if (!gate.lu) {
+ if (visited_uidx[gate.u] != i) {
+ visited_uidx[gate.u] = i;
+ subset_uidx[gate.u] = cur.size_u[0];
+ cur.ori_id_u.push_back(gate.u);
+ ++cur.size_u[0];
+ }
+ gate.u = subset_uidx[gate.u];
+ }
+ has_pre_layer_u |= (gate.lu != 0);
+ }
+
+ for (auto &gate: cur.bin_gates) {
+ if (!gate.getLayerIdU(i)) {
+ if (visited_uidx[gate.u] != i) {
+ visited_uidx[gate.u] = i;
+ subset_uidx[gate.u] = cur.size_u[0];
+ cur.ori_id_u.push_back(gate.u);
+ ++cur.size_u[0];
+ }
+ gate.u = subset_uidx[gate.u];
+ }
+ if (!gate.getLayerIdV(i)) {
+ if (visited_vidx[gate.v] != i) {
+ visited_vidx[gate.v] = i;
+ subset_vidx[gate.v] = cur.size_v[0];
+ cur.ori_id_v.push_back(gate.v);
+ ++cur.size_v[0];
+ }
+ gate.v = subset_vidx[gate.v];
+ }
+ has_pre_layer_u |= (gate.getLayerIdU(i) != 0);
+ has_pre_layer_v |= (gate.getLayerIdV(i) != 0);
+ }
+
+ cur.bit_length_u[0] = ceilPow2BitLength(cur.size_u[0]);
+ cur.bit_length_v[0] = ceilPow2BitLength(cur.size_v[0]);
+
+ if (has_pre_layer_u) switch (cur.ty) {
+ case layerType::FFT:
+ cur.size_u[1] = 1ULL << cur.fft_bit_length - 1;
+ cur.bit_length_u[1] = cur.fft_bit_length - 1;
+ break;
+ case layerType::IFFT:
+ cur.size_u[1] = 1ULL << cur.fft_bit_length;
+ cur.bit_length_u[1] = cur.fft_bit_length;
+ break;
+ default:
+ cur.size_u[1] = lst.size ;
+ cur.bit_length_u[1] = lst.bit_length;
+ break;
+ } else {
+ cur.size_u[1] = 0;
+ cur.bit_length_u[1] = -1;
+ }
+
+ if (has_pre_layer_v) {
+ if (cur.ty == layerType::DOT_PROD) {
+ cur.size_v[1] = lst.size >> cur.fft_bit_length;
+ cur.bit_length_v[1] = lst.bit_length - cur.fft_bit_length;
+ } else {
+ cur.size_v[1] = lst.size;
+ cur.bit_length_v[1] = lst.bit_length;
+ }
+ } else {
+ cur.size_v[1] = 0;
+ cur.bit_length_v[1] = -1;
+ }
+ cur.updateSize();
+ }
+ cerr << "begin subset finish." << endl;
+}
+
+void layeredCircuit::init(u8 q_bit_size, u8 _layer_sz) {
+ two_mul.resize((q_bit_size + 1) << 1);
+ two_mul[0] = F_ONE;
+ two_mul[q_bit_size + 1] = -F_ONE;
+ for (int i = 1; i <= q_bit_size; ++i) {
+ two_mul[i] = two_mul[i - 1] + two_mul[i - 1];
+ two_mul[i + q_bit_size + 1] = -two_mul[i];
+ }
+ size = _layer_sz;
+ circuit.resize(size);
+}
diff --git a/src/circuit.h b/src/circuit.h
new file mode 100644
index 0000000..46d47e0
--- /dev/null
+++ b/src/circuit.h
@@ -0,0 +1,89 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include "global_var.hpp"
+
+using std::cerr;
+using std::endl;
+using std::vector;
+
+struct uniGate {
+ u32 g, u;
+ u8 lu, sc;
+ uniGate(u32 _g, u32 _u, u8 _lu, u8 _sc) :
+ g(_g), u(_u), lu(_lu), sc(_sc) {
+// cerr << "uni: " << g << ' ' << u << ' ' << lu <<' ' << sc.real << endl;
+ }
+};
+
+struct binGate {
+ u32 g, u, v;
+ u8 sc, l;
+ binGate(u32 _g, u32 _u, u32 _v, u8 _sc, u8 _l):
+ g(_g), u(_u), v(_v), sc(_sc), l(_l) {
+// cerr << "bin: " << g << ' ' << u << ' ' << lu << ' ' << v << ' ' << lu << ' ' << sc.real << endl;
+ }
+ [[nodiscard]] u8 getLayerIdU(u8 layer_id) const { return !l ? 0 : layer_id - 1; }
+ [[nodiscard]] u8 getLayerIdV(u8 layer_id) const { return !(l & 1) ? 0 : layer_id - 1; }
+};
+
+enum class layerType {
+ INPUT, FFT, IFFT, ADD_BIAS, RELU, Sqr, OPT_AVG_POOL, MAX_POOL, AVG_POOL, DOT_PROD, PADDING, FCONN, NCONV, NCONV_MUL, NCONV_ADD
+};
+
+class layer {
+public:
+ layerType ty;
+ u32 size{}, size_u[2]{}, size_v[2]{};
+ i8 bit_length_u[2]{}, bit_length_v[2]{}, bit_length{};
+ i8 max_bl_u{}, max_bl_v{};
+
+ bool need_phase2;
+
+ // bit decomp related
+ u32 zero_start_id;
+
+ std::vector uni_gates;
+ std::vector bin_gates;
+
+ vector ori_id_u, ori_id_v;
+ i8 fft_bit_length;
+
+ // iFFT or avg pooling.
+ F scale;
+
+ layer() {
+ bit_length_u[0] = bit_length_v[0] = -1;
+ size_u[0] = size_v[0] = 0;
+ bit_length_u[1] = bit_length_v[1] = -1;
+ size_u[1] = size_v[1] = 0;
+ need_phase2 = false;
+ zero_start_id = 0;
+ fft_bit_length = -1;
+ scale = F_ONE;
+ }
+
+ void updateSize() {
+ max_bl_u = std::max(bit_length_u[0], bit_length_u[1]);
+ max_bl_v = 0;
+ if (!need_phase2) return;
+
+ max_bl_v = std::max(bit_length_v[0], bit_length_v[1]);
+ }
+};
+
+class layeredCircuit {
+public:
+ vector circuit;
+ u8 size;
+ vector two_mul;
+
+ void init(u8 q_bit_size, u8 _layer_sz);
+ void initSubset();
+};
+
diff --git a/src/global_var.hpp b/src/global_var.hpp
new file mode 100644
index 0000000..d81f8ea
--- /dev/null
+++ b/src/global_var.hpp
@@ -0,0 +1,58 @@
+//
+// Created by 69029 on 5/4/2021.
+//
+
+#include
+#include
+
+#ifndef ZKCNN_GLOBAL_VAR_HPP
+#define ZKCNN_GLOBAL_VAR_HPP
+
+// the output format
+#define MO_INFO_OUT_ID 0
+#define PSIZE_OUT_ID 1
+#define KSIZE_OUT_ID 2
+#define PCNT_OUT_ID 3
+#define CONV_TY_OUT_ID 4
+#define QS_OUT_ID 5
+#define WS_OUT_ID 6
+#define PT_OUT_ID 7
+#define VT_OUT_ID 8
+#define PS_OUT_ID 9
+#define POLY_PT_OUT_ID 10
+#define POLY_VT_OUT_ID 11
+#define POLY_PS_OUT_ID 12
+#define TOT_PT_OUT_ID 13
+#define TOT_VT_OUT_ID 14
+#define TOT_PS_OUT_ID 15
+
+using std::cerr;
+using std::endl;
+using std::vector;
+using std::string;
+using std::max;
+using std::min;
+using std::ifstream;
+using std::ofstream;
+using std::ostream;
+using std::pair;
+using std::make_pair;
+
+extern vector output_tb;
+
+#define F Fr
+#define G G1
+#define F_ONE (Fr::one())
+#define F_ZERO (Fr(0))
+
+#define F_BYTE_SIZE (Fr::getByteSize())
+
+template
+string to_string_wp(const T a_value, const int n = 4) {
+ std::ostringstream out;
+ out.precision(n);
+ out << std::fixed << a_value;
+ return out.str();
+}
+
+#endif //ZKCNN_GLOBAL_VAR_HPP
diff --git a/src/main_demo.cpp b/src/main_demo.cpp
new file mode 100644
index 0000000..3c78ba4
--- /dev/null
+++ b/src/main_demo.cpp
@@ -0,0 +1,45 @@
+//
+// Created by 69029 on 4/12/2021.
+//
+
+#include "circuit.h"
+#include "neuralNetwork.hpp"
+#include "verifier.hpp"
+#include "models.hpp"
+#include "global_var.hpp"
+
+// the arguments' format
+#define INPUT_FILE_ID 1 // the input filename
+#define CONFIG_FILE_ID 2 // the config filename
+#define OUTPUT_FILE_ID 3 // the input filename
+#define NETWORK_FILE_ID 4 // the configuration of vgg
+#define PIC_CNT 5 // the number of picture paralleled
+
+char QSIZE;
+char SCALE;
+vector output_tb(16, "");
+
+int main(int argc, char **argv) {
+ initPairing(mcl::BLS12_381);
+
+ char i_filename[500], c_filename[500], o_filename[500], n_filename[500];
+
+ strcpy(i_filename, argv[INPUT_FILE_ID]);
+ strcpy(c_filename, argv[CONFIG_FILE_ID]);
+ strcpy(o_filename, argv[OUTPUT_FILE_ID]);
+ strcpy(n_filename, argv[NETWORK_FILE_ID]);
+
+ int pic_cnt = atoi(argv[PIC_CNT]);
+
+ output_tb[MO_INFO_OUT_ID] ="vgg (relu)";
+ output_tb[PCNT_OUT_ID] = std::to_string(pic_cnt);
+
+ prover p;
+ vgg nn(32, 32, 3, pic_cnt, i_filename, c_filename, o_filename, n_filename);
+ nn.create(p, false);
+ verifier v(&p, p.C);
+ v.verify();
+
+ for (auto &s: output_tb) printf("%s, ", s.c_str());
+ puts("");
+}
\ No newline at end of file
diff --git a/src/models.cpp b/src/models.cpp
new file mode 100644
index 0000000..c9f80cc
--- /dev/null
+++ b/src/models.cpp
@@ -0,0 +1,365 @@
+//
+// Created by 69029 on 3/16/2021.
+//
+
+#include
+#include
+#include "models.hpp"
+#include "utils.hpp"
+#undef USE_VIRGO
+
+
+vgg::vgg(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, const string &i_filename,
+ const string &c_filename, const std::string &o_filename, const std::string &n_filename):
+ neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) {
+ assert(psize_x == psize_y);
+ conv_section.resize(5);
+
+ convType conv_ty = NAIVE_FAST;
+
+ ifstream config_in(n_filename);
+ string con;
+ i64 kernel_size = 3, ch_in = pic_channel, ch_out, new_nx = pic_size_x, new_ny = pic_size_y;
+
+ int idx = 0;
+ while (config_in >> con) {
+ if (con[0] != 'M' && con[0] != 'A') {
+ ch_out = stoi(con, nullptr, 10);
+ conv_section[idx].emplace_back(conv_ty, ch_out, ch_in, kernel_size);
+ ch_in = ch_out;
+ } else {
+ ++idx;
+ pool.emplace_back(con[0] == 'M' ? MAX : AVG, 2, 1);
+ new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1;
+ new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1;
+ }
+ }
+
+ assert(pic_size_x == 32);
+ full_conn.emplace_back(512, new_nx * new_ny * ch_in);
+ full_conn.emplace_back(512, 512);
+ full_conn.emplace_back(10, 512);
+}
+
+vgg16::vgg16(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty,
+ const std::string &i_filename, const string &c_filename, const std::string &o_filename)
+ : neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) {
+ assert(psize_x == psize_y);
+ conv_section.resize(5);
+
+ i64 start = 64, kernel_size = 3, new_nx = pic_size_x, new_ny = pic_size_y;
+
+ conv_section[0].emplace_back(conv_ty, start, pic_channel, kernel_size);
+ conv_section[0].emplace_back(conv_ty, start, start, kernel_size);
+ pool.emplace_back(pool_ty, 2, 1);
+ new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1;
+ new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1;
+
+ conv_section[1].emplace_back(conv_ty, start << 1, start, kernel_size);
+ conv_section[1].emplace_back(conv_ty, start << 1, start << 1, kernel_size);
+ pool.emplace_back(pool_ty, 2, 1);
+ new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1;
+ new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1;
+
+ conv_section[2].emplace_back(conv_ty, start << 2, start << 1, kernel_size);
+ conv_section[2].emplace_back(conv_ty, start << 2, start << 2, kernel_size);
+ conv_section[2].emplace_back(conv_ty, start << 2, start << 2, kernel_size);
+ pool.emplace_back(pool_ty, 2, 1);
+ new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1;
+ new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1;
+
+ conv_section[3].emplace_back(conv_ty, start << 3, start << 2, 3);
+ conv_section[3].emplace_back(conv_ty, start << 3, start << 3, 3);
+ conv_section[3].emplace_back(conv_ty, start << 3, start << 3, 3);
+ pool.emplace_back(pool_ty, 2, 1);
+ new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1;
+ new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1;
+
+ conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3);
+ conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3);
+ conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3);
+
+ pool.emplace_back(pool_ty, 2, 1);
+ new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1;
+ new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1;
+
+ if (pic_size_x == 224) {
+ full_conn.emplace_back(4096, new_nx * new_ny * (start << 3));
+ full_conn.emplace_back(4096, 4096);
+ full_conn.emplace_back(1000, 4096);
+ } else {
+ assert(pic_size_x == 32);
+ full_conn.emplace_back(512, new_nx * new_ny * (start << 3));
+ full_conn.emplace_back(512, 512);
+ full_conn.emplace_back(10, 512);
+ }
+}
+
+vgg11::vgg11(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty,
+ const std::string &i_filename, const string &c_filename, const std::string &o_filename)
+ : neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) {
+ assert(psize_x == psize_y);
+ conv_section.resize(5);
+
+ i64 start = 64, kernel_size = 3, new_nx = pic_size_x, new_ny = pic_size_y;
+
+ conv_section[0].emplace_back(conv_ty, start, pic_channel, kernel_size);
+ pool.emplace_back(pool_ty, 2, 1);
+ new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1;
+ new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1;
+
+ conv_section[1].emplace_back(conv_ty, start << 1, start, kernel_size);
+ pool.emplace_back(pool_ty, 2, 1);
+ new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1;
+ new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1;
+
+ conv_section[2].emplace_back(conv_ty, start << 2, start << 1, kernel_size);
+ conv_section[2].emplace_back(conv_ty, start << 2, start << 2, kernel_size);
+ pool.emplace_back(pool_ty, 2, 1);
+ new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1;
+ new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1;
+
+ conv_section[3].emplace_back(conv_ty, start << 3, start << 2, 3);
+ conv_section[3].emplace_back(conv_ty, start << 3, start << 3, 3);
+ pool.emplace_back(pool_ty, 2, 1);
+ new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1;
+ new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1;
+
+ conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3);
+ conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3);
+
+ pool.emplace_back(pool_ty, 2, 1);
+ new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1;
+ new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1;
+
+ if (pic_size_x == 224) {
+ full_conn.emplace_back(4096, new_nx * new_ny * (start << 3));
+ full_conn.emplace_back(4096, 4096);
+ full_conn.emplace_back(1000, 4096);
+ } else {
+ assert(pic_size_x == 32);
+ full_conn.emplace_back(512, new_nx * new_ny * (start << 3));
+ full_conn.emplace_back(512, 512);
+ full_conn.emplace_back(10, 512);
+ }
+}
+
+ccnn::ccnn(i64 psize_x, i64 psize_y, i64 pparallel, i64 pchannel, poolType pool_ty) :
+ neuralNetwork(psize_x, psize_y, pchannel, pparallel, "", "", "") {
+ conv_section.resize(1);
+
+ conv_section[0].emplace_back(NAIVE_FAST, 2, pchannel, 3, 0, 0);
+ pool.emplace_back(pool_ty, 2, 1);
+
+// conv_section[1].emplace_back(FFT, 64, 4, 3);
+// conv_section[1].emplace_back(NAIVE, 64, 64, 3);
+// pool.emplace_back(pool_ty, 2, 1);
+
+// conv_section[0].emplace_back(FFT, 2, pic_channel, 3);
+// conv_section[1].emplace_back(NAIVE, 1, 2, 3);
+// pool.emplace_back(pool_ty, 2, 1);
+}
+
+lenet::lenet(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty,
+ const std::string &i_filename, const string &c_filename, const std::string &o_filename)
+ : neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) {
+ conv_section.emplace_back();
+ if (psize_x == 28 && psize_y == 28)
+ conv_section[0].emplace_back(conv_ty, 6, pchannel, 5, 0, 2);
+ else conv_section[0].emplace_back(conv_ty, 6, pchannel, 5, 0, 0);
+ pool.emplace_back(pool_ty, 2, 1);
+
+ conv_section.emplace_back();
+ conv_section[1].emplace_back(conv_ty, 16, 6, 5, 0, 0);
+ pool.emplace_back(pool_ty, 2, 1);
+
+ full_conn.emplace_back(120, 400);
+ full_conn.emplace_back(84, 120);
+ full_conn.emplace_back(10, 84);
+}
+
+lenetCifar::lenetCifar(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty,
+ const std::string &i_filename, const string &c_filename, const std::string &o_filename)
+ : neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) {
+ conv_section.resize(3);
+
+ conv_section[0].emplace_back(conv_ty, 6, pchannel, 5, 0, 0);
+ pool.emplace_back(pool_ty, 2, 1);
+
+ conv_section[1].emplace_back(conv_ty, 16, 6, 5, 0, 0);
+ pool.emplace_back(pool_ty, 2, 1);
+
+ conv_section[2].emplace_back(conv_ty, 120, 16, 5, 0, 0);
+
+ full_conn.emplace_back(84, 120);
+ full_conn.emplace_back(10, 84);
+}
+
+void singleConv::createConv(prover &p) {
+ initParamConv();
+ p.C.init(Q_BIT_SIZE, SIZE);
+
+ p.val.resize(SIZE);
+ val = p.val.begin();
+ two_mul = p.C.two_mul.begin();
+
+ i64 layer_id = 0;
+ inputLayer(p.C.circuit[layer_id++]);
+
+ new_nx_in = pic_size_x;
+ new_ny_in = pic_size_y;
+ pool_ty = NONE;
+ for (i64 i = 0; i < conv_section.size(); ++i) {
+ auto &sec = conv_section[i];
+ for (i64 j = 0; j < sec.size(); ++j) {
+ auto &conv = sec[j];
+ refreshConvParam(new_nx_in, new_ny_in, conv);
+
+ switch (conv.ty) {
+ case FFT:
+ paddingLayer(p.C.circuit[layer_id], layer_id, conv.weight_start_id);
+ fftLayer(p.C.circuit[layer_id], layer_id);
+ dotProdLayer(p.C.circuit[layer_id], layer_id);
+ ifftLayer(p.C.circuit[layer_id], layer_id);
+ break;
+ case NAIVE_FAST:
+ naiveConvLayerFast(p.C.circuit[layer_id], layer_id, conv.weight_start_id, conv.bias_start_id);
+ break;
+ default:
+ naiveConvLayerMul(p.C.circuit[layer_id], layer_id, conv.weight_start_id);
+ naiveConvLayerAdd(p.C.circuit[layer_id], layer_id, conv.bias_start_id);
+ }
+ }
+ }
+ p.C.initSubset();
+// for (i64 i = 0; i < SIZE; ++i) {
+// cerr << i << "(" << p.C.circuit[i].zero_start_id << ", " << p.C.circuit[i].size << "):\t";
+// for (i64 j = 0; j < std::min(100u, p.C.circuit[i].size); ++j)
+// cerr << p.val[i][j] << ' ';
+// cerr << endl;
+// bool flag = false;
+// for (i64 j = 0; j < p.C.circuit[i].size; ++j)
+// if (p.val[i][j] != F_ZERO) flag = true;
+// if (flag) cerr << "not all zero: " << i << endl;
+// for (i64 j = p.C.circuit[i].zero_start_id; j < p.C.circuit[i].size; ++j)
+// if (p.val[i][j] != F_ZERO) { cerr << "WRONG! " << i << ' ' << j << ' ' << p.val[i][j] << endl; exit(EXIT_FAILURE); }
+// }
+ cerr << "finish creating circuit." << endl;
+}
+
+void singleConv::initParamConv() {
+ i64 conv_layer_cnt = 0;
+ total_in_size = 0;
+ total_para_size = total_relu_in_size = total_ave_in_size = total_max_in_size = 0;
+
+ // data
+ i64 pos = pic_size_x * pic_size_y * pic_channel * pic_parallel;
+
+ new_nx_in = pic_size_x;
+ new_ny_in = pic_size_y;
+ for (i64 i = 0; i < conv_section.size(); ++i) {
+ auto &sec = conv_section[i];
+ for (i64 j = 0; j < sec.size(); ++j) {
+ refreshConvParam(new_nx_in, new_ny_in, sec[j]);
+ conv_layer_cnt += sec[j].ty == FFT ? FFT_SIZE - 1 : sec[j].ty == NAIVE ? NCONV_SIZE : NCONV_FAST_SIZE;
+ // conv_kernel
+ sec[j].weight_start_id = pos;
+ pos += sqr(m) * channel_in * channel_out;
+ total_para_size += sqr(m) * channel_in * channel_out;
+ sec[j].bias_start_id = -1;
+ }
+ }
+ total_in_size = pos;
+
+ SIZE = 1 + conv_layer_cnt;
+ cerr << "SIZE: " << SIZE << endl;
+}
+
+vector singleConv::getFFTAns(const vector &output) {
+ vector res;
+ res.resize(nx_out * ny_out * channel_out * pic_channel * pic_parallel);
+
+ i64 lst_fft_lenh = getFFTLen() >> 1;
+ i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding;
+ for (i64 p = 0; p < pic_parallel; ++p)
+ for (i64 co = 0; co < channel_out; ++co)
+ for (i64 x = L; x + m <= Rx; x += (1 << log_stride))
+ for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) {
+ i64 idx = tesIdx(p, co, ((x - L) >> log_stride), ((y - L) >> log_stride), channel_out, nx_out, ny_out);
+ i64 i = cubIdx(p, co, matIdx(Rx - x - 1, Ry - y - 1, ny_padded_in), channel_out, lst_fft_lenh);
+ res[idx] = output[i];
+ }
+ return res;
+}
+
+double singleConv::calcRawFFT() {
+ auto in = val[0].begin();
+ auto conv = val[0].begin() + conv_section[0][0].weight_start_id;
+ auto bias = val[0].begin() + conv_section[0][0].bias_start_id;
+
+ timer tm;
+ int logn = ceilPow2BitLength(nx_padded_in * ny_padded_in) + 1;
+ vector res(nx_out * ny_out * channel_out * pic_parallel, F_ZERO);
+ vector arr1(1 << logn, F_ZERO);
+ vector arr2(arr1.size(), F_ZERO);
+
+ assert(pic_parallel == 1 && pic_channel == 1);
+ // data matrix
+ i64 L = -padding;
+ i64 Rx = nx_in + padding, Ry = ny_in + padding;
+
+ tm.start();
+ for (i64 x = L; x < Rx; ++x)
+ for (i64 y = L; y < Ry; ++y)
+ if (check(x, y, nx_in, ny_in)) {
+ i64 g = matIdx(Rx - x - 1, Ry - y - 1, ny_padded_in);
+ i64 u = matIdx(x, y, ny_in);
+ arr1[g] = in[u];
+ }
+
+ // kernel matrix
+ for (i64 x = 0; x < nx_padded_in; ++x)
+ for (i64 y = 0; y < ny_padded_in; ++y)
+ if (check(x, y, m, m)) {
+ i64 g = matIdx(x, y, ny_padded_in);
+ i64 u = matIdx(x, y, m);
+ arr2[g] = conv[u];
+ }
+
+ fft(arr1, logn, false);
+ fft(arr2, logn, false);
+ for (i64 i = 0; i < arr1.size(); ++i)
+ arr1[i] = arr1[i] * arr2[i];
+ fft(arr1, logn, true);
+ reverse(arr1.begin(), arr1.end());
+
+ tm.stop();
+ return tm.elapse_sec();
+}
+
+double singleConv::calcRawNaive() {
+ auto in = val[0].begin();
+ auto conv = val[0].begin() + conv_section[0][0].weight_start_id;
+ auto bias = val[0].begin() + conv_section[0][0].bias_start_id;
+
+ timer tm;
+ vector res(nx_out * ny_out * channel_out * pic_parallel, F_ZERO);
+ tm.start();
+ for (i64 p = 0; p < pic_parallel; ++p)
+ for (i64 i = 0; i < channel_out; ++i)
+ for (i64 j = 0; j < channel_in; ++j)
+ for (i64 x = -padding; x + m <= nx_in + padding; x += (1 << log_stride))
+ for (i64 y = -padding; y + m <= ny_in + padding; y += (1 << log_stride)) {
+ i64 idx = tesIdx(p, i, (x + padding) >> log_stride, (y + padding) >> log_stride, channel_out, nx_out, ny_out);
+ if (j == 0) res[idx] = res[idx] + bias[i];
+ for (i64 tx = x; tx < x + m; ++tx)
+ for (i64 ty = y; ty < y + m; ++ty)
+ if (check(tx, ty, nx_in, ny_in)) {
+ i64 u = tesIdx(p, j, tx, ty, channel_in, nx_in, ny_in);
+ i64 v = tesIdx(i, j, tx - x, ty - y, channel_in, m, m);
+ res.at(idx) = res.at(idx) + in[u] * conv[v];
+ }
+ }
+ tm.stop();
+ return tm.elapse_sec();
+}
diff --git a/src/models.hpp b/src/models.hpp
new file mode 100644
index 0000000..3a71a9c
--- /dev/null
+++ b/src/models.hpp
@@ -0,0 +1,66 @@
+//
+// Created by 69029 on 3/16/2021.
+//
+
+#ifndef ZKCNN_VGG_HPP
+#define ZKCNN_VGG_HPP
+
+#include "neuralNetwork.hpp"
+
+class vgg: public neuralNetwork {
+
+public:
+ explicit vgg(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, const std::string &i_filename, const string &c_filename, const std::string &o_filename, const std::string &n_filename);
+
+};
+
+class vgg16: public neuralNetwork {
+
+public:
+ explicit vgg16(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty,
+ const std::string &i_filename, const string &c_filename, const std::string &o_filename);
+
+};
+
+class vgg11: public neuralNetwork {
+
+public:
+ explicit vgg11(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty,
+ const std::string &i_filename, const string &c_filename, const std::string &o_filename);
+
+};
+
+class lenet: public neuralNetwork {
+public:
+ explicit lenet(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty,
+ const std::string &i_filename, const string &c_filename, const std::string &o_filename);
+};
+
+class lenetCifar: public neuralNetwork {
+public:
+ explicit lenetCifar(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, convType conv_ty, poolType pool_ty,
+ const std::string &i_filename, const string &c_filename, const std::string &o_filename);
+};
+
+class ccnn: public neuralNetwork {
+public:
+ explicit ccnn(i64 psize_x, i64 psize_y, i64 pparallel, i64 pchannel, poolType pool_ty);
+};
+
+class singleConv: public neuralNetwork {
+public:
+ explicit singleConv(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, i64 kernel_size, i64 channel_out,
+ i64 log_stride, i64 padding, convType conv_ty);
+
+ void createConv(prover &p);
+
+ void initParamConv();
+
+ vector getFFTAns(const vector &output);
+
+ double calcRawFFT();
+
+ double calcRawNaive();
+};
+
+#endif //ZKCNN_VGG_HPP
diff --git a/src/neuralNetwork.cpp b/src/neuralNetwork.cpp
new file mode 100644
index 0000000..7d77cdc
--- /dev/null
+++ b/src/neuralNetwork.cpp
@@ -0,0 +1,1016 @@
+//
+// Created by 69029 on 3/16/2021.
+//
+
+#include "neuralNetwork.hpp"
+#include "utils.hpp"
+#include "global_var.hpp"
+#include
+#include
+#include
+#include
+
+using std::cerr;
+using std::endl;
+using std::max;
+using std::ifstream;
+using std::ofstream;
+
+ifstream in;
+ifstream conf;
+ofstream out;
+
+neuralNetwork::neuralNetwork(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, const string &i_filename,
+ const string &c_filename, const string &o_filename) :
+ pic_size_x(psize_x), pic_size_y(psize_y), pic_channel(pchannel), pic_parallel(pparallel),
+ SIZE(0), NCONV_FAST_SIZE(1), NCONV_SIZE(2), FFT_SIZE(5),
+ AVE_POOL_SIZE(1), FC_SIZE(1), RELU_SIZE(1), act_ty(RELU_ACT) {
+
+ in.open(i_filename);
+ if (!in.is_open())
+ fprintf(stderr, "Can't find the input file!!!\n");
+ conf.open(c_filename);
+ if (!conf.is_open())
+ fprintf(stderr, "Can't find the config file!!!\n");
+
+ if (!o_filename.empty()) out.open(o_filename);
+}
+
+neuralNetwork::neuralNetwork(i64 psize, i64 pchannel, i64 pparallel, i64 kernel_size, i64 sec_size, i64 fc_size,
+ i64 start_channel, convType conv_ty, poolType pool_ty)
+ : neuralNetwork(psize, psize, pchannel, pparallel, "", "", "") {
+ pool_bl = 2;
+ pool_stride_bl = pool_bl >> 1;
+ conv_section.resize(sec_size);
+
+ i64 start = start_channel;
+ for (i64 i = 0; i < sec_size; ++i) {
+ conv_section[i].emplace_back(conv_ty, start << i, i ? (start << (i - 1)) : pic_channel, kernel_size);
+ conv_section[i].emplace_back(conv_ty, start << i, start << i, kernel_size);
+ pool.emplace_back(pool_ty, 2, 1);
+ }
+
+ i64 new_nx = (pic_size_x >> pool_stride_bl * conv_section.size());
+ i64 new_ny = (pic_size_y >> pool_stride_bl * conv_section.size());
+ for (i64 i = 0; i < fc_size; ++i)
+ full_conn.emplace_back(i == fc_size - 1 ? 1000 : 4096, i ? 4096 : new_nx * new_ny * (start << (sec_size - 1)));
+}
+
+void neuralNetwork::create(prover &pr, bool only_compute) {
+ assert(pool.size() >= conv_section.size() - 1);
+
+ initParam();
+ pr.C.init(Q_BIT_SIZE, SIZE);
+
+ pr.val.resize(SIZE);
+ val = pr.val.begin();
+ two_mul = pr.C.two_mul.begin();
+
+ i64 layer_id = 0;
+ inputLayer(pr.C.circuit[layer_id++]);
+
+ new_nx_in = pic_size_x;
+ new_ny_in = pic_size_y;
+ for (i64 i = 0; i < conv_section.size(); ++i) {
+ auto &sec = conv_section[i];
+ for (i64 j = 0; j < sec.size(); ++j) {
+ auto &conv = sec[j];
+ refreshConvParam(new_nx_in, new_ny_in, conv);
+ pool_ty = i < pool.size() && j == sec.size() - 1 ? pool[i].ty : NONE;
+ x_bit = x_next_bit;
+ switch (conv.ty) {
+ case FFT:
+ paddingLayer(pr.C.circuit[layer_id], layer_id, conv.weight_start_id);
+ fftLayer(pr.C.circuit[layer_id], layer_id);
+ dotProdLayer(pr.C.circuit[layer_id], layer_id);
+ ifftLayer(pr.C.circuit[layer_id], layer_id);
+ addBiasLayer(pr.C.circuit[layer_id], layer_id, conv.bias_start_id);
+ break;
+ case NAIVE_FAST:
+ naiveConvLayerFast(pr.C.circuit[layer_id], layer_id, conv.weight_start_id, conv.bias_start_id);
+ break;
+ default:
+ naiveConvLayerMul(pr.C.circuit[layer_id], layer_id, conv.weight_start_id);
+ naiveConvLayerAdd(pr.C.circuit[layer_id], layer_id, conv.bias_start_id);
+ }
+
+ // update the scale bit
+ x_next_bit = getNextBit(layer_id - 1);
+ T = x_bit + w_bit - x_next_bit;
+ Q_MAX = Q + T;
+ if (pool_ty != MAX)
+ reluActConvLayer(pr.C.circuit[layer_id], layer_id);
+ }
+
+ if (i >= pool.size()) continue;
+ calcSizeAfterPool(pool[i]);
+ switch (pool[i].ty) {
+ case AVG: avgPoolingLayer(pr.C.circuit[layer_id], layer_id); break;
+ case MAX: maxPoolingLayer(pr.C, layer_id, pool[i].dcmp_start_id, pool[i].max_start_id,
+ pool[i].max_dcmp_start_id); break;
+ }
+ }
+
+ pool_ty = NONE;
+ for (int i = 0; i < full_conn.size(); ++i) {
+ auto &fc = full_conn[i];
+ refreshFCParam(fc);
+ x_bit = x_next_bit;
+ fullyConnLayer(pr.C.circuit[layer_id], layer_id, fc.weight_start_id, fc.bias_start_id);
+ if (i == full_conn.size() - 1) break;
+
+ // update the scale bit
+ x_next_bit = getNextBit(layer_id - 1);
+ T = x_bit + w_bit - x_next_bit;
+ Q_MAX = Q + T;
+ reluActFconLayer(pr.C.circuit[layer_id], layer_id);
+ }
+
+ assert(SIZE == layer_id);
+
+ total_in_size += total_max_in_size + total_ave_in_size + total_relu_in_size;
+ initLayer(pr.C.circuit[0], total_in_size, layerType::INPUT);
+ assert(total_in_size == pr.val[0].size());
+
+ printInfer(pr);
+// printLayerValues(pr);
+
+ if (only_compute) return;
+ pr.C.initSubset();
+ cerr << "finish creating circuit." << endl;
+}
+
+void neuralNetwork::inputLayer(layer &circuit) {
+ initLayer(circuit, total_in_size, layerType::INPUT);
+
+ for (i64 i = 0; i < total_in_size; ++i)
+ circuit.uni_gates.emplace_back(i, 0, 0, 0);
+
+ calcInputLayer(circuit);
+ printLayerInfo(circuit, 0);
+}
+
+void
+neuralNetwork::paddingLayer(layer &circuit, i64 &layer_id, i64 first_conv_id) {
+ i64 lenh = getFFTLen() >> 1;
+ i64 size = lenh * channel_in * (pic_parallel + channel_out);
+ initLayer(circuit, size, layerType::PADDING);
+ circuit.fft_bit_length = getFFTBitLen();
+
+ // data matrix
+ i64 L = -padding;
+ i64 Rx = nx_in + padding, Ry = ny_in + padding;
+
+ for (i64 p = 0; p < pic_parallel; ++p)
+ for (i64 ci = 0; ci < channel_in; ++ci)
+ for (i64 x = L; x < Rx; ++x)
+ for (i64 y = L; y < Ry; ++y)
+ if (check(x, y, nx_in, ny_in)) {
+ i64 g = cubIdx(p, ci, matIdx(Rx - x - 1, Ry - y - 1, ny_padded_in), channel_in, lenh);
+ i64 u = tesIdx(p, ci, x, y, channel_in, nx_in, ny_in);
+ circuit.uni_gates.emplace_back(g, u, layer_id - 1, 0);
+ }
+
+ // kernel matrix
+ i64 first = pic_parallel * channel_in * lenh;
+ for (i64 co = 0; co < channel_out; ++co)
+ for (i64 ci = 0; ci < channel_in; ++ci)
+ for (i64 x = 0; x < nx_padded_in; ++x)
+ for (i64 y = 0; y < ny_padded_in; ++y)
+ if (check(x, y, m, m)) {
+ i64 g = first + cubIdx(co, ci, matIdx(x, y, ny_padded_in), channel_in, lenh) ;
+ i64 u = first_conv_id + tesIdx(co, ci, x, y, channel_in, m, m);
+ circuit.uni_gates.emplace_back(g, u, 0, 0);
+ }
+
+ readConvWeight(first_conv_id);
+ calcNormalLayer(circuit, layer_id);
+ printLayerInfo(circuit, layer_id++);
+}
+
+void neuralNetwork::fftLayer(layer &circuit, i64 &layer_id) {
+ i64 size = getFFTLen() * channel_in * (pic_parallel + channel_out);
+ initLayer(circuit, size, layerType::FFT);
+ circuit.fft_bit_length = getFFTBitLen();
+
+ calcFFTLayer(circuit, layer_id);
+ printLayerInfo(circuit, layer_id++);
+}
+
+void neuralNetwork::dotProdLayer(layer &circuit, i64 &layer_id) {
+ i64 len = getFFTLen();
+ i64 size = len * channel_out * pic_parallel;
+ initLayer(circuit, size, layerType::DOT_PROD);
+ circuit.need_phase2 = true;
+ circuit.fft_bit_length = getFFTBitLen();
+
+ for (i64 p = 0; p < pic_parallel; ++p)
+ for (i64 co = 0; co < channel_out; ++co)
+ for (i64 ci = 0; ci < channel_in; ++ci) {
+ i64 g = matIdx(p, co, channel_out);
+ i64 u = matIdx(p, ci, channel_in);
+ i64 v = matIdx(pic_parallel + co, ci, channel_in);
+ circuit.bin_gates.emplace_back(g, u, v, 0, 1);
+ }
+
+ calcDotProdLayer(circuit, layer_id);
+ printLayerInfo(circuit, layer_id++);
+}
+
+void neuralNetwork::ifftLayer(layer &circuit, i64 &layer_id) {
+ i64 len = getFFTLen(), lenh = len >> 1;
+ i64 size = lenh * channel_out * pic_parallel;
+ initLayer(circuit, size, layerType::IFFT);
+ circuit.fft_bit_length = getFFTBitLen();
+ F::inv(circuit.scale, F(1ULL << circuit.fft_bit_length));
+
+ calcFFTLayer(circuit, layer_id);
+ printLayerInfo(circuit, layer_id++);
+}
+
+void neuralNetwork::addBiasLayer(layer &circuit, i64 &layer_id, i64 first_bias_id) {
+ i64 len = getFFTLen();
+ i64 size = nx_out * ny_out * channel_out * pic_parallel;
+ initLayer(circuit, size, layerType::ADD_BIAS);
+
+ i64 lenh = len >> 1;
+ i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding;
+ for (i64 p = 0; p < pic_parallel; ++p)
+ for (i64 co = 0; co < channel_out; ++co)
+ for (i64 x = L; x + m <= Rx; x += (1 << log_stride))
+ for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) {
+ i64 u = cubIdx(p, co, matIdx(Rx - x - 1, Ry - y - 1, ny_padded_in), channel_out, lenh);
+ i64 g = tesIdx(p, co, (x - L) >> log_stride, (y - L) >> log_stride, channel_out, nx_out, ny_out);
+ circuit.uni_gates.emplace_back(g, first_bias_id + co, 0, 0);
+ circuit.uni_gates.emplace_back(g, u, layer_id - 1, 0);
+ }
+
+ readBias(first_bias_id);
+ calcNormalLayer(circuit, layer_id);
+ printLayerInfo(circuit, layer_id++);
+}
+
+void neuralNetwork::naiveConvLayerFast(layer &circuit, i64 &layer_id, i64 first_conv_id, i64 first_bias_id) {
+ i64 size = nx_out * ny_out * channel_out * pic_parallel;
+ initLayer(circuit, size, layerType::NCONV);
+ circuit.need_phase2 = true;
+
+ i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding;
+ i64 mat_in_size = nx_in * ny_in;
+ i64 m_sqr = sqr(m);
+ for (i64 p = 0; p < pic_parallel; ++p)
+ for (i64 co = 0; co < channel_out; ++co)
+ for (i64 ci = 0; ci < channel_in; ++ci)
+ for (i64 x = L; x + m <= Rx; x += (1 << log_stride))
+ for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) {
+ i64 g = tesIdx(p, co, ((x - L) >> log_stride), ((y - L) >> log_stride), channel_out, nx_out, ny_out);
+ if (ci == 0 && ~first_bias_id) circuit.uni_gates.emplace_back(g, first_bias_id + co, 0, 0);
+ for (i64 tx = x; tx < x + m; ++tx)
+ for (i64 ty = y; ty < y + m; ++ty)
+ if (check(tx, ty, nx_in, ny_in)) {
+ i64 u = tesIdx(p, ci, tx, ty, channel_in, nx_in, ny_in);
+ i64 v = first_conv_id + tesIdx(co, ci, tx - x, ty - y, channel_in, m, m);
+ circuit.bin_gates.emplace_back(g, u, v, 0, 2 * (u8) (layer_id > 1));
+ }
+ }
+
+ readConvWeight(first_conv_id);
+ if (~first_bias_id) readBias(first_bias_id);
+ calcNormalLayer(circuit, layer_id);
+ printLayerInfo(circuit, layer_id++);
+}
+
+void neuralNetwork::naiveConvLayerMul(layer &circuit, i64 &layer_id, i64 first_conv_id) {
+ i64 mat_out_size = nx_out * ny_out;
+ i64 mat_in_size = nx_in * ny_in;
+ i64 m_sqr = sqr(m);
+ i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding;
+
+ i64 g = 0;
+ for (i64 p = 0; p < pic_parallel; ++p)
+ for (i64 co = 0; co < channel_out; ++co)
+ for (i64 ci = 0; ci < channel_in; ++ci)
+ for (i64 x = L; x + m <= Rx; x += (1 << log_stride))
+ for (i64 y = L; y + m <= Ry; y += (1 << log_stride))
+ for (i64 tx = x; tx < x + m; ++tx)
+ for (i64 ty = y; ty < y + m; ++ty)
+ if (check(tx, ty, nx_in, ny_in)) {
+ i64 u = tesIdx(p, ci, tx, ty, channel_in, nx_in, ny_in);
+ i64 v = first_conv_id + tesIdx(co, ci, tx - x, ty - y, channel_in, m, m);
+ circuit.bin_gates.emplace_back(g++, u, v, 0, 2 * (u8) (layer_id > 1));
+ }
+
+ initLayer(circuit, g, layerType::NCONV_MUL);
+ circuit.need_phase2 = true;
+ readConvWeight(first_conv_id);
+ calcNormalLayer(circuit, layer_id);
+ printLayerInfo(circuit, layer_id++);
+}
+
+void neuralNetwork::naiveConvLayerAdd(layer &circuit, i64 &layer_id, i64 first_bias_id) {
+ i64 size = nx_out * ny_out * channel_out * pic_parallel;
+ initLayer(circuit, size, layerType::NCONV_ADD);
+
+ i64 mat_in_size = nx_in * ny_in;
+ i64 m_sqr = sqr(m);
+ i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding;
+
+ i64 u = 0;
+ for (i64 p = 0; p < pic_parallel; ++p)
+ for (i64 co = 0; co < channel_out; ++co)
+ for (i64 ci = 0; ci < channel_in; ++ci)
+ for (i64 x = L; x + m <= Rx; x += (1 << log_stride))
+ for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) {
+ i64 g = tesIdx(p, co, ((x - L) >> log_stride),( (y - L) >> log_stride), channel_out, nx_out, ny_out);
+ i64 cnt = 0;
+ if (ci == 0 && ~first_bias_id) {
+ circuit.uni_gates.emplace_back(g, first_bias_id + co, 0, 0);
+ ++cnt;
+ }
+ for (i64 tx = x; tx < x + m; ++tx)
+ for (i64 ty = y; ty < y + m; ++ty)
+ if (check(tx, ty, nx_in, ny_in)) {
+ circuit.uni_gates.emplace_back(g, u++, layer_id - 1, 0);
+ ++cnt;
+ }
+ }
+
+ if (~first_bias_id) readBias(first_bias_id);
+ calcNormalLayer(circuit, layer_id);
+ printLayerInfo(circuit, layer_id++);
+}
+
+void neuralNetwork::reluActConvLayer(layer &circuit, i64 &layer_id) {
+ i64 mat_out_size = nx_out * ny_out;
+ i64 size = 1L * mat_out_size * channel_out * (2 + Q_MAX) * pic_parallel;
+ i64 block_len = mat_out_size * channel_out * pic_parallel;
+
+ i64 dcmp_cnt = block_len * Q_MAX;
+ i64 first_dcmp_id = val[0].size();
+ val[0].resize(val[0].size() + dcmp_cnt);
+ total_relu_in_size += dcmp_cnt;
+
+ initLayer(circuit, size, layerType::RELU);
+ circuit.need_phase2 = true;
+
+ circuit.zero_start_id = block_len;
+
+ for (i64 g = 0; g < block_len; ++g) {
+ i64 sign_u = first_dcmp_id + g * Q_MAX;
+ for (i64 s = 1; s < Q; ++s) {
+ i64 v = sign_u + s;
+ circuit.uni_gates.emplace_back(g, v, 0, Q - 1 - s);
+ circuit.bin_gates.emplace_back(g, sign_u, v, Q - s + Q_BIT_SIZE, 0);
+ }
+ }
+
+ i64 len = getFFTLen();
+ i64 lenh = len >> 1;
+ i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding;
+ for (i64 p = 0; p < pic_parallel; ++p)
+ for (i64 co = 0; co < channel_out; ++ co)
+ for (i64 x = L; x + m <= Rx; x += (1 << log_stride))
+ for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) {
+ i64 u = tesIdx(p, co, (x - L) >> log_stride, (y - L) >> log_stride, channel_out, nx_out, ny_out);
+ i64 g = block_len + u, sign_v = first_dcmp_id + u * Q_MAX;
+ circuit.uni_gates.emplace_back(g, u, layer_id - 1, Q_BIT_SIZE + 1);
+ circuit.bin_gates.emplace_back(g, u, sign_v, 1, 2 * (u8) (layer_id > 1));
+ prepareSignBit(layer_id - 1, u, sign_v);
+ for (i64 s = 1; s < Q_MAX; ++s) {
+ i64 v = sign_v + s;
+ circuit.uni_gates.emplace_back(g, v, 0, Q_MAX - s - 1);
+ prepareDecmpBit(layer_id - 1, u, v, Q_MAX - s - 1);
+ }
+ }
+
+ for (i64 g = block_len << 1; g < (block_len << 1) + block_len * Q_MAX; ++g) {
+ i64 u = first_dcmp_id + g - (block_len << 1);
+ circuit.bin_gates.emplace_back(g, u, u, 0, 0);
+ circuit.uni_gates.emplace_back(g, u, 0, Q_BIT_SIZE + 1);
+ }
+
+ calcNormalLayer(circuit, layer_id);
+ printLayerInfo(circuit, layer_id++);
+}
+
+void neuralNetwork::reluActFconLayer(layer &circuit, i64 &layer_id) {
+ i64 block_len = channel_out * pic_parallel;
+ i64 size = block_len * (2 + Q_MAX);
+ initLayer(circuit, size, layerType::RELU);
+ circuit.zero_start_id = block_len;
+ circuit.need_phase2 = true;
+
+ i64 dcmp_cnt = block_len * Q_MAX;
+ i64 first_dcmp_id = val[0].size();
+ val[0].resize(val[0].size() + dcmp_cnt);
+ total_relu_in_size += dcmp_cnt;
+
+ for (i64 g = 0; g < block_len; ++g) {
+ i64 sign_u = first_dcmp_id + g * Q_MAX;
+ for (i64 s = 1; s < Q; ++s) {
+ i64 v = sign_u + s;
+ circuit.uni_gates.emplace_back(g, v, 0, (Q - s - 1));
+ circuit.bin_gates.emplace_back(g, sign_u, v, Q - s + Q_BIT_SIZE, 0);
+ }
+ }
+
+ for (i64 u = 0; u < block_len; ++u) {
+ i64 g = block_len + u, sign_v = first_dcmp_id + u * Q_MAX;
+ circuit.uni_gates.emplace_back(g, u, layer_id - 1, Q_BIT_SIZE + 1);
+ circuit.bin_gates.emplace_back(g, u, sign_v, 1, 2 * (u8) (layer_id > 1));
+ prepareSignBit(layer_id - 1, u, sign_v);
+
+ for (i64 s = 1; s < Q_MAX; ++s) {
+ i64 v = sign_v + s;
+ circuit.uni_gates.emplace_back(g, v, 0, Q_MAX - s - 1);
+ prepareDecmpBit(layer_id - 1, u, v, Q_MAX - s - 1);
+ }
+ }
+
+ for (i64 g = block_len << 1; g < (block_len << 1) + block_len * Q_MAX; ++g) {
+ i64 u = first_dcmp_id + g - (block_len << 1);
+ circuit.bin_gates.emplace_back(g, u, u, 0, 0);
+ circuit.uni_gates.emplace_back(g, u, 0, Q_BIT_SIZE + 1);
+ }
+
+ calcNormalLayer(circuit, layer_id);
+ printLayerInfo(circuit, layer_id++);
+}
+
+void neuralNetwork::avgPoolingLayer(layer &circuit, i64 &layer_id) {
+ i64 mat_out_size = nx_out * ny_out;
+ i64 zero_start_id = new_nx_in * new_ny_in * channel_out * pic_parallel;
+ i64 size = zero_start_id + getPoolDecmpSize();
+ u8 dpool_bl = pool_bl << 1;
+ i64 pool_sz_sqr = sqr(pool_sz);
+ initLayer(circuit, size, layerType::AVG_POOL);
+ F::inv(circuit.scale, pool_sz_sqr);
+ circuit.zero_start_id = zero_start_id;
+ circuit.need_phase2 = true;
+
+ i64 first_gate_id = val[0].size();
+ val[0].resize(val[0].size() + zero_start_id * dpool_bl);
+ total_ave_in_size += zero_start_id * dpool_bl;
+
+ // [0 .. zero_start_id]
+ // [zero_start_id .. zero_start_id + (g = 0..channel_out * mat_new_size) * dpool_bl + rm_i .. channel_out * mat_new_size * (1 + dpool_bl)]
+ for (i64 p = 0; p < pic_parallel; ++p)
+ for (i64 co = 0; co < channel_out; ++co)
+ for (i64 x = 0; x + pool_sz <= nx_out; x += pool_stride)
+ for (i64 y = 0; y + pool_sz <= ny_out; y += pool_stride) {
+ i64 g = tesIdx(p, co, (x >> pool_stride_bl), (y >> pool_stride_bl), channel_out, new_nx_in, new_ny_in);
+ F data = F_ZERO;
+ for (i64 tx = x; tx < x + pool_sz; ++tx)
+ for (i64 ty = y; ty < y + pool_sz; ++ty) {
+ i64 u = tesIdx(p, co, tx, ty, channel_out, nx_out, ny_out);
+ circuit.uni_gates.emplace_back(g, u, layer_id - 1, 0);
+ data = data + val[layer_id - 1][u];
+ }
+
+ for (i64 rm_i = 0; rm_i < dpool_bl; ++rm_i) {
+ i64 idx = matIdx(g, rm_i, dpool_bl), u = first_gate_id + idx, g_bit = zero_start_id + idx;
+ circuit.uni_gates.emplace_back(g, u, 0, dpool_bl - rm_i + Q_BIT_SIZE);
+ prepareFieldBit(F(data), u, dpool_bl - rm_i - 1);
+
+ // check bit
+ circuit.bin_gates.emplace_back(g_bit, u, u, 0, 0);
+ circuit.uni_gates.emplace_back(g_bit, u, 0, Q_BIT_SIZE + 1);
+ }
+ }
+
+ calcNormalLayer(circuit, layer_id);
+ printLayerInfo(circuit, layer_id++);
+}
+
+void
+neuralNetwork::maxPoolingLayer(layeredCircuit &C, i64 &layer_id, i64 first_dcmp_id, i64 first_max_id,
+ i64 first_max_dcmp_id) {
+ i64 mat_out_size = nx_out * ny_out;
+ i64 tot_out_size = mat_out_size * channel_out * pic_parallel;
+ i64 mat_new_size = new_nx_in * new_ny_in;
+ i64 tot_new_size = mat_new_size * channel_out * pic_parallel;
+ i64 pool_sz_sqr = sqr(pool_sz);
+
+ i64 dcmp_cnt = getPoolDecmpSize();
+ first_dcmp_id = val[0].size();
+ val[0].resize(val[0].size() + dcmp_cnt);
+ total_max_in_size += dcmp_cnt;
+
+ i64 max_cnt = tot_new_size;
+ first_max_id = val[0].size();
+ val[0].resize(val[0].size() + max_cnt);
+ total_max_in_size += max_cnt;
+
+ i64 max_dcmp_cnt = tot_new_size * (Q_MAX - 1);
+ first_max_dcmp_id = val[0].size();
+ val[0].resize(val[0].size() + max_dcmp_cnt);
+ total_max_in_size += max_dcmp_cnt;
+
+ // 0: max - everyone & max - (max bits) == 0
+ // [0..tot_new_size * sqr(pool_sz)][tot_new_size * sqr(pool_sz)..tot_new_size * sqr(pool_sz) + tot_new_size]
+ i64 size_0 = tot_new_size * pool_sz_sqr + tot_new_size;
+ layer &circuit = C.circuit[layer_id];
+ initLayer(circuit, size_0, layerType::MAX_POOL);
+ circuit.zero_start_id = tot_new_size * pool_sz_sqr;
+ i64 fft_len = getFFTLen(), fft_lenh = fft_len >> 1;
+ for (i64 p = 0; p < pic_parallel; ++p)
+ for (i64 co = 0; co < channel_out; ++co) {
+ for (i64 x = 0; x + pool_sz <= nx_out; x += pool_stride)
+ for (i64 y = 0; y + pool_sz <= ny_out; y += pool_stride) {
+ i64 i_max = tesIdx(p, co, x >> pool_stride_bl, y >> pool_stride_bl, channel_out, new_nx_in, new_ny_in);
+ i64 u_max = first_max_id + i_max;
+ for (i64 tx = x; tx < x + pool_sz; ++tx)
+ for (i64 ty = y; ty < y + pool_sz; ++ty) {
+ i64 g = cubIdx(tesIdx(p, co, x >> pool_stride_bl, y >> pool_stride_bl, channel_out, new_nx_in, new_ny_in), tx - x, ty - y, pool_sz, pool_sz);
+ i64 u_g = tesIdx(p, co, tx, ty, channel_out, nx_out, ny_out);
+ circuit.uni_gates.emplace_back(g, u_max, 0, 0);
+ circuit.uni_gates.emplace_back(g, u_g, layer_id - 1, Q_BIT_SIZE + 1);
+ prepareMax(layer_id - 1, u_g, u_max);
+ }
+ }
+ }
+
+ for (i64 i_new = 0; i_new < tot_new_size; ++i_new) {
+ i64 g_new = circuit.zero_start_id + i_new;
+ i64 u_new = first_max_id + i_new;
+ circuit.uni_gates.emplace_back(g_new, u_new, 0, Q_BIT_SIZE + 1);
+ for (i64 i_new_bit = 0; i_new_bit < Q_MAX - 1; ++i_new_bit) {
+ i64 u_new_bit = first_max_dcmp_id + matIdx(i_new, i_new_bit, Q_MAX - 1);
+ circuit.uni_gates.emplace_back(g_new, u_new_bit, 0, Q_MAX - 2 - i_new_bit);
+ prepareDecmpBit(0, u_new, u_new_bit, Q_MAX - 2 - i_new_bit);
+ }
+ }
+ calcNormalLayer(circuit, layer_id);
+ printLayerInfo(circuit, layer_id++);
+
+ // 1: (max - someone)^2 & max - everyone - ((max - everyone) bits) == 0
+ // [0..tot_new_size * (sqr(pool_sz) + 1 >> 1)][tot_new_size * (sqr(pool_sz) + 1 >> 1)..tot_new_size * (sqr(pool_sz) + 1 >> 1) + tot_new_size * sqr(pool_sz)]
+ // 2: (max - someone)^4
+ // ?: (max - someone)^(2^?)
+ // [0..(((tot_out_size + 1) / 2 + 1) / 2...+ 1) / 2]
+ // f: new tensor & (max - someone)^(pool_sz^2 + 1) & all (include minus and max) bits check
+ // [0..tot_new_size]
+ // [tot_new_size..tot_new_size * 2]
+ // [tot_new_size * 2..tot_new_size * (Q + 1)]
+ // [tot_new_size * (Q + 1)
+ // ..tot_new_size * (Q + 1) + (g = 0..tot_out_size) * (Q - 1) + bit_i
+ // ..tot_new_size * (Q + 1) + tot_new_size * (pool_sz^2) * (Q - 1)]
+ i64 contain_max_ly = 1, ksize = pool_sz_sqr;
+ while (!(ksize & 1)) { ksize >>= 1; ++contain_max_ly; }
+ ksize = pool_sz_sqr;
+
+ for (int i = 1; i < pool_layer_cnt; ++i) {
+ layer &circuit = C.circuit[layer_id];
+ i64 size = tot_new_size * ( ((ksize + 1 )>> 1) + (i64) (i == 1) * ksize ) +
+ (i64) (i == pool_layer_cnt - 1) * tot_new_size * Q_MAX +
+ (i64) (i == pool_layer_cnt - 1) * tot_new_size * pool_sz_sqr * (Q_MAX - 1);
+ initLayer(circuit, size, layerType::MAX_POOL);
+ circuit.need_phase2 = true;
+
+ // new tensor
+ i64 before_mul = 0;
+ if (i == pool_layer_cnt - 1) {
+ before_mul = tot_new_size;
+ for (i64 g = 0; g < tot_new_size; ++g)
+ for (i64 j = 0; j < Q - 1; ++j) {
+ i64 u = first_max_dcmp_id + matIdx(g, j, Q_MAX - 1);
+ circuit.uni_gates.emplace_back(g, u, 0, Q - 2 - j);
+ }
+ }
+
+ // multiplications of subtraction
+ for (i64 cnt = 0; cnt < tot_new_size; ++cnt) {
+ i64 v_max = first_max_id + cnt;
+ for (i64 j = 0; (j << 1) < ksize; ++j) {
+ i64 idx = matIdx(cnt, j, (ksize + 1) >> 1);
+ i64 g = before_mul + idx;
+ i64 u = matIdx(cnt, (j << 1), ksize);
+ if ((j << 1 | 1) < ksize) {
+ i64 v = matIdx(cnt, (j << 1 | 1), ksize);
+ circuit.bin_gates.emplace_back(g, u, v, 0, layer_id > 1);
+ } else if (i == contain_max_ly)
+ circuit.bin_gates.emplace_back(g, u, v_max, 0, 2 * (u8) (layer_id > 1));
+ else
+ circuit.uni_gates.emplace_back(g, u, layer_id - 1, 0);
+ }
+ }
+
+ if (i == 1) {
+ i64 minus_cnt = tot_new_size * ksize;
+ i64 minus_new_cnt = tot_new_size * ((ksize + 1) >> 1);
+ circuit.zero_start_id = minus_new_cnt;
+ for (i64 v = 0; v < minus_cnt; ++v) {
+ i64 g = minus_new_cnt + v;
+ circuit.uni_gates.emplace_back(g, v, layer_id - 1, Q_BIT_SIZE + 1);
+ for (i64 bit_j = 0; bit_j < Q_MAX - 1; ++bit_j) {
+ i64 u = first_dcmp_id + matIdx(v, bit_j, Q_MAX - 1);
+ circuit.uni_gates.emplace_back(g, u, 0, Q_MAX - 2 - bit_j);
+ prepareDecmpBit(layer_id - 1, v, u, Q_MAX - 2 - bit_j);
+ }
+ }
+ } else if (i == pool_layer_cnt - 1) {
+ i64 minus_cnt = tot_new_size * pool_sz_sqr;
+ circuit.zero_start_id = before_mul;
+ for (i64 j = 0; j < minus_cnt; ++j) {
+ i64 g = before_mul + tot_new_size + j;
+ i64 u = first_dcmp_id + j;
+ circuit.bin_gates.emplace_back(g, u, u, 0, 0);
+ circuit.uni_gates.emplace_back(g, u, 0, Q_BIT_SIZE + 1);
+ }
+ }
+ ksize = (ksize + 1) >> 1;
+ calcNormalLayer(circuit, layer_id);
+ printLayerInfo(circuit, layer_id++);
+ }
+
+}
+
+void neuralNetwork::fullyConnLayer(layer &circuit, i64 &layer_id, i64 first_fc_id, i64 first_bias_id) {
+ i64 size = channel_out * pic_parallel;
+ initLayer(circuit, size, layerType::FCONN);
+ circuit.need_phase2 = true;
+
+ for (i64 p = 0; p < pic_parallel; ++p)
+ for (i64 co = 0; co < channel_out; ++co) {
+ i64 g = matIdx(p, co, channel_out);
+ circuit.uni_gates.emplace_back(g, first_bias_id + co, 0, 0);
+ for (i64 ci = 0; ci < channel_in; ++ci) {
+ i64 u = matIdx(p, ci, channel_in);
+ i64 v = first_fc_id + matIdx(co, ci, channel_in);
+ circuit.bin_gates.emplace_back(g, u, v, 0, 2 * (u8) (layer_id > 1));
+ }
+ }
+
+ readFconWeight(first_fc_id);
+ readBias(first_bias_id);
+ calcNormalLayer(circuit, layer_id);
+ printLayerInfo(circuit, layer_id++);
+}
+
+void
+neuralNetwork::refreshConvParam(i64 new_nx, i64 new_ny, const convKernel &conv) {
+ nx_in = new_nx;
+ ny_in = new_ny;
+ padding = conv.padding;
+ nx_padded_in = nx_in + (conv.padding * 2);
+ ny_padded_in = ny_in + (conv.padding * 2);
+
+ m = conv.size;
+ channel_in = conv.channel_in;
+ channel_out = conv.channel_out;
+ log_stride = conv.stride_bl;
+
+ nx_out = ((nx_padded_in - m) >> log_stride) + 1;
+ ny_out = ((ny_padded_in - m) >> log_stride) + 1;
+
+ new_nx_in = nx_out;
+ new_ny_in = ny_out;
+ conv_layer_cnt = conv.ty == FFT ? FFT_SIZE : conv.ty == NAIVE ? NCONV_SIZE : NCONV_FAST_SIZE;
+}
+
+void neuralNetwork::refreshFCParam(const fconKernel &fc) {
+ nx_in = nx_out = m = 1;
+ ny_in = ny_out = 1;
+ channel_in = fc.channel_in;
+ channel_out = fc.channel_out;
+}
+
+i64 neuralNetwork::getFFTLen() const {
+ return 1L << getFFTBitLen();
+}
+
+i8 neuralNetwork::getFFTBitLen() const {
+ return ceilPow2BitLength( (u32)nx_padded_in * ny_padded_in ) + 1;
+}
+
+// input: [data]
+// [[conv_kernel || relu_conv_bit_decmp]{sec.size()}[max_pool]{if maxPool}[pool_bit_decmp]]{conv_section.size()}
+// [fc_kernel || relu_fc_bit_decmp]
+void neuralNetwork::initParam() {
+ act_layer_cnt = RELU_SIZE;
+ i64 total_conv_layer_cnt = 0, total_pool_layer_cnt = 0;
+ total_in_size = 0;
+ total_para_size = 0;
+ total_relu_in_size = 0;
+ total_ave_in_size = 0;
+ total_max_in_size = 0;
+
+ // data
+ i64 pos = pic_size_x * pic_size_y * pic_channel * pic_parallel;
+
+ new_nx_in = pic_size_x;
+ new_ny_in = pic_size_y;
+ for (i64 i = 0; i < conv_section.size(); ++i) {
+ auto &sec = conv_section[i];
+ for (i64 j = 0; j < sec.size(); ++j) {
+ refreshConvParam(new_nx_in, new_ny_in, sec[j]);
+ // conv_kernel
+ sec[j].weight_start_id = pos;
+ u32 para_size = sqr(m) * channel_in * channel_out;
+ pos += para_size;
+ total_para_size += para_size;
+ fprintf(stderr, "kernel weight: %11d%11lld\n", para_size, total_para_size);
+
+ sec[j].bias_start_id = pos;
+ pos += channel_out;
+ total_para_size += channel_out;
+ fprintf(stderr, "bias weight: %11lld%11lld\n", channel_out, total_para_size);
+ }
+
+ total_conv_layer_cnt += sec.size() * (conv_layer_cnt + act_layer_cnt);
+
+ if (i >= pool.size()) continue;
+ calcSizeAfterPool(pool[i]);
+ total_pool_layer_cnt += pool_layer_cnt;
+ if (pool[i].ty == MAX)
+ if (act_ty == RELU_ACT) total_conv_layer_cnt -= act_layer_cnt;
+ }
+
+ for (int i = 0; i < full_conn.size(); ++i) {
+ auto &fc = full_conn[i];
+ refreshFCParam(fc);
+ // fc_kernel
+ fc.weight_start_id = pos;
+ u32 para_size = channel_out * channel_in;
+ pos += para_size;
+ total_para_size += para_size;
+ fprintf(stderr, "kernel weight: %11d%11lld\n", para_size, total_para_size);
+ fc.bias_start_id = pos;
+ pos += channel_out;
+ total_para_size += channel_out;
+ fprintf(stderr, "bias weight: %11lld%11lld\n", channel_out, total_para_size);
+ if (i == full_conn.size() - 1) break;
+ }
+ total_in_size = pos;
+
+ SIZE = 1 + total_conv_layer_cnt + total_pool_layer_cnt + (FC_SIZE + RELU_SIZE) * full_conn.size();
+ if (!full_conn.empty()) SIZE -= RELU_SIZE;
+ cerr << "SIZE: " << SIZE << endl;
+}
+
+void neuralNetwork::printLayerInfo(const layer &circuit, i64 layer_id) {
+// fprintf(stderr, "+ %2lld " , layer_id);
+// switch (circuit.ty) {
+// case layerType::INPUT: fprintf(stderr, "inputLayer "); break;
+// case layerType::PADDING: fprintf(stderr, "paddingLayer "); break;
+// case layerType::FFT: fprintf(stderr, "fftLayer "); break;
+// case layerType::DOT_PROD: fprintf(stderr, "dotProdLayer "); break;
+// case layerType::IFFT: fprintf(stderr, "ifftLayer "); break;
+// case layerType::ADD_BIAS: fprintf(stderr, "addBiasLayer "); break;
+// case layerType::RELU: fprintf(stderr, "reluActLayer "); break;
+// case layerType::Sqr: fprintf(stderr, "squareActLayer "); break;
+// case layerType::OPT_AVG_POOL: fprintf(stderr, "avgOptPoolingLayer "); break;
+// case layerType::AVG_POOL: fprintf(stderr, "avgPoolingLayer "); break;
+// case layerType::MAX_POOL: fprintf(stderr, "maxPoolingLayer "); break;
+// case layerType::FCONN: fprintf(stderr, "fullyConnLayer "); break;
+// case layerType::NCONV: fprintf(stderr, "naiveConvFast "); break;
+// case layerType::NCONV_MUL: fprintf(stderr, "naiveConvMul "); break;
+// case layerType::NCONV_ADD: fprintf(stderr, "naiveConvAdd "); break;
+//m
+// }
+// fprintf(stderr, "%11u (2^%2d)\n", circuit.size, (int) circuit.bit_length);
+}
+
+void neuralNetwork::printWitnessInfo(const layer &circuit) const {
+ assert(circuit.size == total_in_size);
+ u32 total_data_in_size = total_in_size - total_relu_in_size - total_ave_in_size - total_max_in_size;
+ fprintf(stderr,"%u (2^%2d) = %u (%.2f%% data) + %lld (%.2f%% relu) + %lld (%.2f%% ave) + %lld (%.2f%% max), ",
+ circuit.size, circuit.bit_length, total_data_in_size, 100.0 * total_data_in_size / (double) total_in_size,
+ total_relu_in_size, 100.0 * total_relu_in_size / (double) total_in_size,
+ total_ave_in_size, 100.0 * total_ave_in_size / (double) total_in_size,
+ total_max_in_size, 100.0 * total_max_in_size / (double) total_in_size);
+ output_tb[WS_OUT_ID] = std::to_string(circuit.size) + "(2^" + std::to_string(ceilPow2BitLength(circuit.size)) + ")";
+}
+
+i64 neuralNetwork::getPoolDecmpSize() const {
+ switch (pool_ty) {
+ case AVG: return new_nx_in * new_ny_in * (pool_bl << 1) * channel_out * pic_parallel;
+ case MAX: return new_nx_in * new_ny_in * sqr(pool_sz) * channel_out * pic_parallel * (Q_MAX - 1);
+ default:
+ assert(false);
+ }
+}
+
+void neuralNetwork::calcSizeAfterPool(const poolKernel &p) {
+ pool_sz = p.size;
+ pool_bl = ceilPow2BitLength(pool_sz);
+ pool_stride_bl = p.stride_bl;
+ pool_stride = 1 << p.stride_bl;
+ pool_layer_cnt = p.ty == MAX ? 1 + ceilPow2BitLength(sqr(p.size) + 1) : AVE_POOL_SIZE;
+ new_nx_in = ((nx_out - pool_sz) >> pool_stride_bl) + 1;
+ new_ny_in = ((ny_out - pool_sz) >> pool_stride_bl) + 1;
+}
+
+void neuralNetwork::calcInputLayer(layer &circuit) {
+ val[0].resize(circuit.size);
+
+ assert(val[0].size() == total_in_size);
+ auto val_0 = val[0].begin();
+
+ double num, mx = -10000, mn = 10000;
+ vector input_dat;
+ for (i64 ci = 0; ci < pic_channel; ++ci)
+ for (i64 x = 0; x < pic_size_x; ++x)
+ for (i64 y = 0; y < pic_size_y; ++y) {
+ in >> num;
+ input_dat.push_back(num);
+ mx = max(mx, num);
+ mn = min(mn, num);
+ }
+
+ // (mx - mn) * 2^i <= 2^Q - 1
+ // quant_shr = i
+ x_next_bit = (int) (log( ((1 << (Q - 1)) - 1) / (mx - mn) ) / log(2));
+ if ((int) ((mx - mn) * exp2(x_next_bit)) > (1 << (Q - 1)) - 1) --x_next_bit;
+
+ for (i64 p = 0; p < pic_parallel; ++p) {
+ i64 i = 0;
+ for (i64 ci = 0; ci < pic_channel; ++ci)
+ for (i64 x = 0; x < pic_size_x; ++x)
+ for (i64 y = 0; y < pic_size_y; ++y)
+ *val_0++ = F((i64)(input_dat[i++] * exp2(x_next_bit)));
+ }
+ for (; val_0 < val[0].begin() + circuit.size; ++val_0) val_0 -> clear();
+}
+
+
+void neuralNetwork::readConvWeight(i64 first_conv_id) {
+ auto val_0 = val[0].begin() + first_conv_id;
+
+ double num, mx = -10000, mn = 10000;
+ vector input_dat;
+ for (i64 co = 0; co < channel_out; ++co)
+ for (i64 ci = 0; ci < channel_in; ++ci)
+ for (i64 x = 0; x < m; ++x)
+ for (i64 y = 0; y < m; ++y) {
+ in >> num;
+ input_dat.push_back(num);
+ mx = max(mx, num);
+ mn = min(mn, num);
+ }
+
+ // (mx - mn) * 2^i <= 2^Q - 1
+ // quant_shr = i
+ w_bit = (int) (log( ((1 << (Q - 1)) - 1) / (mx - mn) ) / log(2));
+ if ((int) ((mx - mn) * exp2(w_bit)) > (1 << (Q - 1)) - 1) --w_bit;
+
+ for (double i : input_dat) *val_0++ = F((i64) (i * exp2(w_bit)));
+
+}
+
+void neuralNetwork::readBias(i64 first_bias_id) {
+ auto val_0 = val[0].begin() + first_bias_id;
+
+ double num, mx = -10000, mn = 10000;
+ vector input_dat;
+ for (i64 co = 0; co < channel_out; ++co) {
+ in >> num;
+ input_dat.push_back(num);
+ mx = max(mx, num);
+ mn = min(mn, num);
+ }
+
+ for (double i : input_dat) *val_0++ = F((i64) (i * exp2(w_bit + x_bit)));
+
+}
+
+void neuralNetwork::readFconWeight(i64 first_fc_id) {
+ double num, mx = -10000, mn = 10000;
+ auto val_0 = val[0].begin() + first_fc_id;
+
+ vector input_dat;
+ for (i64 co = 0; co < channel_out; ++co)
+ for (i64 ci = 0; ci < channel_in; ++ci) {
+ in >> num;
+ input_dat.push_back(num);
+ mx = max(mx, num);
+ mn = min(mn, num);
+ }
+
+ // (mx - mn) * 2^i <= 2^Q - 1
+ // quant_shr = i
+ w_bit = (int) (log( ((1 << (Q - 1)) - 1) / (mx - mn) ) / log(2));
+ if ((int) ((mx - mn) * exp2(w_bit)) > (1 << (Q - 1)) - 1) --w_bit;
+
+ for (double i : input_dat) *val_0++ = F((i64) (i * exp2(w_bit)));
+}
+
+void neuralNetwork::prepareDecmpBit(i64 layer_id, i64 idx, i64 dcmp_id, i64 bit_shift) {
+ auto data = abs(val[layer_id].at(idx).getInt64());
+ val[0].at(dcmp_id) = (data >> bit_shift) & 1;
+}
+
+void neuralNetwork::prepareFieldBit(const F &data, i64 dcmp_id, i64 bit_shift) {
+ auto tmp = abs(data.getInt64());
+ val[0].at(dcmp_id) = (tmp >> bit_shift) & 1;
+}
+
+void neuralNetwork::prepareSignBit(i64 layer_id, i64 idx, i64 dcmp_id) {
+ val[0].at(dcmp_id) = val[layer_id].at(idx).isNegative() ? F_ONE : F_ZERO;
+}
+
+void neuralNetwork::prepareMax(i64 layer_id, i64 idx, i64 max_id) {
+ auto data = val[layer_id].at(idx).isNegative() ? F_ZERO : val[layer_id].at(idx);
+ if (data > val[0].at(max_id)) val[0].at(max_id) = data;
+}
+
+void neuralNetwork::calcNormalLayer(const layer &circuit, i64 layer_id) {
+ val[layer_id].resize(circuit.size);
+ for (auto &x: val[layer_id]) x.clear();
+
+ for (auto &gate: circuit.uni_gates) {
+ val[layer_id].at(gate.g) = val[layer_id].at(gate.g) + val[gate.lu].at(gate.u) * two_mul[gate.sc];
+ }
+
+
+ for (auto &gate: circuit.bin_gates) {
+ u8 bin_lu = gate.getLayerIdU(layer_id), bin_lv = gate.getLayerIdV(layer_id);
+ val[layer_id].at(gate.g) = val[layer_id].at(gate.g) + val[bin_lu].at(gate.u) * val[bin_lv][gate.v] * two_mul[gate.sc];
+ }
+
+ F mx_val = F_ZERO, mn_val = F_ZERO;
+ for (i64 g = 0; g < circuit.size; ++g)
+ val[layer_id].at(g) = val[layer_id].at(g) * circuit.scale;
+}
+
+void neuralNetwork::calcDotProdLayer(const layer &circuit, i64 layer_id) {
+ val[layer_id].resize(circuit.size);
+ for (int i = 0; i < circuit.size; ++i) val[layer_id][i].clear();
+
+ char fft_bit = circuit.fft_bit_length;
+ u32 fft_len = 1 << fft_bit;
+ u8 l = layer_id - 1;
+ for (auto &gate: circuit.bin_gates)
+ for (int s = 0; s < fft_len; ++s)
+ val[layer_id][gate.g << fft_bit | s] = val[layer_id][gate.g << fft_bit | s] +
+ val[l][gate.u << fft_bit | s] * val[l][gate.v << fft_bit | s];
+}
+
+void neuralNetwork::calcFFTLayer(const layer &circuit, i64 layer_id) {
+ i64 fft_len = 1ULL << circuit.fft_bit_length;
+ i64 fft_lenh = fft_len >> 1;
+ val[layer_id].resize(circuit.size);
+ std::vector arr(fft_len, F_ZERO);
+ if (circuit.ty == layerType::FFT) for (i64 c = 0, d = 0; d < circuit.size; c += fft_lenh, d += fft_len) {
+ for (i64 j = c; j < c + fft_lenh; ++j) arr[j - c] = val[layer_id - 1].at(j);
+ for (i64 j = fft_lenh; j < fft_len; ++j) arr[j].clear();
+ fft(arr, circuit.fft_bit_length, circuit.ty == layerType::IFFT);
+ for (i64 j = d; j < d + fft_len; ++j) val[layer_id].at(j) = arr[j - d];
+ } else for (u32 c = 0, d = 0; c < circuit.size; c += fft_lenh, d += fft_len) {
+ for (i64 j = d; j < d + fft_len; ++j) arr[j - d] = val[layer_id - 1].at(j);
+ fft(arr, circuit.fft_bit_length, circuit.ty == layerType::IFFT);
+ for (i64 j = c; j < c + fft_lenh; ++j) val[layer_id].at(j) = arr[j - c];
+ }
+}
+
+int neuralNetwork::getNextBit(int layer_id) {
+ F mx = F_ZERO, mn = F_ZERO;
+ for (const auto &x: val[layer_id]) {
+ if (!x.isNegative()) mx = max(mx, x);
+ else mn = max(mn, -x);
+ }
+ i64 x = (mx + mn).getInt64();
+ double real_scale = x / exp2(x_bit + w_bit);
+ int res = (int) log2( ((1 << (Q - 1)) - 1) / real_scale );
+ return res;
+}
+
+void neuralNetwork::printLayerValues(prover &pr) {
+ for (i64 i = 0; i < SIZE; ++i) {
+// if (pr.C.circuit[i].ty == layerType::FCONN || pr.C.circuit[i].ty == layerType::ADD_BIAS || i && i < SIZE - 1 && pr.C.circuit[i + 1].ty == layerType::PADDING) {
+ cerr << i << "(" << pr.C.circuit[i].zero_start_id << ", " << pr.C.circuit[i].size << "):\t";
+ for (i64 j = 0; j < std::min(200u, pr.C.circuit[i].size); ++j)
+ if (!pr.val[i][j].isZero()) cerr << pr.val[i][j] << ' ';
+ cerr << endl;
+ for (i64 j = pr.C.circuit[i].zero_start_id; j < pr.C.circuit[i].size; ++j)
+ if (pr.val[i].at(j) != F_ZERO) {
+ cerr << "WRONG! " << i << ' ' << j << ' ' << (-pr.val[i][j] * F_ONE) << endl;
+ exit(EXIT_FAILURE);
+ }
+ }
+}
+
+void neuralNetwork::printInfer(prover &pr) {
+ // output the inference result with the size of (pic_parallel x n_class)
+ if (out.is_open()) {
+ int n_class = full_conn.back().channel_out;
+ for (int p = 0; p < pic_parallel; ++p) {
+ int k = -1;
+ F v;
+ for (int c = 0; c < n_class; ++c) {
+ auto tmp = val[SIZE - 1].at(matIdx(p, c, n_class));
+ if (!tmp.isNegative() && (k == -1 || v < tmp)) {
+ k = c;
+ v = tmp;
+ }
+ }
+ out << k << endl;
+
+ // output one-hot
+// for (int c = 0; c < n_class; ++c) out << (k == c) << ' ';
+// out << endl;
+ }
+ }
+ out.close();
+ printWitnessInfo(pr.C.circuit[0]);
+}
\ No newline at end of file
diff --git a/src/neuralNetwork.hpp b/src/neuralNetwork.hpp
new file mode 100644
index 0000000..acfcc99
--- /dev/null
+++ b/src/neuralNetwork.hpp
@@ -0,0 +1,169 @@
+//
+// Created by 69029 on 3/16/2021.
+//
+
+#ifndef ZKCNN_NEURALNETWORK_HPP
+#define ZKCNN_NEURALNETWORK_HPP
+
+#include
+#include
+#include "circuit.h"
+#include "prover.hpp"
+
+using std::vector;
+using std::tuple;
+using std::pair;
+
+enum convType {
+ FFT, NAIVE, NAIVE_FAST
+};
+
+struct convKernel {
+ convType ty;
+ i64 channel_out, channel_in, size, stride_bl, padding, weight_start_id, bias_start_id;
+ convKernel(convType _ty, i64 _channel_out, i64 _channel_in, i64 _size, i64 _log_stride, i64 _padding) :
+ ty(_ty), channel_out(_channel_out), channel_in(_channel_in), size(_size), stride_bl(_log_stride), padding(_padding){}
+
+ convKernel(convType _ty, i64 _channel_out, i64 _channel_in, i64 _size):
+ convKernel(_ty, _channel_out, _channel_in, _size, 0, _size >> 1) {}
+};
+
+struct fconKernel {
+ i64 channel_out, channel_in, weight_start_id, bias_start_id;
+ fconKernel(i64 _channel_out, i64 _channel_in):
+ channel_out(_channel_out), channel_in(_channel_in) {}
+};
+
+enum poolType {
+ AVG, MAX, NONE
+};
+
+enum actType {
+ RELU_ACT
+};
+
+struct poolKernel {
+ poolType ty;
+ i64 size, stride_bl, dcmp_start_id, max_start_id, max_dcmp_start_id;
+ poolKernel(poolType _ty, i64 _size, i64 _log_stride):
+ ty(_ty), size(_size), stride_bl(_log_stride) {}
+};
+
+
+class neuralNetwork {
+public:
+ explicit neuralNetwork(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, const string &i_filename,
+ const string &c_filename, const string &o_filename);
+
+ neuralNetwork(i64 psize, i64 pchannel, i64 pparallel, i64 kernel_size, i64 sec_size, i64 fc_size,
+ i64 start_channel, convType conv_ty, poolType pool_ty);
+
+ void create(prover &pr, bool only_compute);
+
+protected:
+
+ void initParam();
+
+ int getNextBit(int layer_id);
+
+ void refreshConvParam(i64 new_nx, i64 new_ny, const convKernel &conv);
+
+ void calcSizeAfterPool(const poolKernel &p);
+
+ void refreshFCParam(const fconKernel &fc);
+
+ [[nodiscard]] i64 getFFTLen() const;
+
+ [[nodiscard]] i8 getFFTBitLen() const;
+
+ [[nodiscard]] i64 getPoolDecmpSize() const;
+
+ void prepareDecmpBit(i64 layer_id, i64 idx, i64 dcmp_id, i64 bit_shift);
+
+ void prepareFieldBit(const F &data, i64 dcmp_id, i64 bit_shift);
+
+ void prepareSignBit(i64 layer_id, i64 idx, i64 dcmp_id);
+
+ void prepareMax(i64 layer_id, i64 idx, i64 max_id);
+
+ void calcInputLayer(layer &circuit);
+
+ void calcNormalLayer(const layer &circuit, i64 layer_id);
+
+ void calcDotProdLayer(const layer &circuit, i64 layer_id);
+
+ void calcFFTLayer(const layer &circuit, i64 layer_id);
+
+ vector> conv_section;
+ vector pool;
+ poolType pool_ty;
+ i64 pool_bl, pool_sz;
+ i64 pool_stride_bl, pool_stride;
+ i64 pool_layer_cnt, act_layer_cnt, conv_layer_cnt;
+ actType act_ty;
+
+ vector full_conn;
+
+ i64 pic_size_x, pic_size_y, pic_channel, pic_parallel;
+ i64 SIZE;
+ const i64 NCONV_FAST_SIZE, NCONV_SIZE, FFT_SIZE, AVE_POOL_SIZE, FC_SIZE, RELU_SIZE;
+ i64 T;
+ const i64 Q = 9;
+ i64 Q_MAX;
+ const i64 Q_BIT_SIZE = 220;
+
+ i64 nx_in, nx_out, ny_in, ny_out, m, channel_in, channel_out, log_stride, padding;
+ i64 new_nx_in, new_ny_in;
+ i64 nx_padded_in, ny_padded_in;
+ i64 total_in_size, total_para_size, total_relu_in_size, total_ave_in_size, total_max_in_size;
+ int x_bit, w_bit, x_next_bit;
+
+ vector>::iterator val;
+ vector::iterator two_mul;
+
+ void inputLayer(layer &circuit);
+
+ void paddingLayer(layer &circuit, i64 &layer_id, i64 first_conv_id);
+
+ void fftLayer(layer &circuit, i64 &layer_id);
+
+ void dotProdLayer(layer &circuit, i64 &layer_id);
+
+ void ifftLayer(layer &circuit, i64 &layer_id);
+
+ void addBiasLayer(layer &circuit, i64 &layer_id, i64 first_bias_id);
+
+ void naiveConvLayerFast(layer &circuit, i64 &layer_id, i64 first_conv_id, i64 first_bias_id);
+
+ void naiveConvLayerMul(layer &circuit, i64 &layer_id, i64 first_conv_id);
+
+ void naiveConvLayerAdd(layer &circuit, i64 &layer_id, i64 first_bias_id);
+
+ void reluActConvLayer(layer &circuit, i64 &layer_id);
+
+ void reluActFconLayer(layer &circuit, i64 &layer_id);
+
+ void avgPoolingLayer(layer &circuit, i64 &layer_id);
+
+ void
+ maxPoolingLayer(layeredCircuit &C, i64 &layer_id, i64 first_dcmp_id, i64 first_max_id, i64 first_max_dcmp_id);
+
+ void fullyConnLayer(layer &circuit, i64 &layer_id, i64 first_fc_id, i64 first_bias_id);
+
+ static void printLayerInfo(const layer &circuit, i64 layer_id);
+
+ void readBias(i64 first_bias_id);
+
+ void readConvWeight(i64 first_conv_id);
+
+ void readFconWeight(i64 first_fc_id);
+
+ void printWitnessInfo(const layer &circuit) const;
+
+ void printLayerValues(prover &pr);
+
+ void printInfer(prover &pr);
+};
+
+
+#endif //ZKCNN_NEURALNETWORK_HPP
diff --git a/src/polynomial.cpp b/src/polynomial.cpp
new file mode 100644
index 0000000..5b0fe3f
--- /dev/null
+++ b/src/polynomial.cpp
@@ -0,0 +1,130 @@
+#include
+#include "polynomial.h"
+
+quintuple_poly::quintuple_poly() { a.clear(); b.clear(); c.clear(); d.clear(); e.clear(); f.clear();}
+quintuple_poly::quintuple_poly(const F &aa, const F &bb, const F &cc, const F &dd, const F &ee, const F &ff) {
+ a = aa;
+ b = bb;
+ c = cc;
+ d = dd;
+ e = ee;
+ f = ff;
+}
+
+quintuple_poly quintuple_poly::operator + (const quintuple_poly &x) const {
+ return quintuple_poly(a + x.a, b + x.b, c + x.c, d + x.d, e + x.e, f + x.f);
+}
+
+F quintuple_poly::eval(const F &x) const {
+ return (((((a * x) + b) * x + c) * x + d) * x + e) * x + f;
+}
+
+void quintuple_poly::clear() {
+ a.clear(); b.clear(); c.clear(); d.clear(); e.clear(); f.clear();
+}
+
+quadruple_poly::quadruple_poly() {a.clear(); b.clear(); c.clear(); d.clear(); e.clear();}
+quadruple_poly::quadruple_poly(const F &aa, const F &bb, const F &cc, const F &dd, const F &ee) {
+ a = aa;
+ b = bb;
+ c = cc;
+ d = dd;
+ e = ee;
+}
+
+quadruple_poly quadruple_poly::operator + (const quadruple_poly &x) const {
+ return quadruple_poly(a + x.a, b + x.b, c + x.c, d + x.d, e + x.e);
+}
+
+F quadruple_poly::eval(const F &x) const {
+ return ((((a * x) + b) * x + c) * x + d) * x + e;
+}
+
+void quadruple_poly::clear() {
+ a.clear(); b.clear(); c.clear(); d.clear(); e.clear();
+}
+
+cubic_poly::cubic_poly() {a.clear(); b.clear(); c.clear(); d.clear();}
+cubic_poly::cubic_poly(const F &aa, const F &bb, const F &cc, const F &dd) {
+ a = aa;
+ b = bb;
+ c = cc;
+ d = dd;
+}
+
+cubic_poly cubic_poly::operator + (const cubic_poly &x) const {
+ return cubic_poly(a + x.a, b + x.b, c + x.c, d + x.d);
+}
+
+F cubic_poly::eval(const F &x) const {
+ return (((a * x) + b) * x + c) * x + d;
+}
+
+quadratic_poly::quadratic_poly() {a.clear(); b.clear(); c.clear();}
+quadratic_poly::quadratic_poly(const F &aa, const F &bb, const F &cc) {
+ a = aa;
+ b = bb;
+ c = cc;
+}
+
+quadratic_poly quadratic_poly::operator + (const quadratic_poly &x) const {
+ return quadratic_poly(a + x.a, b + x.b, c + x.c);
+}
+
+quadratic_poly quadratic_poly::operator+(const linear_poly &x) const {
+ return quadratic_poly(a, b + x.a, c + x.b);
+}
+
+cubic_poly quadratic_poly::operator * (const linear_poly &x) const {
+ return cubic_poly(a * x.a, a * x.b + b * x.a, b * x.b + c * x.a, c * x.b);
+}
+
+cubic_poly cubic_poly::operator * (const F &x) const {
+ return cubic_poly(a * x, b * x, c * x, d * x);
+}
+
+void cubic_poly::clear() {
+ a.clear(); b.clear(); c.clear(); d.clear();
+}
+
+quadratic_poly quadratic_poly::operator*(const F &x) const {
+ return quadratic_poly(a * x, b * x, c * x);
+}
+
+F quadratic_poly::eval(const F &x) const {
+ return ((a * x) + b) * x + c;
+}
+
+void quadratic_poly::clear() {
+ a.clear(); b.clear(); c.clear();
+}
+
+linear_poly::linear_poly() {a.clear(); b.clear();}
+linear_poly::linear_poly(const F &aa, const F &bb) {
+ a = aa;
+ b = bb;
+}
+linear_poly::linear_poly(const F &x) {
+ a.clear();
+ b = x;
+}
+
+linear_poly linear_poly::operator + (const linear_poly &x) const {
+ return linear_poly(a + x.a, b + x.b);
+}
+
+quadratic_poly linear_poly::operator * (const linear_poly &x) const {
+ return quadratic_poly(a * x.a, a * x.b + b * x.a, b * x.b);
+}
+
+linear_poly linear_poly::operator*(const F &x) const {
+ return linear_poly(a * x, b * x);
+}
+
+F linear_poly::eval(const F &x) const {
+ return a * x + b;
+}
+
+void linear_poly::clear() {
+ a.clear(); b.clear();
+}
diff --git a/src/polynomial.h b/src/polynomial.h
new file mode 100644
index 0000000..cc147f2
--- /dev/null
+++ b/src/polynomial.h
@@ -0,0 +1,72 @@
+#pragma once
+
+#include
+#include
+#include "global_var.hpp"
+
+class linear_poly;
+
+//ax^3 + bx^2 + cx + d
+class cubic_poly {
+public:
+ F a, b, c, d;
+ cubic_poly();
+ cubic_poly(const F &, const F &, const F &, const F &);
+ cubic_poly operator + (const cubic_poly &) const;
+ cubic_poly operator * (const F &) const;
+ F eval(const F &) const;
+ void clear();
+};
+
+//ax^2 + bx + c
+class quadratic_poly {
+public:
+ F a, b, c;
+ quadratic_poly();
+ quadratic_poly(const F &, const F &, const F &);
+ quadratic_poly operator + (const quadratic_poly &) const;
+ quadratic_poly operator + (const linear_poly &) const;
+ cubic_poly operator * (const linear_poly &) const;
+ quadratic_poly operator * (const F &) const;
+ F eval(const F &) const;
+ void clear();
+};
+
+
+//ax + b
+class linear_poly {
+public:
+ F a, b;
+ linear_poly();
+ linear_poly(const F &, const F &);
+ linear_poly(const F &);
+ linear_poly operator + (const linear_poly &) const;
+ quadratic_poly operator * (const linear_poly &) const;
+ linear_poly operator * (const F &) const;
+ F eval(const F &) const;
+ void clear();
+};
+
+
+
+//ax^4 + bx^3 + cx^2 + dx + e
+class quadruple_poly {
+public:
+ F a, b, c, d, e;
+ quadruple_poly();
+ quadruple_poly(const F &, const F &, const F &, const F &, const F &);
+ quadruple_poly operator + (const quadruple_poly &) const;
+ F eval(const F &) const;
+ void clear();
+};
+
+//ax^5 + bx^4 + cx^3 + dx^2 + ex + f
+class quintuple_poly {
+public:
+ F a, b, c, d, e, f;
+ quintuple_poly();
+ quintuple_poly(const F &, const F &, const F &, const F &, const F &, const F &);
+ quintuple_poly operator + (const quintuple_poly &) const;
+ F eval(const F &) const;
+ void clear();
+};
\ No newline at end of file
diff --git a/src/prover.cpp b/src/prover.cpp
new file mode 100644
index 0000000..0c99714
--- /dev/null
+++ b/src/prover.cpp
@@ -0,0 +1,511 @@
+//
+// Created by 69029 on 3/9/2021.
+//
+
+#include "prover.hpp"
+#include
+#include
+
+static vector beta_gs, beta_u;
+
+using std::unique_ptr;
+
+linear_poly interpolate(const F &zero_v, const F &one_v) {
+ return {one_v - zero_v, zero_v};
+}
+
+void prover::init() {
+ proof_size = 0;
+ r_u.resize(C.size + 1);
+ r_v.resize(C.size + 1);
+}
+
+/**
+ * This is to initialize all process.
+ *
+ * @param the random point to be evaluated at the output layer
+ */
+void prover::sumcheckInitAll(const vector::const_iterator &r_0_from_v) {
+ sumcheck_id = C.size;
+ i8 last_bl = C.circuit[sumcheck_id - 1].bit_length;
+ r_u[sumcheck_id].resize(last_bl);
+
+ prove_timer.start();
+ for (int i = 0; i < last_bl; ++i) r_u[sumcheck_id][i] = r_0_from_v[i];
+ prove_timer.stop();
+}
+
+/**
+ * This is to initialize before the process of a single layer.
+ *
+ * @param the random combination coefficiants for multiple reduction points
+ */
+void prover::sumcheckInit(const F &alpha_0, const F &beta_0) {
+ prove_timer.start();
+ auto &cur = C.circuit[sumcheck_id];
+ alpha = alpha_0;
+ beta = beta_0;
+ r_0 = r_u[sumcheck_id].begin();
+ r_1 = r_v[sumcheck_id].begin();
+ --sumcheck_id;
+ prove_timer.stop();
+}
+
+/**
+ * This is to initialize before the phase 1 of a single inner production layer.
+ */
+void prover::sumcheckDotProdInitPhase1() {
+ fprintf(stderr, "sumcheck level %d, phase1 init start\n", sumcheck_id);
+
+ auto &cur = C.circuit[sumcheck_id];
+ i8 fft_bl = cur.fft_bit_length;
+ i8 cnt_bl = cur.bit_length - fft_bl;
+ total[0] = 1ULL << fft_bl;
+ total[1] = 1ULL << cur.bit_length_u[1];
+ total_size[1] = cur.size_u[1];
+ u32 fft_len = total[0];
+
+ r_u[sumcheck_id].resize(cur.max_bl_u);
+ V_mult[0].resize(total[1]);
+ V_mult[1].resize(total[1]);
+ mult_array[1].resize(total[0]);
+ beta_gs.resize(1ULL << fft_bl);
+
+ prove_timer.start();
+
+ initBetaTable(beta_gs, fft_bl, r_0, F_ONE);
+
+ for (u32 t = 0; t < fft_len; ++t)
+ mult_array[1][t] = beta_gs[t];
+ for (u32 u = 0; u < total[1]; ++u) {
+ V_mult[0][u].clear();
+ if (u >= cur.size_u[1]) V_mult[1][u].clear();
+ else V_mult[1][u] = val[sumcheck_id - 1][u];
+ }
+
+ for (auto &gate: cur.bin_gates)
+ for (u32 t = 0; t < fft_len; ++t) {
+ u32 idx_u = gate.u << fft_bl | t;
+ u32 idx_v = gate.v << fft_bl | t;
+ V_mult[0][idx_u] = V_mult[0][idx_u] + beta_g[gate.g] * val[sumcheck_id - 1][idx_v];
+ }
+
+ round = 0;
+ prove_timer.stop();
+}
+
+/**
+ * This is the one-step reduction within a sumcheck process of a single inner production layer.
+ *
+ * @param the random point of the reduction of the previous step
+ * @return the reducted cubic degree polynomial of the current variable from prover to verifier
+ */
+cubic_poly prover::sumcheckDotProdUpdate1(const F &previous_random) {
+ prove_timer.start();
+
+ if (round) r_u[sumcheck_id].at(round - 1) = previous_random;
+ ++round;
+
+ auto &tmp_mult = mult_array[1];
+ auto &tmp_v0 = V_mult[0], &tmp_v1 = V_mult[1];
+
+ if (total[0] == 1)
+ tmp_mult[0] = tmp_mult[0].eval(previous_random);
+ else for (u32 i = 0; i < (total[0] >> 1); ++i) {
+ u32 g0 = i << 1, g1 = i << 1 | 1;
+ tmp_mult[i] = interpolate(tmp_mult[g0].eval(previous_random), tmp_mult[g1].eval(previous_random));
+ }
+ total[0] >>= 1;
+
+ cubic_poly ret;
+ for (u32 i = 0; i < (total[1] >> 1); ++i) {
+ u32 g0 = i << 1, g1 = i << 1 | 1;
+ if (g0 >= total_size[1]) {
+ tmp_v0[i].clear();
+ tmp_v1[i].clear();
+ continue;
+ }
+ if (g1 >= total_size[1]) {
+ tmp_v0[g1].clear();
+ tmp_v1[g1].clear();
+ }
+ tmp_v0[i] = interpolate(tmp_v0[g0].eval(previous_random), tmp_v0[g1].eval(previous_random));
+ tmp_v1[i] = interpolate(tmp_v1[g0].eval(previous_random), tmp_v1[g1].eval(previous_random));
+ if (total[0]) ret = ret + tmp_mult[i & total[0] - 1] * tmp_v1[i] * tmp_v0[i];
+ else ret = ret + tmp_mult[0] * tmp_v1[i] * tmp_v0[i];
+ }
+ proof_size += F_BYTE_SIZE * (3 + (!ret.a.isZero()));
+
+ total[1] >>= 1;
+ total_size[1] = (total_size[1] + 1) >> 1;
+
+ prove_timer.stop();
+ return ret;
+}
+
+void prover::sumcheckDotProdFinalize1(const F &previous_random, F &claim_1) {
+ prove_timer.start();
+ r_u[sumcheck_id].at(round - 1) = previous_random;
+ claim_1 = V_mult[1][0].eval(previous_random);
+ V_u1 = V_mult[1][0].eval(previous_random) * mult_array[1][0].eval(previous_random);
+ prove_timer.stop();
+ proof_size += F_BYTE_SIZE * 1;
+}
+
+void prover::sumcheckInitPhase1(const F &relu_rou_0) {
+ fprintf(stderr, "sumcheck level %d, phase1 init start\n", sumcheck_id);
+
+ auto &cur = C.circuit[sumcheck_id];
+ total[0] = ~cur.bit_length_u[0] ? 1ULL << cur.bit_length_u[0] : 0;
+ total_size[0] = cur.size_u[0];
+ total[1] = ~cur.bit_length_u[1] ? 1ULL << cur.bit_length_u[1] : 0;
+ total_size[1] = cur.size_u[1];
+
+ r_u[sumcheck_id].resize(cur.max_bl_u);
+ V_mult[0].resize(total[0]);
+ V_mult[1].resize(total[1]);
+ mult_array[0].resize(total[0]);
+ mult_array[1].resize(total[1]);
+ beta_g.resize(1ULL << cur.bit_length);
+ if (cur.ty == layerType::PADDING) beta_gs.resize(1ULL << cur.fft_bit_length);
+ if (cur.ty == layerType::FFT || cur.ty == layerType::IFFT)
+ beta_gs.resize(total[1]);
+
+ prove_timer.start();
+
+ relu_rou = relu_rou_0;
+ add_term.clear();
+ for (int b = 0; b < 2; ++b)
+ for (u32 u = 0; u < total[b]; ++u)
+ mult_array[b][u].clear();
+
+ if (cur.ty == layerType::FFT || cur.ty == layerType::IFFT) {
+ i8 fft_bl = cur.fft_bit_length;
+ i8 fft_blh = cur.fft_bit_length - 1;
+ i8 cnt_bl = cur.ty == layerType::FFT ? cur.bit_length - fft_bl : cur.bit_length - fft_blh;
+ u32 cnt_len = cur.size >> (cur.ty == layerType::FFT ? fft_bl : fft_blh);
+ if (cur.ty == layerType::FFT)
+ initBetaTable(beta_g, cnt_bl, r_0 + fft_bl, r_1, alpha, beta);
+ else initBetaTable(beta_g, cnt_bl, r_0 + fft_blh, alpha);
+ for (u32 u = 0, l = sumcheck_id - 1; u < total[1]; ++u) {
+ V_mult[1][u].clear();
+ if (u >= cur.size_u[1]) continue;
+ for (u32 g = 0; g < cnt_len; ++g) {
+ u32 idx = g << cur.max_bl_u | u;
+ V_mult[1][u] = V_mult[1][u] + val[l][idx] * beta_g[g];
+ }
+ }
+
+ beta_gs.resize(total[1]);
+ phiGInit(beta_gs, r_0, cur.scale, fft_bl, cur.ty == layerType::IFFT);
+ for (u32 u = 0; u < total[1] ; ++u) {
+ mult_array[1][u] = beta_gs[u];
+ }
+ } else {
+ for (int b = 0; b < 2; ++b) {
+ auto dep = !b ? 0 : sumcheck_id - 1;
+ for (u32 u = 0; u < total[b]; ++u) {
+ if (u >= cur.size_u[b])
+ V_mult[b][u].clear();
+ else V_mult[b][u] = getCirValue(dep, cur.ori_id_u, u);
+ }
+ }
+
+ if (cur.ty == layerType::PADDING) {
+ i8 fft_blh = cur.fft_bit_length - 1;
+ u32 fft_lenh = 1ULL << fft_blh;
+ initBetaTable(beta_gs, fft_blh, r_0, F_ONE);
+ for (long g = (1L << cur.bit_length) - 1; g >= 0; --g)
+ beta_g[g] = beta_g[g >> fft_blh] * beta_gs[g & fft_lenh - 1];
+ } else initBetaTable(beta_g, cur.bit_length, r_0, r_1, alpha * cur.scale, beta * cur.scale);
+ if (cur.zero_start_id < cur.size)
+ for (u32 g = cur.zero_start_id; g < 1ULL << cur.bit_length; ++g) beta_g[g] = beta_g[g] * relu_rou;
+
+ for (auto &gate: cur.uni_gates) {
+ bool idx = gate.lu != 0;
+ mult_array[idx][gate.u] = mult_array[idx][gate.u] + beta_g[gate.g] * C.two_mul[gate.sc];
+ }
+
+ for (auto &gate: cur.bin_gates) {
+ bool idx = gate.getLayerIdU(sumcheck_id) != 0;
+ auto val_lv = getCirValue(gate.getLayerIdV(sumcheck_id), cur.ori_id_v, gate.v);
+ mult_array[idx][gate.u] = mult_array[idx][gate.u] + val_lv * beta_g[gate.g] * C.two_mul[gate.sc];
+ }
+ }
+
+ round = 0;
+ prove_timer.stop();
+ fprintf(stderr, "sumcheck level %d, phase1 init finished\n", sumcheck_id);
+}
+
+void prover::sumcheckInitPhase2() {
+ fprintf(stderr, "sumcheck level %d, phase2 init start\n", sumcheck_id);
+
+ auto &cur = C.circuit[sumcheck_id];
+ total[0] = ~cur.bit_length_v[0] ? 1ULL << cur.bit_length_v[0] : 0;
+ total_size[0] = cur.size_v[0];
+ total[1] = ~cur.bit_length_v[1] ? 1ULL << cur.bit_length_v[1] : 0;
+ total_size[1] = cur.size_v[1];
+ i8 fft_bl = cur.fft_bit_length;
+ i8 cnt_bl = cur.max_bl_v;
+
+ r_v[sumcheck_id].resize(cur.max_bl_v);
+
+ V_mult[0].resize(total[0]);
+ V_mult[1].resize(total[1]);
+ mult_array[0].resize(total[0]);
+ mult_array[1].resize(total[1]);
+
+ if (cur.ty == layerType::DOT_PROD) {
+ beta_u.resize(1ULL << cnt_bl);
+ beta_gs.resize(1ULL << fft_bl);
+ } else beta_u.resize(1ULL << cur.max_bl_u);
+
+ prove_timer.start();
+
+ add_term.clear();
+ for (int b = 0; b < 2; ++b) {
+ for (u32 v = 0; v < total[b]; ++v)
+ mult_array[b][v].clear();
+ }
+
+ if (cur.ty == layerType::DOT_PROD) {
+ u32 fft_len = 1ULL << cur.fft_bit_length;
+ initBetaTable(beta_u, cnt_bl, r_u[sumcheck_id].begin() + fft_bl, F_ONE);
+ initBetaTable(beta_gs, fft_bl, r_u[sumcheck_id].begin(), F_ONE);
+
+ for (u32 v = 0; v < total[1]; ++v) {
+ V_mult[1][v].clear();
+ if (v >= cur.size_v[1]) continue;
+ for (u32 t = 0; t < fft_len; ++t) {
+ u32 idx_v = (v << fft_bl) | t;
+ V_mult[1][v] = V_mult[1][v] + val[sumcheck_id - 1][idx_v] * beta_gs[t];
+ }
+ }
+
+ for (auto &gate: cur.bin_gates)
+ mult_array[1][gate.v] =
+ mult_array[1][gate.v] + beta_g[gate.g] * beta_u[gate.u] * V_u1;
+ } else {
+ initBetaTable(beta_u, cur.max_bl_u, r_u[sumcheck_id].begin(), F_ONE);
+ for (int b = 0; b < 2; ++b) {
+ auto dep = !b ? 0 : sumcheck_id - 1;
+ for (u32 v = 0; v < total[b]; ++v) {
+ V_mult[b][v] = v >= cur.size_v[b] ? F_ZERO : getCirValue(dep, cur.ori_id_v, v);
+ }
+ }
+ for (auto &gate: cur.uni_gates) {
+ auto V_u = !gate.lu ? V_u0 : V_u1;
+ add_term = add_term + beta_g[gate.g] * beta_u[gate.u] * V_u * C.two_mul[gate.sc];
+ }
+ for (auto &gate: cur.bin_gates) {
+ bool idx = gate.getLayerIdV(sumcheck_id);
+ auto V_u = !gate.getLayerIdU(sumcheck_id) ? V_u0 : V_u1;
+ mult_array[idx][gate.v] = mult_array[idx][gate.v] + beta_g[gate.g] * beta_u[gate.u] * V_u * C.two_mul[gate.sc];
+ }
+ }
+
+ round = 0;
+ prove_timer.stop();
+}
+
+void prover::sumcheckLiuInit(const vector &s_u, const vector &s_v) {
+ sumcheck_id = 0;
+ total[1] = (1ULL << C.circuit[sumcheck_id].bit_length);
+ total_size[1] = C.circuit[sumcheck_id].size;
+
+ r_u[0].resize(C.circuit[0].bit_length);
+ mult_array[1].resize(total[1]);
+ V_mult[1].resize(total[1]);
+
+ i8 max_bl = 0;
+ for (int i = sumcheck_id + 1; i < C.size; ++i)
+ max_bl = max(max_bl, max(C.circuit[i].bit_length_u[0], C.circuit[i].bit_length_v[0]));
+ beta_g.resize(1ULL << max_bl);
+
+ prove_timer.start();
+ add_term.clear();
+
+ for (u32 g = 0; g < total[1]; ++g) {
+ mult_array[1][g].clear();
+ V_mult[1][g] = (g < total_size[1]) ? val[0][g] : F_ZERO;
+ }
+
+ for (u8 i = sumcheck_id + 1; i < C.size; ++i) {
+ i8 bit_length_i = C.circuit[i].bit_length_u[0];
+ u32 size_i = C.circuit[i].size_u[0];
+ if (~bit_length_i) {
+ initBetaTable(beta_g, bit_length_i, r_u[i].begin(), s_u[i - 1]);
+ for (u32 hu = 0; hu < size_i; ++hu) {
+ u32 u = C.circuit[i].ori_id_u[hu];
+ mult_array[1][u] = mult_array[1][u] + beta_g[hu];
+ }
+ }
+
+ bit_length_i = C.circuit[i].bit_length_v[0];
+ size_i = C.circuit[i].size_v[0];
+ if (~bit_length_i) {
+ initBetaTable(beta_g, bit_length_i, r_v[i].begin(), s_v[i - 1]);
+ for (u32 hv = 0; hv < size_i; ++hv) {
+ u32 v = C.circuit[i].ori_id_v[hv];
+ mult_array[1][v] = mult_array[1][v] + beta_g[hv];
+ }
+ }
+ }
+
+ round = 0;
+ prove_timer.stop();
+}
+
+quadratic_poly prover::sumcheckUpdate1(const F &previous_random) {
+ return sumcheckUpdate(previous_random, r_u[sumcheck_id]);
+}
+
+quadratic_poly prover::sumcheckUpdate2(const F &previous_random) {
+ return sumcheckUpdate(previous_random, r_v[sumcheck_id]);
+}
+
+quadratic_poly prover::sumcheckUpdate(const F &previous_random, vector &r_arr) {
+ prove_timer.start();
+
+ if (round) r_arr.at(round - 1) = previous_random;
+ ++round;
+ quadratic_poly ret;
+
+ add_term = add_term * (F_ONE - previous_random);
+ for (int b = 0; b < 2; ++b)
+ ret = ret + sumcheckUpdateEach(previous_random, b);
+ ret = ret + quadratic_poly(F_ZERO, -add_term, add_term);
+
+ prove_timer.stop();
+ proof_size += F_BYTE_SIZE * 3;
+ return ret;
+}
+
+quadratic_poly prover::sumcheckLiuUpdate(const F &previous_random) {
+ prove_timer.start();
+ ++round;
+
+ auto ret = sumcheckUpdateEach(previous_random, true);
+
+ prove_timer.stop();
+ proof_size += F_BYTE_SIZE * 3;
+ return ret;
+}
+
+quadratic_poly prover::sumcheckUpdateEach(const F &previous_random, bool idx) {
+ auto &tmp_mult = mult_array[idx];
+ auto &tmp_v = V_mult[idx];
+
+ if (total[idx] == 1) {
+ tmp_v[0] = tmp_v[0].eval(previous_random);
+ tmp_mult[0] = tmp_mult[0].eval(previous_random);
+ add_term = add_term + tmp_v[0].b * tmp_mult[0].b;
+ }
+
+ quadratic_poly ret;
+ for (u32 i = 0; i < (total[idx] >> 1); ++i) {
+ u32 g0 = i << 1, g1 = i << 1 | 1;
+ if (g0 >= total_size[idx]) {
+ tmp_v[i].clear();
+ tmp_mult[i].clear();
+ continue;
+ }
+ if (g1 >= total_size[idx]) {
+ tmp_v[g1].clear();
+ tmp_mult[g1].clear();
+ }
+ tmp_v[i] = interpolate(tmp_v[g0].eval(previous_random), tmp_v[g1].eval(previous_random));
+ tmp_mult[i] = interpolate(tmp_mult[g0].eval(previous_random), tmp_mult[g1].eval(previous_random));
+ ret = ret + tmp_mult[i] * tmp_v[i];
+ }
+ total[idx] >>= 1;
+ total_size[idx] = (total_size[idx] + 1) >> 1;
+
+ return ret;
+}
+
+/**
+ * This is to evaluate a multi-linear extension at a random point.
+ *
+ * @param the value of the array & random point & the size of the array & the size of the random point
+ * @return sum of `values`, or 0.0 if `values` is empty.
+ */
+F prover::Vres(const vector::const_iterator &r, u32 output_size, u8 r_size) {
+ prove_timer.start();
+
+ vector output(output_size);
+ for (u32 i = 0; i < output_size; ++i)
+ output[i] = val[C.size - 1][i];
+ u32 whole = 1ULL << r_size;
+ for (u8 i = 0; i < r_size; ++i) {
+ for (u32 j = 0; j < (whole >> 1); ++j) {
+ if (j > 0)
+ output[j].clear();
+ if ((j << 1) < output_size)
+ output[j] = output[j << 1] * (F_ONE - r[i]);
+ if ((j << 1 | 1) < output_size)
+ output[j] = output[j] + output[j << 1 | 1] * (r[i]);
+ }
+ whole >>= 1;
+ }
+ F res = output[0];
+
+ prove_timer.stop();
+ proof_size += F_BYTE_SIZE;
+ return res;
+}
+
+void prover::sumcheckFinalize1(const F &previous_random, F &claim_0, F &claim_1) {
+ prove_timer.start();
+ r_u[sumcheck_id].at(round - 1) = previous_random;
+ V_u0 = claim_0 = total[0] ? V_mult[0][0].eval(previous_random) : (~C.circuit[sumcheck_id].bit_length_u[0]) ? V_mult[0][0].b : F_ZERO;
+ V_u1 = claim_1 = total[1] ? V_mult[1][0].eval(previous_random) : (~C.circuit[sumcheck_id].bit_length_u[1]) ? V_mult[1][0].b : F_ZERO;
+ prove_timer.stop();
+
+ mult_array[0].clear();
+ mult_array[1].clear();
+ V_mult[0].clear();
+ V_mult[1].clear();
+ proof_size += F_BYTE_SIZE * 2;
+}
+
+void prover::sumcheckFinalize2(const F &previous_random, F &claim_0, F &claim_1) {
+ prove_timer.start();
+ r_v[sumcheck_id].at(round - 1) = previous_random;
+ claim_0 = total[0] ? V_mult[0][0].eval(previous_random) : (~C.circuit[sumcheck_id].bit_length_v[0]) ? V_mult[0][0].b : F_ZERO;
+ claim_1 = total[1] ? V_mult[1][0].eval(previous_random) : (~C.circuit[sumcheck_id].bit_length_v[1]) ? V_mult[1][0].b : F_ZERO;
+ prove_timer.stop();
+
+ mult_array[0].clear();
+ mult_array[1].clear();
+ V_mult[0].clear();
+ V_mult[1].clear();
+ proof_size += F_BYTE_SIZE * 2;
+}
+
+void prover::sumcheckLiuFinalize(const F &previous_random, F &claim_1) {
+ prove_timer.start();
+ r_u[sumcheck_id].at(round - 1) = previous_random;
+ claim_1 = total[1] ? V_mult[1][0].eval(previous_random) : V_mult[1][0].b;
+ prove_timer.stop();
+ proof_size += F_BYTE_SIZE;
+
+ mult_array[1].clear();
+ V_mult[1].clear();
+ beta_g.clear();
+}
+
+F prover::getCirValue(u8 layer_id, const vector &ori, u32 u) {
+ return !layer_id ? val[0][ori[u]] : val[layer_id][u];
+}
+
+hyrax_bls12_381::polyProver &prover::commitInput(const vector &gens) {
+ if (C.circuit[0].size != (1ULL << C.circuit[0].bit_length)) {
+ val[0].resize(1ULL << C.circuit[0].bit_length);
+ for (int i = C.circuit[0].size; i < val[0].size(); ++i)
+ val[0][i].clear();
+ }
+ poly_p = std::make_unique(val[0], gens);
+ return *poly_p;
+}
\ No newline at end of file
diff --git a/src/prover.hpp b/src/prover.hpp
new file mode 100644
index 0000000..505e18c
--- /dev/null
+++ b/src/prover.hpp
@@ -0,0 +1,81 @@
+//
+// Created by 69029 on 3/9/2021.
+//
+
+#ifndef ZKCNN_PROVER_HPP
+#define ZKCNN_PROVER_HPP
+
+#include "global_var.hpp"
+#include "circuit.h"
+#include "polynomial.h"
+
+using std::unique_ptr;
+
+class neuralNetwork;
+class singleConv;
+class prover {
+public:
+ void init();
+
+ void sumcheckInitAll(const vector::const_iterator &r_0_from_v);
+ void sumcheckInit(const F &alpha_0, const F &beta_0);
+ void sumcheckDotProdInitPhase1();
+ void sumcheckInitPhase1(const F &relu_rou_0);
+ void sumcheckInitPhase2();
+
+ cubic_poly sumcheckDotProdUpdate1(const F &previous_random);
+ quadratic_poly sumcheckUpdate1(const F &previous_random);
+ quadratic_poly sumcheckUpdate2(const F &previous_random);
+
+ F Vres(const vector::const_iterator &r, u32 output_size, u8 r_size);
+
+ void sumcheckDotProdFinalize1(const F &previous_random, F &claim_1);
+ void sumcheckFinalize1(const F &previous_random, F &claim_0, F &claim_1);
+ void sumcheckFinalize2(const F &previous_random, F &claim_0, F &claim_1);
+ void sumcheckLiuFinalize(const F &previous_random, F &claim_1);
+
+ void sumcheckLiuInit(const vector &s_u, const vector &s_v);
+ quadratic_poly sumcheckLiuUpdate(const F &previous_random);
+
+ hyrax_bls12_381::polyProver &commitInput(const vector &gens);
+
+ timer prove_timer;
+ double proveTime() const { return prove_timer.elapse_sec(); }
+ double proofSize() const { return (double) proof_size / 1024.0; }
+ double polyProverTime() const { return poly_p -> getPT(); }
+ double polyProofSize() const { return poly_p -> getPS(); }
+
+ layeredCircuit C;
+ vector> val; // the output of each gate
+private:
+ quadratic_poly sumcheckUpdateEach(const F &previous_random, bool idx);
+ quadratic_poly sumcheckUpdate(const F &previous_random, vector &r_arr);
+ F getCirValue(u8 layer_id, const vector &ori, u32 u);
+
+ vector::iterator r_0, r_1; // current positions
+ vector> r_u, r_v; // next positions
+
+ vector beta_g;
+
+ F add_term;
+ vector mult_array[2];
+ vector V_mult[2];
+
+ F V_u0, V_u1;
+
+ F alpha, beta, relu_rou;
+
+ u64 proof_size;
+
+ u32 total[2], total_size[2];
+ u8 round; // step within a sumcheck
+ u8 sumcheck_id; // the level
+
+ unique_ptr poly_p;
+
+ friend neuralNetwork;
+ friend singleConv;
+};
+
+
+#endif //ZKCNN_PROVER_HPP
diff --git a/src/utils.cpp b/src/utils.cpp
new file mode 100644
index 0000000..e892eb3
--- /dev/null
+++ b/src/utils.cpp
@@ -0,0 +1,234 @@
+//
+// Created by 69029 on 3/9/2021.
+//
+
+#include
+#include
+#include
+#include "utils.hpp"
+
+using std::cerr;
+using std::endl;
+using std::string;
+using std::cin;
+
+int ceilPow2BitLengthSigned(double n) {
+ return (i8) ceil(log2(n));
+}
+
+int floorPow2BitLengthSigned(double n) {
+ return (i8) floor(log2(n));
+}
+
+i8 ceilPow2BitLength(u32 n) {
+ return n < 1e-9 ? -1 : (i8) ceil(log(n) / log(2.));
+}
+
+i8 floorPow2BitLength(u32 n) {
+// cerr << n << ' ' << log(n) / log(2.)< &beta_f, vector &beta_s, const vector::const_iterator &r, const F &init, u32 first_half, u32 second_half) {
+ beta_f.at(0) = init;
+ beta_s.at(0) = F_ONE;
+
+ for (u32 i = 0; i < first_half; ++i) {
+ for (u32 j = 0; j < (1ULL << i); ++j) {
+ auto tmp = beta_f.at(j) * r[i];
+ beta_f.at(j | (1ULL << i)) = tmp;
+ beta_f.at(j) = beta_f[j] - tmp;
+ }
+ }
+
+ for (u32 i = 0; i < second_half; ++i) {
+ for (u32 j = 0; j < (1ULL << i); ++j) {
+ auto tmp = beta_s[j] * r[(i + first_half)];
+ beta_s[j | (1ULL << i)] = tmp;
+ beta_s[j] = beta_s[j] - tmp;
+ }
+ }
+}
+
+void phiPowInit(vector &phi_mul, int n, bool isIFFT) {
+ u32 N = 1ULL << n;
+ F phi = getRootOfUnit(n);
+ if (isIFFT) F::inv(phi, phi);
+ phi_mul[0] = F_ONE;
+ for (u32 i = 1; i < N; ++i) phi_mul[i] = phi_mul[i - 1] * phi;
+}
+
+void phiGInit(vector &phi_g, const vector::const_iterator &rx, const F &scale, int n, bool isIFFT) {
+ vector phi_mul(1 << n);
+ phiPowInit(phi_mul, n, isIFFT);
+
+ if (isIFFT) {
+// cerr << "==" << endl;
+// cerr << "gLength: " << n << endl;
+// for (int i = 0; i < n - 1; ++i) {
+// cerr << rx[i];
+// cerr << endl;
+// }
+ phi_g[0] = phi_g[1] = scale;
+ for (int i = 2; i <= n; ++i)
+ for (u32 b = 0; b < (1ULL << (i - 1)); ++b) {
+ u32 l = b, r = b ^ (1ULL << (i - 1));
+ int m = n - i;
+ F tmp1 = F_ONE - rx[m], tmp2 = rx[m] * phi_mul[b << m];
+ phi_g[r] = phi_g[l] * (tmp1 - tmp2);
+ phi_g[l] = phi_g[l] * (tmp1 + tmp2);
+ }
+ } else {
+// cerr << "==" << endl;
+// cerr << "gLength: " << n << endl;
+// for (int i = 0; i < n; ++i) {
+// cerr << rx[i];
+// cerr << endl;
+// }
+ phi_g[0] = scale;
+ for (int i = 1; i < n; ++i)
+ for (u32 b = 0; b < (1ULL << (i - 1)); ++b) {
+ u32 l = b, r = b ^ (1ULL << (i - 1));
+ int m = n - i;
+ F tmp1 = F_ONE - rx[m], tmp2 = rx[m] * phi_mul[b << m];
+ phi_g[r] = phi_g[l] * (tmp1 - tmp2);
+ phi_g[l] = phi_g[l] * (tmp1 + tmp2);
+ }
+ for (u32 b = 0; b < (1ULL << (n - 1)); ++b) {
+ u32 l = b;
+ F tmp1 = F_ONE - rx[0], tmp2 = rx[0] * phi_mul[b];
+ phi_g[l] = phi_g[l] * (tmp1 + tmp2);
+ }
+ }
+}
+
+void fft(vector &arr, int logn, bool flag) {
+// cerr << "fft: " << endl;
+// for (auto x: arr) cerr << x << ' ';
+// cerr << endl;
+ static std::vector rev;
+ static std::vector w;
+
+ u32 len = 1ULL << logn;
+ assert(arr.size() == len);
+
+ rev.resize(len);
+ w.resize(len);
+
+ rev[0] = 0;
+ for (u32 i = 1; i < len; ++i)
+ rev[i] = rev[i >> 1] >> 1 | (i & 1) << (logn - 1);
+
+ w[0] = F_ONE;
+ w[1] = getRootOfUnit(logn);
+ if (flag) F::inv(w[1], w[1]);
+ for (u32 i = 2; i < len; ++i) w[i] = w[i - 1] * w[1];
+
+ for (u32 i = 0; i < len; ++i)
+ if (rev[i] < i) std::swap(arr[i], arr[rev[i]]);
+
+ for (u32 i = 2; i <= len; i <<= 1)
+ for (u32 j = 0; j < len; j += i)
+ for (u32 k = 0; k < (i >> 1); ++k) {
+ auto u = arr[j + k];
+ auto v = arr[j + k + (i >> 1)] * w[len / i * k];
+ arr[j + k] = u + v;
+ arr[j + k + (i >> 1)] = u - v;
+ }
+
+ if (flag) {
+ F ilen;
+ F::inv(ilen, len);
+ for (u32 i = 0; i < len; ++i)
+ arr[i] = arr[i] * ilen;
+ }
+}
+
+void
+initBetaTable(vector &beta_g, u8 gLength, const vector::const_iterator &r_0, const vector::const_iterator &r_1,
+ const F &alpha, const F &beta) {
+ u8 first_half = gLength >> 1, second_half = gLength - first_half;
+ u32 mask_fhalf = (1ULL << first_half) - 1;
+
+ vector beta_f(1ULL << first_half), beta_s(1ULL << second_half);
+ if (!beta.isZero()) {
+ initHalfTable(beta_f, beta_s, r_1, beta, first_half, second_half);
+ for (u32 i = 0; i < (1ULL << gLength); ++i)
+ beta_g[i] = beta_f[i & mask_fhalf] * beta_s[i >> first_half];
+ } else for (u32 i = 0; i < (1ULL << gLength); ++i)
+ beta_g[i].clear();
+
+ if (alpha.isZero()) return;
+ initHalfTable(beta_f, beta_s, r_0, alpha, first_half, second_half);
+ for (u32 i = 0; i < (1ULL << gLength); ++i)
+ beta_g[i] = beta_g[i] + beta_f[i & mask_fhalf] * beta_s[i >> first_half];
+}
+
+
+void initBetaTable(vector &beta_g, u8 gLength, const vector::const_iterator &r, const F &init) {
+ if (gLength == -1) return;
+ int first_half = gLength >> 1, second_half = gLength - first_half;
+ u32 mask_fhalf = (1ULL << first_half) - 1;
+ vector beta_f(1ULL << first_half), beta_s(1ULL << second_half);
+
+ if (!init.isZero()) {
+ initHalfTable(beta_f, beta_s, r, init, first_half, second_half);
+ for (u32 i = 0; i < (1ULL << gLength); ++i)
+ beta_g[i] = beta_f[i & mask_fhalf] * beta_s[i >> first_half];
+ } else for (u32 i = 0; i < (1ULL << gLength); ++i)
+ beta_g[i].clear();
+}
+
+bool check(long x, long y, long nx, long ny) {
+ return 0 <= x && x < nx && 0 <= y && y < ny;
+}
+//
+//F getData(u8 scale_bl) {
+// double x;
+// in >> x;
+// long y = round(x * (1L << scale_bl));
+// return F(y);
+//}
+
+void initLayer(layer &circuit, long size, layerType ty) {
+ circuit.size = circuit.zero_start_id = size;
+ circuit.bit_length = ceilPow2BitLength(size);
+ circuit.ty = ty;
+}
+
+long sqr(long x) {
+ return x * x;
+}
+
+double byte2KB(size_t x) { return x / 1024.0; }
+
+double byte2MB(size_t x) { return x / 1024.0 / 1024.0; }
+
+double byte2GB(size_t x) { return x / 1024.0 / 1024.0 / 1024.0; }
+
+long matIdx(long x, long y, long n) {
+ assert(y < n);
+ return x * n + y;
+}
+
+long cubIdx(long x, long y, long z, long n, long m) {
+ assert(y < n && z < m);
+ return matIdx(matIdx(x, y, n), z, m);
+}
+
+long tesIdx(long w, long x, long y, long z, long n, long m, long l) {
+ assert(x < n && y < m && z < l);
+ return matIdx(cubIdx(w, x, y, n, m), z, l);
+}
+
+F getRootOfUnit(int n) {
+ F res = -F_ONE;
+ if (!n) return F_ONE;
+ while (--n) {
+ bool b = F::squareRoot(res, res);
+ assert(b);
+ }
+ return res;
+}
+
+
diff --git a/src/utils.hpp b/src/utils.hpp
new file mode 100644
index 0000000..d5b5043
--- /dev/null
+++ b/src/utils.hpp
@@ -0,0 +1,49 @@
+//
+// Created by 69029 on 3/9/2021.
+//
+
+#ifndef ZKCNN_UTILS_HPP
+#define ZKCNN_UTILS_HPP
+
+#include
+
+int ceilPow2BitLengthSigned(double n);
+int floorPow2BitLengthSigned(double n);
+
+char ceilPow2BitLength(u32 n);
+char floorPow2BitLength(u32 n);
+
+
+void fft(vector &arr, int logn, bool flag);
+
+void
+initBetaTable(vector &beta_g, u8 gLength, const vector::const_iterator &r_0, const vector::const_iterator &r_1,
+ const F &alpha, const F &beta);
+
+void initPhiTable(F *phi_g, const layer &cur_layer, const F *r_0, const F *r_1, F alpha, F beta);
+
+void phiGInit(vector &phi_g, const vector::const_iterator &rx, const F &scale, int n, bool isIFFT);
+
+void initBetaTable(vector &beta_g, u8 gLength, const vector::const_iterator &r, const F &init);
+
+bool check(long x, long y, long nx, long ny);
+
+long matIdx(long x, long y, long n);
+
+long cubIdx(long x, long y, long z, long n, long m);
+
+long tesIdx(long w, long x, long y, long z, long n, long m, long l);
+
+void initLayer(layer &circuit, long size, layerType ty);
+
+long sqr(long x);
+
+double byte2KB(size_t x);
+
+double byte2MB(size_t x);
+
+double byte2GB(size_t x);
+
+F getRootOfUnit(int n);
+
+#endif //ZKCNN_UTILS_HPP
diff --git a/src/verifier.cpp b/src/verifier.cpp
new file mode 100644
index 0000000..4602acf
--- /dev/null
+++ b/src/verifier.cpp
@@ -0,0 +1,373 @@
+//
+// Created by 69029 on 3/9/2021.
+//
+
+#include "verifier.hpp"
+#include "global_var.hpp"
+#include
+#include
+#include
+
+vector beta_v;
+static vector beta_u, beta_gs;
+
+verifier::verifier(prover *pr, const layeredCircuit &cir):
+ p(pr), C(cir) {
+ final_claim_u0.resize(C.size + 2);
+ final_claim_v0.resize(C.size + 2);
+
+ r_u.resize(C.size + 2);
+ r_v.resize(C.size + 2);
+ // make the prover ready
+ p->init();
+}
+
+F verifier::getFinalValue(const F &claim_u0, const F &claim_u1, const F &claim_v0, const F &claim_v1) {
+
+ auto test_value = bin_value[0] * (claim_u0 * claim_v0)
+ + bin_value[1] * (claim_u1 * claim_v1)
+ + bin_value[2] * (claim_u1 * claim_v0)
+ + uni_value[0] * claim_u0
+ + uni_value[1] * claim_u1;
+
+ return test_value;
+}
+
+void verifier::betaInitPhase1(u8 depth, const F &alpha, const F &beta, const vector::const_iterator &r_0, const vector::const_iterator &r_1, const F &relu_rou) {
+ i8 bl = C.circuit[depth].bit_length;
+ i8 fft_bl = C.circuit[depth].fft_bit_length;
+ i8 fft_blh = C.circuit[depth].fft_bit_length - 1;
+ i8 cnt_bl = bl - fft_bl, cnt_bl2 = C.circuit[depth].max_bl_u - fft_bl;
+
+ switch (C.circuit[depth].ty) {
+ case layerType::FFT:
+ case layerType::IFFT:
+ beta_gs.resize(1ULL << fft_bl);
+ phiGInit(beta_gs, r_0, C.circuit[depth].scale, fft_bl, C.circuit[depth].ty == layerType::IFFT);
+ beta_u.resize(1ULL << C.circuit[depth].max_bl_u);
+ initBetaTable(beta_u, C.circuit[depth].max_bl_u, r_u[depth].begin(), F_ONE);
+ break;
+ case layerType::PADDING:
+ beta_g.resize(1ULL << bl);
+ beta_gs.resize(1ULL << fft_blh);
+ initBetaTable(beta_g, bl - fft_blh, r_u[depth + 2].begin() + fft_bl, r_v[depth + 2].begin(), alpha, beta);
+ initBetaTable(beta_gs, fft_blh, r_0, F_ONE);
+ for (u32 g = (1ULL << bl) - 1; g < (1ULL << bl); --g)
+ beta_g[g] = beta_g[g >> fft_blh] *
+ beta_gs[g & (1ULL << fft_blh) - 1];
+ beta_u.resize(1ULL << C.circuit[depth].max_bl_u);
+ initBetaTable(beta_u, C.circuit[depth].max_bl_u, r_u[depth].begin(), F_ONE);
+ break;
+ case layerType::DOT_PROD:
+ beta_g.resize(1ULL << cnt_bl);
+ initBetaTable(beta_g, cnt_bl, r_u[depth + 2].begin() + fft_bl - 1, alpha);
+
+ beta_u.resize(1ULL << cnt_bl2);
+ initBetaTable(beta_u, cnt_bl2, r_u[depth].begin() + fft_bl, F_ONE);
+ for (u32 i = 0; i < 1ULL << cnt_bl2; ++i)
+ for (u32 j = 0; j < fft_bl; ++j)
+ beta_u[i] = beta_u[i] * ((r_0[j] * r_u[depth][j]) + (F_ONE - r_0[j]) * (F_ONE - r_u[depth][j]));
+ break;
+
+ default:
+ beta_g.resize(1ULL << bl);
+ initBetaTable(beta_g, C.circuit[depth].bit_length, r_0, r_1, alpha * C.circuit[depth].scale,
+ beta * C.circuit[depth].scale);
+ if (C.circuit[depth].zero_start_id < C.circuit[depth].size)
+ for (u32 g = C.circuit[depth].zero_start_id; g < 1ULL << C.circuit[depth].bit_length; ++g)
+ beta_g[g] = beta_g[g] * relu_rou;
+ beta_u.resize(1ULL << C.circuit[depth].max_bl_u);
+ initBetaTable(beta_u, C.circuit[depth].max_bl_u, r_u[depth].begin(), F_ONE);
+ }
+}
+
+void verifier::betaInitPhase2(u8 depth) {
+ beta_v.resize(1ULL << C.circuit[depth].max_bl_v);
+ initBetaTable(beta_v, C.circuit[depth].max_bl_v, r_v[depth].begin(), F_ONE);
+}
+
+void verifier::predicatePhase1(u8 layer_id) {
+ auto &cur_layer = C.circuit[layer_id];
+
+ uni_value[0].clear();
+ uni_value[1].clear();
+ if (cur_layer.ty == layerType::FFT || cur_layer.ty == layerType::IFFT)
+ for (u32 u = 0; u < 1ULL << cur_layer.max_bl_u; ++u)
+ uni_value[1] = uni_value[1] + beta_gs[u] * beta_u[u];
+ else for (auto &gate: cur_layer.uni_gates) {
+ bool idx = gate.lu;
+ uni_value[idx] = uni_value[idx] + beta_g[gate.g] * beta_u[gate.u] * C.two_mul[gate.sc];
+ }
+ bin_value[0] = bin_value[1] = bin_value[2] = F_ZERO;
+}
+
+void verifier::predicatePhase2(u8 layer_id) {
+ uni_value[0] = uni_value[0] * beta_v[0];
+ uni_value[1] = uni_value[1] * beta_v[0];
+
+ auto &cur_layer = C.circuit[layer_id];
+ if (C.circuit[layer_id].ty == layerType::DOT_PROD) {
+ for (auto &gate: cur_layer.bin_gates)
+ bin_value[gate.l] =
+ bin_value[gate.l] +
+ beta_g[gate.g] * beta_u[gate.u] * beta_v[gate.v];
+ } else for (auto &gate: cur_layer.bin_gates)
+ bin_value[gate.l] = bin_value[gate.l] + beta_g[gate.g] * beta_u[gate.u] * beta_v[gate.v] * C.two_mul[gate.sc];
+}
+
+bool verifier::verify() {
+ u8 logn = C.circuit[0].bit_length;
+ u64 n_sqrt = 1ULL << (logn - (logn >> 1));
+ vector gens(n_sqrt);
+ for (auto &x: gens) {
+ Fr tmp;
+ tmp.setByCSPRNG();
+ x = mcl::bn::getG1basePoint() * tmp;
+ }
+
+ poly_v = std::make_unique(p -> commitInput(gens), gens);
+ return verifyInnerLayers() && verifyFirstLayer() && verifyInput();
+}
+
+bool verifier::verifyInnerLayers() {
+ total_timer.start();
+ total_slow_timer.start();
+
+ F alpha = F_ONE, beta = F_ZERO, relu_rou, final_claim_u1, final_claim_v1;
+ r_u[C.size].resize(C.circuit[C.size - 1].bit_length);
+ for (i8 i = 0; i < C.circuit[C.size - 1].bit_length; ++i)
+ r_u[C.size][i].setByCSPRNG();
+ vector::const_iterator r_0 = r_u[C.size].begin();
+ vector::const_iterator r_1;
+
+ total_timer.stop();
+ total_slow_timer.stop();
+
+ auto previousSum = p->Vres(r_0, C.circuit[C.size - 1].size, C.circuit[C.size - 1].bit_length);
+ p -> sumcheckInitAll(r_0);
+
+ for (u8 i = C.size - 1; i; --i) {
+ auto &cur = C.circuit[i];
+ p->sumcheckInit(alpha, beta);
+ total_timer.start();
+ total_slow_timer.start();
+
+ // phase 1
+ r_u[i].resize(cur.max_bl_u);
+ for (int j = 0; j < cur.max_bl_u; ++j) r_u[i][j].setByCSPRNG();
+ if (cur.zero_start_id < cur.size)
+ relu_rou.setByCSPRNG();
+ else relu_rou = F_ONE;
+
+ total_timer.stop();
+ total_slow_timer.stop();
+ if (cur.ty == layerType::DOT_PROD)
+ p->sumcheckDotProdInitPhase1();
+ else p->sumcheckInitPhase1(relu_rou);
+
+ F previousRandom = F_ZERO;
+ for (i8 j = 0; j < cur.max_bl_u; ++j) {
+ F cur_claim, nxt_claim;
+ if (cur.ty == layerType::DOT_PROD) {
+ cubic_poly poly = p->sumcheckDotProdUpdate1(previousRandom);
+ total_timer.start();
+ total_slow_timer.start();
+ cur_claim = poly.eval(F_ZERO) + poly.eval(F_ONE);
+ nxt_claim = poly.eval(r_u[i][j]);
+ } else {
+ quadratic_poly poly = p->sumcheckUpdate1(previousRandom);
+ total_timer.start();
+ total_slow_timer.start();
+ cur_claim = poly.eval(F_ZERO) + poly.eval(F_ONE);
+ nxt_claim = poly.eval(r_u[i][j]);
+ }
+
+ if (cur_claim != previousSum) {
+ cerr << cur_claim << ' ' << previousSum << endl;
+ fprintf(stderr, "Verification fail, phase1, circuit %d, current bit %d\n", i, j);
+ return false;
+ }
+ previousRandom = r_u[i][j];
+ previousSum = nxt_claim;
+ total_timer.stop();
+ total_slow_timer.stop();
+ }
+
+ if (cur.ty == layerType::DOT_PROD)
+ p->sumcheckDotProdFinalize1(previousRandom, final_claim_u1);
+ else p->sumcheckFinalize1(previousRandom, final_claim_u0[i], final_claim_u1);
+
+ total_slow_timer.start();
+ betaInitPhase1(i, alpha, beta, r_0, r_1, relu_rou);
+ predicatePhase1(i);
+
+ total_timer.start();
+ if (cur.need_phase2) {
+ r_v[i].resize(cur.max_bl_v);
+ for (int j = 0; j < cur.max_bl_v; ++j) r_v[i][j].setByCSPRNG();
+
+ total_timer.stop();
+ total_slow_timer.stop();
+
+ p->sumcheckInitPhase2();
+ previousRandom = F_ZERO;
+ for (u32 j = 0; j < cur.max_bl_v; ++j) {
+ quadratic_poly poly = p->sumcheckUpdate2(previousRandom);
+
+ total_timer.start();
+ total_slow_timer.start();
+ if (poly.eval(F_ZERO) + poly.eval(F_ONE) != previousSum) {
+ fprintf(stderr, "Verification fail, phase2, circuit level %d, current bit %d, total is %d\n", i, j,
+ cur.max_bl_v);
+ return false;
+ }
+
+ previousRandom = r_v[i][j];
+ previousSum = poly.eval(previousRandom);
+ total_timer.stop();
+ total_slow_timer.stop();
+ }
+ p->sumcheckFinalize2(previousRandom, final_claim_v0[i], final_claim_v1);
+
+ total_slow_timer.start();
+ betaInitPhase2(i);
+ predicatePhase2(i);
+ total_timer.start();
+ }
+ F test_value = getFinalValue(final_claim_u0[i], final_claim_u1, final_claim_v0[i], final_claim_v1);
+
+ if (previousSum != test_value) {
+ std::cerr << test_value << ' ' << previousSum << std::endl;
+ fprintf(stderr, "Verification fail, semi final, circuit level %d\n", i);
+ return false;
+ } else fprintf(stderr, "Verification Pass, semi final, circuit level %d\n", i);
+
+ if (cur.ty == layerType::FFT || cur.ty == layerType::IFFT)
+ previousSum = final_claim_u1;
+ else {
+ if (~cur.bit_length_u[1])
+ alpha.setByCSPRNG();
+ else alpha.clear();
+ if ((~cur.bit_length_v[1]) || cur.ty == layerType::FFT)
+ beta.setByCSPRNG();
+ else beta.clear();
+ previousSum = alpha * final_claim_u1 + beta * final_claim_v1;
+ }
+
+ r_0 = r_u[i].begin();
+ r_1 = r_v[i].begin();
+
+ total_timer.stop();
+ total_slow_timer.stop();
+ beta_u.clear();
+ beta_v.clear();
+ }
+ return true;
+}
+
+bool verifier::verifyFirstLayer() {
+ total_slow_timer.start();
+ total_timer.start();
+
+ auto &cur = C.circuit[0];
+
+ vector sig_u(C.size - 1);
+ for (int i = 0; i < C.size - 1; ++i) sig_u[i].setByCSPRNG();
+ vector sig_v(C.size - 1);
+ for (int i = 0; i < C.size - 1; ++i) sig_v[i].setByCSPRNG();
+ r_u[0].resize(cur.bit_length);
+ for (int i = 0; i < cur.bit_length; ++i) r_u[0][i].setByCSPRNG();
+ auto r_0 = r_u[0].begin();
+
+ F previousSum = F_ZERO;
+ for (int i = 1; i < C.size; ++i) {
+ if (~C.circuit[i].bit_length_u[0])
+ previousSum = previousSum + sig_u[i - 1] * final_claim_u0[i];
+ if (~C.circuit[i].bit_length_v[0])
+ previousSum = previousSum + sig_v[i - 1] * final_claim_v0[i];
+ }
+ total_timer.stop();
+ total_slow_timer.stop();
+
+ p->sumcheckLiuInit(sig_u, sig_v);
+ F previousRandom = F_ZERO;
+ for (int j = 0; j < cur.bit_length; ++j) {
+ auto poly = p -> sumcheckLiuUpdate(previousRandom);
+ if (poly.eval(F_ZERO) + poly.eval(F_ONE) != previousSum) {
+ fprintf(stderr, "Liu fail, circuit 0, current bit %d\n", j);
+ return false;
+ }
+ previousRandom = r_0[j];
+ previousSum = poly.eval(previousRandom);
+ }
+
+ F gr = F_ZERO;
+ p->sumcheckLiuFinalize(previousRandom, eval_in);
+
+ beta_g.resize(1ULL << cur.bit_length);
+
+ total_slow_timer.start();
+ initBetaTable(beta_g, cur.bit_length, r_0, F_ONE);
+ for (int i = 1; i < C.size; ++i) {
+ if (~C.circuit[i].bit_length_u[0]) {
+ beta_u.resize(1ULL << C.circuit[i].bit_length_u[0]);
+ initBetaTable(beta_u, C.circuit[i].bit_length_u[0], r_u[i].begin(), sig_u[i - 1]);
+ for (u32 j = 0; j < C.circuit[i].size_u[0]; ++j)
+ gr = gr + beta_g[C.circuit[i].ori_id_u[j]] * beta_u[j];
+ }
+
+ if (~C.circuit[i].bit_length_v[0]) {
+ beta_v.resize(1ULL << C.circuit[i].bit_length_v[0]);
+ initBetaTable(beta_v, C.circuit[i].bit_length_v[0], r_v[i].begin(), sig_v[i - 1]);
+ for (u32 j = 0; j < C.circuit[i].size_v[0]; ++j)
+ gr = gr + beta_g[C.circuit[i].ori_id_v[j]] * beta_v[j];
+ }
+ }
+
+ beta_u.clear();
+ beta_v.clear();
+
+ total_timer.start();
+ if (eval_in * gr != previousSum) {
+ fprintf(stderr, "Liu fail, semi final, circuit 0.\n");
+ return false;
+ }
+
+ total_timer.stop();
+ total_slow_timer.stop();
+ output_tb[PT_OUT_ID] = to_string_wp(p->proveTime());
+ output_tb[VT_OUT_ID] = to_string_wp(verifierTime());
+ output_tb[PS_OUT_ID] = to_string_wp(p -> proofSize());
+
+ fprintf(stderr, "Verification pass\n");
+ fprintf(stderr, "Prove Time %lf\n", p->proveTime());
+ fprintf(stderr, "verify time %lf = %lf + %lf(slow)\n", verifierSlowTime(), verifierTime(), verifierSlowTime() - verifierTime());
+ fprintf(stderr, "proof size = %lf kb\n", p -> proofSize());
+
+ beta_g.clear();
+ beta_gs.clear();
+ beta_u.clear();
+ beta_v.clear();
+ r_u.resize(1);
+ r_v.clear();
+
+ sig_u.clear();
+ sig_v.clear();
+ return true;
+}
+
+bool verifier::verifyInput() {
+ if (!poly_v -> verify(r_u[0], eval_in)) {
+ fprintf(stderr, "Verification fail, final input check fail.\n");
+ return false;
+ }
+
+ fprintf(stderr, "poly pt = %.5f, vt = %.5f, ps = %.5f\n", p -> polyProverTime(), poly_v -> getVT(), p -> polyProofSize());
+ output_tb[POLY_PT_OUT_ID] = to_string_wp(p -> polyProverTime());
+ output_tb[POLY_VT_OUT_ID] = to_string_wp(poly_v -> getVT());
+ output_tb[POLY_PS_OUT_ID] = to_string_wp(p -> polyProofSize());
+ output_tb[TOT_PT_OUT_ID] = to_string_wp(p -> polyProverTime() + p->proveTime());
+ output_tb[TOT_VT_OUT_ID] = to_string_wp(poly_v -> getVT() + verifierTime());
+ output_tb[TOT_PS_OUT_ID] = to_string_wp(p -> polyProofSize() + p -> proofSize());
+ return true;
+}
diff --git a/src/verifier.hpp b/src/verifier.hpp
new file mode 100644
index 0000000..d9af504
--- /dev/null
+++ b/src/verifier.hpp
@@ -0,0 +1,47 @@
+//
+// Created by 69029 on 3/9/2021.
+//
+
+#ifndef ZKCNN_CONVVERIFIER_HPP
+#define ZKCNN_CONVVERIFIER_HPP
+
+#include "prover.hpp"
+
+using std::unique_ptr;
+class verifier {
+public:
+ prover *p;
+ const layeredCircuit &C;
+
+ verifier(prover *pr, const layeredCircuit &cir);
+
+ bool verify();
+
+ timer total_timer, total_slow_timer;
+ double verifierTime() const { return total_timer.elapse_sec(); }
+ double verifierSlowTime() const { return total_slow_timer.elapse_sec(); }
+
+private:
+ vector> r_u, r_v;
+ vector final_claim_u0, final_claim_v0;
+ bool verifyInnerLayers();
+ bool verifyFirstLayer();
+ bool verifyInput();
+
+ vector beta_g;
+ void betaInitPhase1(u8 depth, const F &alpha, const F &beta, const vector::const_iterator &r_0, const vector::const_iterator &r_1, const F &relu_rou);
+ void betaInitPhase2(u8 depth);
+
+ F uni_value[2];
+ F bin_value[3];
+ void predicatePhase1(u8 layer_id);
+ void predicatePhase2(u8 layer_id);
+
+ F getFinalValue(const F &claim_u0, const F &claim_u1, const F &claim_v0, const F &claim_v1);
+
+ F eval_in;
+ unique_ptr poly_v;
+};
+
+
+#endif //ZKCNN_CONVVERIFIER_HPP