forked from jvdsn/crypto-attacks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
partial_integer.py
478 lines (421 loc) · 19.4 KB
/
partial_integer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
from math import log2
class PartialInteger:
"""
Represents positive integers with some known and some unknown bits.
"""
def __init__(self):
"""
Constructs a new PartialInteger with total bit length 0 and no components.
"""
self.bit_length = 0
self.unknowns = 0
self._components = []
def add_known(self, value, bit_length):
"""
Adds a known component to the msb of this PartialInteger.
:param value: the value of the component
:param bit_length: the bit length of the component
:return: this PartialInteger, with the component added to the msb
"""
self.bit_length += bit_length
self._components.append((value, bit_length))
return self
def add_unknown(self, bit_length):
"""
Adds an unknown component to the msb of this PartialInteger.
:param bit_length: the bit length of the component
:return: this PartialInteger, with the component added to the msb
"""
self.bit_length += bit_length
self.unknowns += 1
self._components.append((None, bit_length))
return self
def get_known_lsb(self):
"""
Returns all known lsb in this PartialInteger.
This method can cross multiple known components, but stops once an unknown component is encountered.
:return: a tuple containing the known lsb and the bit length of the known lsb
"""
lsb = 0
lsb_bit_length = 0
for value, bit_length in self._components:
if value is None:
return lsb, lsb_bit_length
lsb = lsb + (value << lsb_bit_length)
lsb_bit_length += bit_length
return lsb, lsb_bit_length
def get_known_msb(self):
"""
Returns all known msb in this PartialInteger.
This method can cross multiple known components, but stops once an unknown component is encountered.
:return: a tuple containing the known msb and the bit length of the known msb
"""
msb = 0
msb_bit_length = 0
for value, bit_length in reversed(self._components):
if value is None:
return msb, msb_bit_length
msb = (msb << bit_length) + value
msb_bit_length += bit_length
return msb, msb_bit_length
def get_known_middle(self):
"""
Returns all known middle bits in this PartialInteger.
This method can cross multiple known components, but stops once an unknown component is encountered.
:return: a tuple containing the known middle bits and the bit length of the known middle bits
"""
middle = 0
middle_bit_length = 0
for value, bit_length in self._components:
if value is None:
if middle_bit_length > 0:
return middle, middle_bit_length
else:
middle = middle + (value << middle_bit_length)
middle_bit_length += bit_length
return middle, middle_bit_length
def get_unknown_lsb(self):
"""
Returns the bit length of the unknown lsb in this PartialInteger.
This method can cross multiple unknown components, but stops once a known component is encountered.
:return: the bit length of the unknown lsb
"""
lsb_bit_length = 0
for value, bit_length in self._components:
if value is not None:
return lsb_bit_length
lsb_bit_length += bit_length
return lsb_bit_length
def get_unknown_msb(self):
"""
Returns the bit length of the unknown msb in this PartialInteger.
This method can cross multiple unknown components, but stops once a known component is encountered.
:return: the bit length of the unknown msb
"""
msb_bit_length = 0
for value, bit_length in reversed(self._components):
if value is not None:
return msb_bit_length
msb_bit_length += bit_length
return msb_bit_length
def get_unknown_middle(self):
"""
Returns the bit length of the unknown middle bits in this PartialInteger.
This method can cross multiple unknown components, but stops once a known component is encountered.
:return: the bit length of the unknown middle bits
"""
middle_bit_length = 0
for value, bit_length in self._components:
if value is None:
if middle_bit_length > 0:
return middle_bit_length
else:
middle_bit_length += bit_length
return middle_bit_length
def matches(self, i):
"""
Returns whether this PartialInteger matches an integer, that is, all known bits are equal.
:param i: the integer
:return: True if this PartialInteger matches i, False otherwise
"""
shift = 0
for value, bit_length in self._components:
if value is not None and (i >> shift) % (2 ** bit_length) != value:
return False
shift += bit_length
return True
def sub(self, unknowns):
"""
Substitutes some values for the unknown components in this PartialInteger.
These values can be symbolic (e.g. Sage variables)
:param unknowns: the unknowns
:return: an integer or expression with the unknowns substituted
"""
assert len(unknowns) == self.unknowns
i = 0
j = 0
shift = 0
for value, bit_length in self._components:
if value is None:
# We don't shift here because the unknown could be a symbolic variable
i += 2 ** shift * unknowns[j]
j += 1
else:
i += value << shift
shift += bit_length
return i
def get_known_and_unknowns(self):
"""
Returns i_, o, and l such that this integer i = i_ + sum(2^(o_j) * i_j) with i_j < 2^(l_j).
:return: a tuple of i_, o, and l
"""
i_ = 0
o = []
l = []
offset = 0
for value, bit_length in self._components:
if value is None:
o.append(offset)
l.append(bit_length)
else:
i_ += 2 ** offset * value
offset += bit_length
return i_, o, l
def get_unknown_bounds(self):
"""
Returns a list of bounds on each of the unknowns in this PartialInteger.
A bound is simply 2^l with l the bit length of the unknown.
:return: the list of bounds
"""
return [2 ** bit_length for value, bit_length in self._components if value is None]
def to_int(self):
"""
Converts this PartialInteger to an int.
The number of unknowns must be zero.
:return: the int represented by this PartialInteger
"""
assert self.unknowns == 0
return self.sub([])
def to_string_le(self, base, symbols="0123456789abcdefghijklmnopqrstuvwxyz"):
"""
Converts this PartialInteger to a list of characters in the provided base (little endian).
:param base: the base, must be a power of two and less than or equal to 36
:param symbols: the symbols to use, at least as many as base (default: "0123456789abcdefghijklmnopqrstuvwxyz")
:return: the list of characters, with '?' representing an unknown digit
"""
assert (base & (base - 1)) == 0, "Base must be power of two."
assert base <= 36
assert len(symbols) >= base
bits_per_element = int(log2(base))
chars = []
for value, bit_length in self._components:
assert bit_length % bits_per_element == 0, f"Component with bit length {bit_length} can't be represented by base {base} digits"
for _ in range(bit_length // bits_per_element):
if value is None:
chars.append('?')
else:
chars.append(symbols[value % base])
value //= base
return chars
def to_string_be(self, base, symbols="0123456789abcdefghijklmnopqrstuvwxyz"):
"""
Converts this PartialInteger to a list of characters in the provided base (big endian).
:param base: the base, must be a power of two and less than or equal to 36
:param symbols: the symbols to use, at least as many as base (default: "0123456789abcdefghijklmnopqrstuvwxyz")
:return: the list of characters, with '?' representing an unknown digit
"""
return self.to_string_le(base, symbols)[::-1]
def to_bits_le(self, symbols="01"):
"""
Converts this PartialInteger to a list of bit characters (little endian).
:param symbols: the two symbols to use (default: "01")
:return: the list of bit characters, with '?' representing an unknown bit
"""
assert len(symbols) == 2
return self.to_string_le(2, symbols)
def to_bits_be(self, symbols="01"):
"""
Converts this PartialInteger to a list of bit characters (big endian).
:param symbols: the two symbols to use (default: "01")
:return: the list of bit characters, with '?' representing an unknown bit
"""
return self.to_bits_le(symbols)[::-1]
def to_hex_le(self, symbols="0123456789abcdef"):
"""
Converts this PartialInteger to a list of hex characters (little endian).
:param symbols: the 16 symbols to use (default: "0123456789abcdef")
:return: the list of hex characters, with '?' representing an unknown nibble
"""
assert len(symbols) == 16
return self.to_string_le(16, symbols)
def to_hex_be(self, symbols="0123456789abcdef"):
"""
Converts this PartialInteger to a list of hex characters (big endian).
:param symbols: the 16 symbols to use (default: "0123456789abcdef")
:return: the list of hex characters, with '?' representing an unknown nibble
"""
return self.to_hex_le(symbols)[::-1]
@staticmethod
def unknown(bit_length):
return PartialInteger().add_unknown(bit_length)
@staticmethod
def parse_le(digits, base):
"""
Constructs a PartialInteger from arbitrary digits in a provided base (little endian).
:param digits: the digits (string with '?' representing unknown or list with '?'/None representing unknown)
:param base: the base, must be a power of two and less than or equal to 36
:return: a PartialInteger with known and unknown components as indicated by the digits
"""
assert (base & (base - 1)) == 0, "Base must be power of two."
assert base <= 36
bits_per_element = int(log2(base))
p = PartialInteger()
rc_k = 0
rc_u = 0
value = 0
for digit in digits:
if digit is None or digit == '?':
if rc_k > 0:
p.add_known(value, rc_k * bits_per_element)
rc_k = 0
value = 0
rc_u += 1
else:
if isinstance(digit, str):
digit = int(digit, base)
assert 0 <= digit < base
if rc_u > 0:
p.add_unknown(rc_u * bits_per_element)
rc_u = 0
value += digit * base ** rc_k
rc_k += 1
if rc_k > 0:
p.add_known(value, rc_k * bits_per_element)
if rc_u > 0:
p.add_unknown(rc_u * bits_per_element)
return p
@staticmethod
def parse_be(digits, base):
"""
Constructs a PartialInteger from arbitrary digits in a provided base (big endian).
:param digits: the digits (string with '?' representing unknown or list with '?'/None representing unknown)
:param base: the base (must be a power of two and less than or equal to 36)
:return: a PartialInteger with known and unknown components as indicated by the digits
"""
return PartialInteger.parse_le(reversed(digits), base)
@staticmethod
def from_bits_le(bits):
"""
Constructs a PartialInteger from bits (little endian).
:param bits: the bits (string with '?' representing unknown or list with '?'/None representing unknown)
:return: a PartialInteger with known and unknown components as indicated by the bits
"""
return PartialInteger.parse_le(bits, 2)
@staticmethod
def from_bits_be(bits):
"""
Constructs a PartialInteger from bits (big endian).
:param bits: the bits (string with '?' representing unknown or list with '?'/None representing unknown)
:return: a PartialInteger with known and unknown components as indicated by the bits
"""
return PartialInteger.from_bits_le(reversed(bits))
@staticmethod
def from_hex_le(hex):
"""
Constructs a PartialInteger from hex characters (little endian).
:param hex: the hex characters (string with '?' representing unknown or list with '?'/None representing unknown)
:return: a PartialInteger with known and unknown components as indicated by the hex characters
"""
return PartialInteger.parse_le(hex, 16)
@staticmethod
def from_hex_be(hex):
"""
Constructs a PartialInteger from hex characters (big endian).
:param hex: the hex characters (string with '?' representing unknown or list with '?'/None representing unknown)
:return: a PartialInteger with known and unknown components as indicated by the hex characters
"""
return PartialInteger.from_hex_le(reversed(hex))
@staticmethod
def from_lsb(bit_length, lsb, lsb_bit_length):
"""
Constructs a PartialInteger from some known lsb, setting the msb to unknown.
:param bit_length: the total bit length of the integer
:param lsb: the known lsb
:param lsb_bit_length: the bit length of the known lsb
:return: a PartialInteger with one known component (the lsb) and one unknown component (the msb)
"""
assert bit_length >= lsb_bit_length
assert 0 <= lsb <= (2 ** lsb_bit_length)
return PartialInteger().add_known(lsb, lsb_bit_length).add_unknown(bit_length - lsb_bit_length)
@staticmethod
def from_msb(bit_length, msb, msb_bit_length):
"""
Constructs a PartialInteger from some known msb, setting the lsb to unknown.
:param bit_length: the total bit length of the integer
:param msb: the known msb
:param msb_bit_length: the bit length of the known msb
:return: a PartialInteger with one known component (the msb) and one unknown component (the lsb)
"""
assert bit_length >= msb_bit_length
assert 0 <= msb < (2 ** msb_bit_length)
return PartialInteger().add_unknown(bit_length - msb_bit_length).add_known(msb, msb_bit_length)
@staticmethod
def from_lsb_and_msb(bit_length, lsb, lsb_bit_length, msb, msb_bit_length):
"""
Constructs a PartialInteger from some known lsb and msb, setting the middle bits to unknown.
:param bit_length: the total bit length of the integer
:param lsb: the known lsb
:param lsb_bit_length: the bit length of the known lsb
:param msb: the known msb
:param msb_bit_length: the bit length of the known msb
:return: a PartialInteger with two known components (the lsb and msb) and one unknown component (the middle bits)
"""
assert bit_length >= lsb_bit_length + msb_bit_length
assert 0 <= lsb < (2 ** lsb_bit_length)
assert 0 <= msb < (2 ** msb_bit_length)
middle_bit_length = bit_length - lsb_bit_length - msb_bit_length
return PartialInteger().add_known(lsb, lsb_bit_length).add_unknown(middle_bit_length).add_known(msb, msb_bit_length)
@staticmethod
def from_middle(middle, middle_bit_length, lsb_bit_length, msb_bit_length):
"""
Constructs a PartialInteger from some known middle bits, setting the lsb and msb to unknown.
:param middle: the known middle bits
:param middle_bit_length: the bit length of the known middle bits
:param lsb_bit_length: the bit length of the unknown lsb
:param msb_bit_length: the bit length of the unknown msb
:return: a PartialInteger with one known component (the middle bits) and two unknown components (the lsb and msb)
"""
assert 0 <= middle < (2 ** middle_bit_length)
return PartialInteger().add_unknown(lsb_bit_length).add_known(middle, middle_bit_length).add_unknown(msb_bit_length)
@staticmethod
def lsb_of(i, bit_length, lsb_bit_length):
"""
Constructs a PartialInteger from the lsb of a known integer, setting the msb to unknown.
Mainly used for testing purposes.
:param i: the known integer
:param bit_length: the total length of the known integer
:param lsb_bit_length: the bit length of the known lsb
:return: a PartialInteger with one known component (the lsb) and one unknown component (the msb)
"""
lsb = i % (2 ** lsb_bit_length)
return PartialInteger.from_lsb(bit_length, lsb, lsb_bit_length)
@staticmethod
def msb_of(i, bit_length, msb_bit_length):
"""
Constructs a PartialInteger from the msb of a known integer, setting the lsb to unknown.
Mainly used for testing purposes.
:param i: the known integer
:param bit_length: the total length of the known integer
:param msb_bit_length: the bit length of the known msb
:return: a PartialInteger with one known component (the msb) and one unknown component (the lsb)
"""
msb = i >> (bit_length - msb_bit_length)
return PartialInteger.from_msb(bit_length, msb, msb_bit_length)
@staticmethod
def lsb_and_msb_of(i, bit_length, lsb_bit_length, msb_bit_length):
"""
Constructs a PartialInteger from the lsb and msb of a known integer, setting the middle bits to unknown.
Mainly used for testing purposes.
:param i: the known integer
:param bit_length: the total length of the known integer
:param lsb_bit_length: the bit length of the known lsb
:param msb_bit_length: the bit length of the known msb
:return: a PartialInteger with two known components (the lsb and msb) and one unknown component (the middle bits)
"""
lsb = i % (2 ** lsb_bit_length)
msb = i >> (bit_length - msb_bit_length)
return PartialInteger.from_lsb_and_msb(bit_length, lsb, lsb_bit_length, msb, msb_bit_length)
@staticmethod
def middle_of(i, bit_length, lsb_bit_length, msb_bit_length):
"""
Constructs a PartialInteger from the middle bits of a known integer, setting the lsb and msb to unknown.
Mainly used for testing purposes.
:param i: the known integer
:param bit_length: the total length of the known integer
:param lsb_bit_length: the bit length of the unknown lsb
:param msb_bit_length: the bit length of the unknown msb
:return: a PartialInteger with one known component (the middle bits) and two unknown components (the lsb and msb)
"""
middle_bit_length = bit_length - lsb_bit_length - msb_bit_length
middle = (i >> lsb_bit_length) % (2 ** middle_bit_length)
return PartialInteger.from_middle(middle, middle_bit_length, lsb_bit_length, msb_bit_length)