forked from microsoft/SealPIR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.cpp
124 lines (101 loc) · 4.94 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include "pir.hpp"
#include "pir_client.hpp"
#include "pir_server.hpp"
#include <seal/seal.h>
#include <chrono>
#include <memory>
#include <random>
#include <cstdint>
#include <cstddef>
using namespace std::chrono;
using namespace std;
using namespace seal;
int main(int argc, char *argv[]) {
uint64_t number_of_items = 1 << 12;
uint64_t size_per_item = 288; // in bytes
uint32_t N = 2048;
// Recommended values: (logt, d) = (12, 2) or (8, 1).
uint32_t logt = 12;
uint32_t d = 2;
EncryptionParameters params(scheme_type::BFV);
PirParams pir_params;
// Generates all parameters
cout << "Main: Generating all parameters" << endl;
gen_params(number_of_items, size_per_item, N, logt, d, params, pir_params);
cout << "Main: Initializing the database (this may take some time) ..." << endl;
// Create test database
auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
// Copy of the database. We use this at the end to make sure we retrieved
// the correct element.
auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
random_device rd;
for (uint64_t i = 0; i < number_of_items; i++) {
for (uint64_t j = 0; j < size_per_item; j++) {
auto val = rd() % 256;
db.get()[(i * size_per_item) + j] = val;
db_copy.get()[(i * size_per_item) + j] = val;
}
}
// Initialize PIR Server
cout << "Main: Initializing server and client" << endl;
PIRServer server(params, pir_params);
// Initialize PIR client....
PIRClient client(params, pir_params);
GaloisKeys galois_keys = client.generate_galois_keys();
// Set galois key for client with id 0
cout << "Main: Setting Galois keys...";
server.set_galois_key(0, galois_keys);
// Measure database setup
auto time_pre_s = high_resolution_clock::now();
server.set_database(move(db), number_of_items, size_per_item);
server.preprocess_database();
cout << "Main: database pre processed " << endl;
auto time_pre_e = high_resolution_clock::now();
auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
// Choose an index of an element in the DB
uint64_t ele_index = rd() % number_of_items; // element in DB at random position
uint64_t index = client.get_fv_index(ele_index, size_per_item); // index of FV plaintext
uint64_t offset = client.get_fv_offset(ele_index, size_per_item); // offset in FV plaintext
cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
// Measure query generation
auto time_query_s = high_resolution_clock::now();
PirQuery query = client.generate_query(index);
auto time_query_e = high_resolution_clock::now();
auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
cout << "Main: query generated" << endl;
//To marshall query to send over the network, you can use serialize/deserialize:
//std::string query_ser = serialize_query(query);
//PirQuery query2 = deserialize_query(d, 1, query_ser, CIPHER_SIZE);
// Measure query processing (including expansion)
auto time_server_s = high_resolution_clock::now();
PirReply reply = server.generate_reply(query, 0);
auto time_server_e = high_resolution_clock::now();
auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
// Measure response extraction
auto time_decode_s = chrono::high_resolution_clock::now();
Plaintext result = client.decode_reply(reply);
auto time_decode_e = chrono::high_resolution_clock::now();
auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
// Convert from FV plaintext (polynomial) to database element at the client
vector<uint8_t> elems(N * logt / 8);
coeffs_to_bytes(logt, result, elems.data(), (N * logt) / 8);
// Check that we retrieved the correct element
for (uint32_t i = 0; i < size_per_item; i++) {
if (elems[(offset * size_per_item) + i] != db_copy.get()[(ele_index * size_per_item) + i]) {
cout << "Main: elems " << (int)elems[(offset * size_per_item) + i] << ", db "
<< (int) db_copy.get()[(ele_index * size_per_item) + i] << endl;
cout << "Main: PIR result wrong!" << endl;
return -1;
}
}
// Output results
cout << "Main: PIR result correct!" << endl;
cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
cout << "Main: PIRClient query generation time: " << time_query_us / 1000 << " ms" << endl;
cout << "Main: PIRServer reply generation time: " << time_server_us / 1000 << " ms"
<< endl;
cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;
cout << "Main: Reply num ciphertexts: " << reply.size() << endl;
return 0;
}