forked from microsoft/SEAL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
scalingvariant.cpp
161 lines (151 loc) · 7.82 KB
/
scalingvariant.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include "seal/encryptor.h"
#include "seal/util/polyarithsmallmod.h"
#include "seal/util/scalingvariant.h"
#include "seal/util/uintarith.h"
using namespace std;
namespace seal
{
namespace util
{
void add_plain_without_scaling_variant(
const Plaintext &plain, const SEALContext::ContextData &context_data, RNSIter destination)
{
auto &parms = context_data.parms();
auto &coeff_modulus = parms.coeff_modulus();
const size_t plain_coeff_count = plain.coeff_count();
const size_t coeff_modulus_size = coeff_modulus.size();
#ifdef SEAL_DEBUG
if (plain_coeff_count > parms.poly_modulus_degree())
{
throw std::invalid_argument("invalid plaintext");
}
if (destination.poly_modulus_degree() != parms.poly_modulus_degree())
{
throw std::invalid_argument("destination is not valid for encryption parameters");
}
#endif
SEAL_ITERATE(iter(destination, coeff_modulus), coeff_modulus_size, [&](auto I) {
std::transform(
plain.data(), plain.data() + plain_coeff_count, get<0>(I), get<0>(I),
[&](uint64_t m, uint64_t c) -> uint64_t {
m = barrett_reduce_64(m, get<1>(I));
return add_uint_mod(c, m, get<1>(I));
});
});
}
void sub_plain_without_scaling_variant(
const Plaintext &plain, const SEALContext::ContextData &context_data, RNSIter destination)
{
auto &parms = context_data.parms();
auto &coeff_modulus = parms.coeff_modulus();
const size_t plain_coeff_count = plain.coeff_count();
const size_t coeff_modulus_size = coeff_modulus.size();
#ifdef SEAL_DEBUG
if (plain_coeff_count > parms.poly_modulus_degree())
{
throw std::invalid_argument("invalid plaintext");
}
if (destination.poly_modulus_degree() != parms.poly_modulus_degree())
{
throw std::invalid_argument("destination is not valid for encryption parameters");
}
#endif
SEAL_ITERATE(iter(destination, coeff_modulus), coeff_modulus_size, [&](auto I) {
std::transform(
plain.data(), plain.data() + plain_coeff_count, get<0>(I), get<0>(I),
[&](uint64_t m, uint64_t c) -> uint64_t {
m = barrett_reduce_64(m, get<1>(I));
return sub_uint_mod(c, m, get<1>(I));
});
});
}
void multiply_add_plain_with_scaling_variant(
const Plaintext &plain, const SEALContext::ContextData &context_data, RNSIter destination)
{
auto &parms = context_data.parms();
size_t plain_coeff_count = plain.coeff_count();
auto &coeff_modulus = parms.coeff_modulus();
size_t coeff_modulus_size = coeff_modulus.size();
auto plain_modulus = context_data.parms().plain_modulus();
auto coeff_div_plain_modulus = context_data.coeff_div_plain_modulus();
uint64_t plain_upper_half_threshold = context_data.plain_upper_half_threshold();
uint64_t q_mod_t = context_data.coeff_modulus_mod_plain_modulus();
#ifdef SEAL_DEBUG
if (plain_coeff_count > parms.poly_modulus_degree())
{
throw std::invalid_argument("invalid plaintext");
}
if (destination.poly_modulus_degree() != parms.poly_modulus_degree())
{
throw invalid_argument("destination is not valid for encryption parameters");
}
#endif
// Coefficients of plain m multiplied by coeff_modulus q, divided by plain_modulus t,
// and rounded to the nearest integer (rounded up in case of a tie). Equivalent to
// floor((q * m + floor((t+1) / 2)) / t).
SEAL_ITERATE(iter(plain.data(), size_t(0)), plain_coeff_count, [&](auto I) {
// Compute numerator = (q mod t) * m[i] + (t+1)/2
unsigned long long prod[2]{ 0, 0 };
uint64_t numerator[2]{ 0, 0 };
multiply_uint64(get<0>(I), q_mod_t, prod);
unsigned char carry = add_uint64(*prod, plain_upper_half_threshold, numerator);
numerator[1] = static_cast<uint64_t>(prod[1]) + static_cast<uint64_t>(carry);
// Compute fix[0] = floor(numerator / t)
uint64_t fix[2] = { 0, 0 };
divide_uint128_inplace(numerator, plain_modulus.value(), fix);
// Add to ciphertext: floor(q / t) * m + increment
size_t coeff_index = get<1>(I);
SEAL_ITERATE(
iter(destination, coeff_modulus, coeff_div_plain_modulus), coeff_modulus_size, [&](auto J) {
uint64_t scaled_rounded_coeff = multiply_add_uint_mod(get<0>(I), get<2>(J), fix[0], get<1>(J));
get<0>(J)[coeff_index] = add_uint_mod(get<0>(J)[coeff_index], scaled_rounded_coeff, get<1>(J));
});
});
}
void multiply_sub_plain_with_scaling_variant(
const Plaintext &plain, const SEALContext::ContextData &context_data, RNSIter destination)
{
auto &parms = context_data.parms();
size_t plain_coeff_count = plain.coeff_count();
auto &coeff_modulus = parms.coeff_modulus();
size_t coeff_modulus_size = coeff_modulus.size();
auto plain_modulus = context_data.parms().plain_modulus();
auto coeff_div_plain_modulus = context_data.coeff_div_plain_modulus();
uint64_t plain_upper_half_threshold = context_data.plain_upper_half_threshold();
uint64_t q_mod_t = context_data.coeff_modulus_mod_plain_modulus();
#ifdef SEAL_DEBUG
if (plain_coeff_count > parms.poly_modulus_degree())
{
throw std::invalid_argument("invalid plaintext");
}
if (destination.poly_modulus_degree() != parms.poly_modulus_degree())
{
throw invalid_argument("destination is not valid for encryption parameters");
}
#endif
// Coefficients of plain m multiplied by coeff_modulus q, divided by plain_modulus t,
// and rounded to the nearest integer (rounded up in case of a tie). Equivalent to
// floor((q * m + floor((t+1) / 2)) / t).
SEAL_ITERATE(iter(plain.data(), size_t(0)), plain_coeff_count, [&](auto I) {
// Compute numerator = (q mod t) * m[i] + (t+1)/2
unsigned long long prod[2]{ 0, 0 };
uint64_t numerator[2]{ 0, 0 };
multiply_uint64(get<0>(I), q_mod_t, prod);
unsigned char carry = add_uint64(*prod, plain_upper_half_threshold, numerator);
numerator[1] = static_cast<uint64_t>(prod[1]) + static_cast<uint64_t>(carry);
// Compute fix[0] = floor(numerator / t)
uint64_t fix[2] = { 0, 0 };
divide_uint128_inplace(numerator, plain_modulus.value(), fix);
// Add to ciphertext: floor(q / t) * m + increment
size_t coeff_index = get<1>(I);
SEAL_ITERATE(
iter(destination, coeff_modulus, coeff_div_plain_modulus), coeff_modulus_size, [&](auto J) {
uint64_t scaled_rounded_coeff = multiply_add_uint_mod(get<0>(I), get<2>(J), fix[0], get<1>(J));
get<0>(J)[coeff_index] = sub_uint_mod(get<0>(J)[coeff_index], scaled_rounded_coeff, get<1>(J));
});
});
}
} // namespace util
} // namespace seal