diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5f9036f --- /dev/null +++ b/.gitignore @@ -0,0 +1,58 @@ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +*.abi3.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover +.hypothesis/ +.pytest_cache + +# Sphinx documentation +docs/_build/ + +# virtualenv +venv/ +ENV/ +.envrc + +# Temp files +__* +!__init__.py diff --git a/README.rst b/README.rst new file mode 100644 index 0000000..7f0cc0e --- /dev/null +++ b/README.rst @@ -0,0 +1,12 @@ +petrelic +======== + +Simple Python wrapper around RELIC + + +Development +----------- + +To start developing on `petrelic` create a local installation: + + pip3 install -v -e . diff --git a/petrelic/__init__.py b/petrelic/__init__.py new file mode 100644 index 0000000..8e4c322 --- /dev/null +++ b/petrelic/__init__.py @@ -0,0 +1,8 @@ +__version__ = "0.0.0" +__title__ = "petrelic" +__author__ = "Wouter Lueks" +__email__ = "wouter.lueks@epfl.ch" +__url__ = "https://github.com/spring-epfl/petrelic" +__license__ = "Apache 2.0" +__description__ = "A Python binding for the RELIC cryptographic library" +__copyright__ = "EPFL, Spring Lab, 2019" diff --git a/petrelic/_cffi_src/petrelic.c b/petrelic/_cffi_src/petrelic.c new file mode 100644 index 0000000..9905f42 --- /dev/null +++ b/petrelic/_cffi_src/petrelic.c @@ -0,0 +1,15 @@ +#include + +const int CONST_RLC_POS = RLC_POS; +const int CONST_RLC_NEG = RLC_NEG; +const int CONST_RLC_LT = RLC_LT; +const int CONST_RLC_EQ = RLC_EQ; +const int CONST_RLC_GT = RLC_GT; + +unsigned int get_rlc_dig(void) { + return RLC_DIG; +} + +int get_rlc_ok(void) { + return RLC_OK; +} diff --git a/petrelic/_cffi_src/petrelic.h b/petrelic/_cffi_src/petrelic.h new file mode 100644 index 0000000..39705e0 --- /dev/null +++ b/petrelic/_cffi_src/petrelic.h @@ -0,0 +1,127 @@ +int core_init(void); +int core_clean(void); +int pc_param_set_any(); +void pc_param_print(); + + +unsigned int get_rlc_dig(void); +int get_rlc_ok(void); +const int CONST_RLC_POS; +const int CONST_RLC_NEG; +const int CONST_RLC_LT; +const int CONST_RLC_EQ; +const int CONST_RLC_GT; + + +typedef uint64_t dig_t; +typedef struct { + /** The number of digits allocated to this multiple precision integer. */ + int alloc; + /** The number of digits actually used. */ + int used; + /** The sign of this multiple precision integer. */ + int sign; + /** The sequence of contiguous digits that forms this integer. */ + // rlc_align dig_t dp[RLC_BN_SIZE]; HACKITYHACK + // HACK: had to hardcode the constant, and remove rlc_align + // ALLOC = AUTO + dig_t dp[34]; +} bn_st; + +typedef bn_st bn_t[1]; +void g1_get_ord(bn_t order); + + +void bn_new(bn_t a); +void bn_copy(bn_t c, const bn_t a); + +void bn_abs(bn_t c, const bn_t a); +void bn_neg(bn_t c, const bn_t a); +int bn_sign(const bn_t a); +void bn_zero(bn_t a); +int bn_is_zero(const bn_t a); +int bn_is_even(const bn_t a); +int bn_bits(const bn_t a); +int bn_get_bit(const bn_t a, int bit); +void bn_set_bit(bn_t a, int bit, int value); +void bn_get_dig(dig_t *digit, const bn_t a); +void bn_set_2b(bn_t a, int b); +void bn_set_dig(bn_t a, dig_t digit); +void bn_rand(bn_t a, int sign, int bits); +void bn_rand_mod(bn_t a, bn_t b); + +void bn_print(const bn_t a); +int bn_size_str(const bn_t a, int radix); +void bn_read_str(bn_t a, const char *str, int len, int radix); +void bn_write_str(char *str, int len, const bn_t a, int radix); +int bn_size_bin(const bn_t a); +void bn_read_bin(bn_t a, const uint8_t *bin, int len); +void bn_write_bin(uint8_t *bin, int len, const bn_t a); + +int bn_cmp_abs(const bn_t a, const bn_t b); +int bn_cmp_dig(const bn_t a, dig_t b); +int bn_cmp(const bn_t a, const bn_t b); + +void bn_add(bn_t c, const bn_t a, const bn_t b); +void bn_add_dig(bn_t c, const bn_t a, dig_t b); +void bn_sub(bn_t c, const bn_t a, const bn_t b); +void bn_sub_dig(bn_t c, const bn_t a, const dig_t b); + +void bn_mul(bn_t c, const bn_t a, const bn_t b); +void bn_mul_dig(bn_t c, const bn_t a, dig_t b); + +void bn_sqr(bn_t c, const bn_t a); +void bn_dbl(bn_t c, const bn_t a); +void bn_hlv(bn_t c, const bn_t a); +void bn_lsh(bn_t c, const bn_t a, int bits); +void bn_rsh(bn_t c, const bn_t a, int bits); + +void bn_div(bn_t c, const bn_t a, const bn_t b); +void bn_div_rem(bn_t c, bn_t d, const bn_t a, const bn_t b); + +void bn_mod_2b(bn_t c, const bn_t a, int b); +void bn_mod(bn_t c, const bn_t a, const bn_t m); + +void bn_gcd(bn_t c, const bn_t a, const bn_t b); +void bn_gcd_ext(bn_t c, bn_t d, bn_t e, const bn_t a, const bn_t b); + +int bn_is_prime(const bn_t a); +void bn_gen_prime(bn_t a, int bits); +void bn_gen_prime_safep(bn_t a, int bits); +void bn_gen_prime_stron(bn_t a, int bits); + + +// HACK: rlc_align removed, hardcoded size of array +// ORIG: typedef rlc_align dig_t fp_t[RLC_FP_DIGS + RLC_PAD(RLC_FP_BYTES)/(RLC_DIG / 8)]; +// ORIG: typedef rlc_align dig_t fp_st[RLC_FP_DIGS + RLC_PAD(RLC_FP_BYTES)/(RLC_DIG / 8)]; +typedef dig_t fp_t[6]; +typedef dig_t fp_st[6]; + +typedef uint8_t appel; + +typedef struct { + /** The first coordinate. */ + fp_st x; + /** The second coordinate. */ + fp_st y; + /** The third coordinate (projective representation). */ + fp_st z; + /** Flag to indicate that this point is normalized. */ + int norm; +} ep_st; + +typedef ep_st ep_t[1]; + +typedef ep_t g1_t; +// typedef ep2_t g2_t; +// typedef fp12_t gt_t; + +typedef ep_st g1_st; +// typedef ep2_st g2_st; +// typedef fp12_st gt_st; + + +/* void g1_null(g1_t p) */ +void g1_new(g1_t p); +void g1_rand(g1_t p); +void g1_print(g1_t p); diff --git a/petrelic/bindings.py b/petrelic/bindings.py new file mode 100644 index 0000000..d3b5287 --- /dev/null +++ b/petrelic/bindings.py @@ -0,0 +1,24 @@ +from petrelic._petrelic import ffi, lib + +_FFI = ffi +_C = lib + +RLC_OK = _C.get_rlc_ok() + +class RelicInitializer: + initialized = False + + def __init__(self): + if not RelicInitializer.initialized: + self.__initialize_relic() + RelicInitializer.initialized = True + + def __initialize_relic(self): + print("Initializing RELIC") + if _C.core_init() != RLC_OK: + raise RuntimeError("Could not initialize RELIC") + +# Initializing RELIC +RelicInitializer() + + diff --git a/petrelic/bn.py b/petrelic/bn.py new file mode 100644 index 0000000..1e6f157 --- /dev/null +++ b/petrelic/bn.py @@ -0,0 +1,669 @@ +from petrelic.bindings import _FFI, _C +import petrelic.constants as consts + +import functools +import re + +def force_Bn(n): + """A decorator that coerces the nth input to be a Big Number""" + + def decorator_force_Bn(func): + # pylint: disable=star-args + @functools.wraps(func) + def wrapper(*args, **kwargs): + new_args = args + if n < len(args) and not isinstance(args[n], Bn): + if isinstance(args[n], int): + new_args = list(args) + new_args[n] = Bn(new_args[n]) + new_args = tuple(new_args) + else: + # Don't know how to convert + raise TypeError("Cannot convert argument ", n); + + return func(*new_args, **kwargs) + return wrapper + return decorator_force_Bn + +def force_Bn_other(func): + return force_Bn(1)(func) + + +class Bn(object): + + __slots__ = ["bn"] + + @staticmethod + def from_num(num): + if isinstance(num, int): + return Bn(num) + elif isinstance(num, Bn): + return num + else: + # raise TypeError("Cannot coerce %s into a BN." % num) + return NotImplemented + + @staticmethod + def _from_radix_string(sinput, radix): + neg = False + if sinput[0] == '-': + neg = True + sinput = sinput[1:] + + ret = Bn() + s = sinput.encode("utf8") + _C.bn_read_str(ret.bn, s, len(s), radix) + + if neg: + return ret.__neg__() + else: + return ret + + + @staticmethod + def from_decimal(sdec): + """Creates a Big Number from a decimal string. + + Args: + sdec (string): numeric string possibly starting with minus. + + See Also: + str() produces a decimal string from a big number. + + Example: + >>> hundred = Bn.from_decimal("100") + >>> str(hundred) + '100' + """ + + if not re.match("-?[0-9]+$", sdec): + raise Exception("String must only contain digits 0--9 and sign") + + return Bn._from_radix_string(sdec, 10) + + + @staticmethod + def from_hex(shex): + """Creates a Big Number from a hexadecimal string. + + Args: + shex (string): hex (0-F) string possibly starting with minus. + + See Also: + hex() produces a hexadecimal representation of a big number. + + Example: + >>> Bn.from_hex("FF") + Bn(255) + """ + + if not re.match("-?[0-9a-fA-F]+$", shex): + raise Exception("String must only contain digits 0--9,a--f and sign") + + shex = shex.lower() + return Bn._from_radix_string(shex, 16) + + @staticmethod + def from_binary(sbin): + """ Restore number given its Big-endian representation. + + Creates a Big Number from a byte sequence representing the number in + Big-endian 8 byte atoms. Only positive values can be represented as + byte sequence, and the library user should store the sign bit + separately. + + Args: + sbin (string): a byte sequence. + + Example: + >>> from binascii import unhexlify + >>> byte_seq = unhexlify(b"010203") + >>> Bn.from_binary(byte_seq) + Bn(66051) + >>> (1 * 256**2) + (2 * 256) + 3 + 66051 + """ + ret = Bn() + _C.bn_read_bin(ret.bn, sbin, len(sbin)) + return ret + + @staticmethod + def get_prime(bits, safe=1): + """ + Builds a prime Big Number of length bits. + + Args: + bits (int) -- the number of bits. + safe (int) -- 1 for a safe prime, otherwise 0. + """ + + ret = Bn() + if safe == 1: + _C.bn_gen_prime_safep(ret.bn, bits) + else: + _C.bn_gen_prime(ret.bn, bits) + + return ret + + def __init__(self, num=0): + """Initialize a new Bn, initialized with a small integer""" + self.bn = _FFI.new("bn_t") + _C.bn_new(self.bn) + + # TODO: Fix parsing, to support converting bigger numbers + if num >= consts.DIGIT_MAXIMUM: + raise Exception("num does not fit directly inside Bn") + _C.bn_set_dig(self.bn, abs(num)) + + if num < 0: + _C.bn_neg(self.bn, self.bn) + + def copy(self): + """Returns a copy of the Bn object.""" + return self.__copy__() + + def __copy__(self): + # 'Copies the big number. Support for copy module' + other = Bn() + _C.bn_copy(other.bn, self.bn) + return other + + def __deepcopy__(self, memento): + # 'Deepcopy is the same as copy' + # pylint: disable=unused-argument + return self.__copy__() + + # ------------ Comparisons ------------ + + @force_Bn_other + def __inner_cmp__(self, other): + sig = int(_C.bn_cmp(self.bn, other.bn)) + return sig + + def __lt__(self, other): + return self.__inner_cmp__(other) < 0 + + def __le__(self, other): + return self.__inner_cmp__(other) <= 0 + + def __eq__(self, other): + return self.__inner_cmp__(other) == 0 + + def __ne__(self, other): + return self.__inner_cmp__(other) != 0 + + def __gt__(self, other): + return self.__inner_cmp__(other) > 0 + + def __ge__(self, other): + return self.__inner_cmp__(other) >= 0 + + def bool(self): + """Turn Bn into boolean. False if zero, True otherwise. + + Examples: + >>> bool(Bn(0)) + False + >>> bool(Bn(1337)) + True + + """ + return self.__bool__() + + def __bool__(self): + # 'Turn into boolean' + return not bool(_C.bn_is_zero(self.bn)) + + + def repr(self): + return self.__repr__() + + def __repr__(self): + # TODO: return value may be too big, in which case it cannot be recovered + return 'Bn({})'.format(self.repr_in_base(10)) + + def __str__(self): + return self.repr_in_base(10) + + def int(self): + """A native python integer representation of the Big Number. + Synonym for int(bn). + """ + return self.__int__() + + def __int__(self): + return int(self.repr_in_base(10)) + + def __index__(self): + return self.__int__() + + def hex(self): + """The representation of the string in hexadecimal. + Synonym for hex(n).""" + return self.__hex__() + + def __hex__(self): + # """The representation of the string in hexadecimal""" + return self.repr_in_base(16) + + # NOT_IMPLEMENTED: hex() + + def binary(self): + """A byte array representing the absolute value + + A byte array representation of the absolute value in Big-Endian format + (with 8 bit atomic elements). You are responsible for extracting the + sign bit separately. + + Example: + >>> from binascii import hexlify + + >>> bin = Bn(66051).binary() + >>> hexlify(bin) == b'010203' + True + + >>> bin = Bn(1337).binary() + >>> hexlify(bin) == b'0539' + True + + """ + if self < 0: + raise Exception("Cannot represent negative numbers") + + length = _C.bn_size_bin(self.bn) + buf = _FFI.new("char[]", length) + _C.bn_write_bin(buf, length, self.bn) + return _FFI.unpack(buf, length) + + def repr_in_base(self, radix): + """ Represent number as string in given base + + Args: + radix (int): The number of unique digits (2 <= radix <= 62) + + Examples: + >>> Bn(42).repr_in_base(16) + '2A' + >>> Bn(-1024).repr_in_base(2) + '-10000000000' + """ + length = _C.bn_size_str(self.bn, radix) + buf = _FFI.new("char[]", length) + _C.bn_write_str(buf, length, self.bn, radix) + return _FFI.string(buf).decode("utf8") + + def test(self): + """ + >>> b = Bn() + >>> b.repr_in_base(2) + '0' + """ + + def random(self): + """Returns a random number 0 < rand < self + + TODO: currently it excludes 0, update to include 0 + + Example: + #>>> r = Bn(100).random() + #>>> 0 <= r && r < 100 + True + """ + rnd = Bn() + _C.bn_rand_mod(rnd.bn, self.bn) + return rnd + + # ---------- Arithmetic -------------- + def int_neg(self): + """Returns the negative of this number. Synonym with -self. + + Example: + + >>> one100 = Bn(100) + >>> one100.int_neg() + Bn(-100) + >>> -one100 + Bn(-100) + """ + return self.__neg__() + + def int_add(self, other): + """Returns the sum of this number with another. Synonym for self + other. + + Example: + >>> one100 = Bn(100) + >>> two100 = Bn(200) + >>> two100.int_add(one100) # Function syntax + Bn(300) + >>> two100 + one100 # Operator syntax + Bn(300) + """ + return self.__add__(other) + + def __radd__(self, other): + return self.__add__(other) + + @force_Bn_other + def __add__(self, other): + r = Bn() + _C.bn_add(r.bn, self.bn, other.bn) + return r + + def int_sub(self, other): + """Returns the difference between this number and another. + Synonym for self - other. + + Example: + >>> one100 = Bn(100) + >>> two100 = Bn(200) + >>> two100.int_sub(one100) # Function syntax + Bn(100) + >>> two100 - one100 # Operator syntax + Bn(100) + """ + return self.__sub__(other) + + @force_Bn_other + def __rsub__(self, other): + return other - self + + @force_Bn_other + def __sub__(self, other): + r = Bn() + _C.bn_sub(r.bn, self.bn, other.bn) + return r + + def int_mul(self, other): + """Returns the product of this number with another. + Synonym for self * other. + + Example: + >>> one100 = Bn(100) + >>> two100 = Bn(200) + >>> one100.int_mul(two100) # Function syntax + Bn(20000) + >>> one100 * two100 # Operator syntax + Bn(20000) + """ + return self.__mul__(other) + + def __rmul__(self, other): + return self.__mul__(other) + + @force_Bn_other + def __mul__(self, other): + r = Bn() + _C.bn_mul(r.bn, self.bn, other.bn) + return r + + def __neg__(self): + ret = Bn() + _C.bn_neg(ret.bn, self.bn) + return ret + + # ------------------ Mod arithmetic ------------------------- + + @force_Bn(1) + @force_Bn(2) + def mod_add(self, other, m): + """ Returns the sum of self and other modulo m. + + Example: + >>> Bn(10).mod_add(2, 11) + Bn(1) + >>> Bn(10).mod_add(Bn(2), Bn(11)) + Bn(1) + """ + + r = Bn() + _C.bn_add(r.bn, self.bn, other.bn) + _C.bn_mod(r.bn, r.bn, m.bn) + return r + + @force_Bn(1) + @force_Bn(2) + def mod_sub(self, other, m): + """ Returns the difference of self and other modulo m. + + Example: + >>> Bn(10).mod_sub(Bn(2), Bn(11)) + Bn(8) + """ + + r = Bn() + _C.bn_sub(r.bn, self.bn, other.bn) + _C.bn_mod(r.bn, r.bn, m.bn) + return r + + @force_Bn(1) + @force_Bn(2) + def mod_mul(self, other, m): + """ Return the product of self and other modulo m. + + Example: + >>> Bn(10).mod_mul(Bn(2), Bn(11)) + Bn(9) + """ + + r = Bn() + _C.bn_mul(r.bn, self.bn, other.bn) + _C.bn_mod(r.bn, r.bn, m.bn) + return r + + @force_Bn_other + def mod_inverse(self, m): + """ Compute the inverse mod m, such that self * res == 1 mod m. + + Example: + >>> Bn(10).mod_inverse(m = Bn(11)) + Bn(10) + >>> Bn(10).mod_mul(Bn(10), m = Bn(11)) == Bn(1) + True + """ + gcd = Bn() + inv = Bn() + _C.bn_gcd_ext(gcd.bn, inv.bn, _FFI.NULL, self.bn, m.bn) + _C.bn_mod(inv.bn, inv.bn, m.bn) + + if gcd != Bn(1): + raise Exception("No inverse for ", self, "modulo ", m) + + return inv + + def mod_pow(self, other, m, ctx=None): + """ Performs the modular exponentiation of self ** other % m. + + This function is _not_ constant time. + + Example: + >>> one100 = Bn(100) + >>> one100.mod_pow(2, 3) # Modular exponentiation + Bn(1) + """ + return self.__pow__(other, m) + + def divmod(self, other): + """Returns the integer division and remainder of this number by another. + + Example: + >>> Bn(13).divmod(Bn(9)) + (Bn(1), Bn(4)) + """ + return self.__divmod__(other) + + def __rdivmod__(self, other): + return Bn(other).__divmod__(self) + + @force_Bn_other + def __divmod__(self, other): + quot = Bn() + rem = Bn() + _C.bn_div_rem(quot.bn, rem.bn, self.bn, other.bn) + return (quot, rem) + + def int_div(self, other): + """Returns the integer division of this number by another. + Synonym of self / other. + + Example: + >>> one100 = Bn(100) + >>> two100 = Bn(200) + >>> two100.int_div(one100) # Function syntax + Bn(2) + >>> two100 // one100 # Operator syntax + Bn(2) + """ + return self.__floordiv__(other) + + @force_Bn_other + def __rfloordiv__(self, other): + return other.__floordiv__(self) + + @force_Bn_other + def __floordiv__(self, other): + quot = Bn() + _C.bn_div(quot.bn, self.bn, other.bn) + return quot + + def mod(self, other): + """Returns the remainder of this number modulo another. + Synonym for self % other. + + Example: + >>> one100 = Bn(100) + >>> two100 = Bn(200) + >>> two100.mod(one100) # Function syntax + Bn(0) + >>> two100 % one100 # Operator syntax + Bn(0) + """ + return self.__mod__(other) + + @force_Bn_other + def __rmod__(self, other): + return other.__mod__(self) + + @force_Bn_other + def __mod__(self, other): + rem = Bn() + _C.bn_mod(rem.bn, self.bn, other.bn) + return rem + + def pow(self, other, modulo=None): + """Returns the number raised to the power other optionally modulo a third number. + Synonym with pow(self, other, modulo). + + Example: + >>> one100 = Bn(100) + >>> one100.pow(2) # Function syntax + Bn(10000) + >>> one100 ** 2 # Operator syntax + Bn(10000) + >>> one100.pow(2, 3) # Modular exponentiation + Bn(1) + """ + return self.__pow__(other, modulo) + + @force_Bn_other + def __rpow__(self, other): + return other.__pow__(self) + + @force_Bn_other + def __pow__(self, n, modulo=None): + if n < 0 and modulo is None: + raise ArithmeticError("Negative exponent only supported when modulus is set") + + # TODO: fix coercions later + if type(modulo) == int: + modulo = Bn(modulo) + + base = Bn() + _C.bn_copy(base.bn, self.bn) + + if _C.bn_sign(n.bn) == _C.CONST_RLC_NEG: + base = base.mod_inverse(modulo) + _C.bn_neg(n.bn) + + if _C.bn_is_zero(n.bn) == 1: + return Bn(1) + + res = Bn(1) + if modulo is None: + while not bool(_C.bn_is_zero(n.bn)): + if n.is_odd(): + res = res * base + + _C.bn_sqr(base.bn, base.bn) + _C.bn_hlv(n.bn, n.bn) + return res + else: + # TODO: rewrite using bn_mxp + while not bool(_C.bn_is_zero(n.bn)): + if n.is_odd(): + res = res.mod_mul(base, modulo) + + base = base.mod_mul(base, modulo) + _C.bn_hlv(n.bn, n.bn) + return res + + + def is_prime(self): + """Returns True if the number is prime, with negligible prob. of error. + + Examples: + >>> Bn(37).is_prime() + True + >>> Bn(10).is_prime() + False + """ + return _C.bn_is_prime(self.bn) == 1 + + def is_odd(self): + """Returns True if the number is odd. + + Examples: + >>> Bn(2).is_odd() + False + >>> Bn(1337).is_odd() + True + """ + + return not bool(_C.bn_is_even(self.bn)) + + def is_even(self): + """Returns True if the number is even. + + Examples: + >>> Bn(2).is_even() + True + >>> Bn(1337).is_even() + False + """ + + return bool(_C.bn_is_even(self.bn)) + + def is_bit_set(self, n): + """Returns True if the nth bit is set + + Examples: + >>> a = Bn(17) # in binary 10001 + >>> a.is_bit_set(0) + True + >>> a.is_bit_set(1) + False + >>> a.is_bit_set(4) + True + """ + bit = _C.bn_get_bit(self.bn, n) + return bool(bit) + + def num_bits(self): + """Returns the number of bits representing this Big Number""" + return int(_C.bn_bits(self.bn)) + + + def __hash__(self): + return int(self).__hash__() + + + diff --git a/petrelic/compile.py b/petrelic/compile.py new file mode 100644 index 0000000..e4eda8a --- /dev/null +++ b/petrelic/compile.py @@ -0,0 +1,27 @@ +from cffi import FFI +import os + +CURRENT_PATH = os.path.abspath(os.path.dirname(__file__)) +RELIC_BINDINGS_PATH = os.path.join(CURRENT_PATH, '_cffi_src') + +def get_bindings_file(filename): + src_path = os.path.join(RELIC_BINDINGS_PATH, filename) + with open(src_path) as src_file: + return src_file.read() + + + +relic_bindings_defs = get_bindings_file("petrelic.h") +relic_bindings_src = get_bindings_file("petrelic.c") + +_FFI = FFI() +_FFI.set_source("petrelic._petrelic", relic_bindings_src, + library_dirs=["/home/vagrant/local/lib"], + include_dirs=["/home/vagrant/local/include"], + libraries=['relic']) +_FFI.cdef(relic_bindings_defs) + + +if __name__ == "__main__": + print("Compiling petrelic bindings for RELIC") + _FFI.compile(verbose=True) diff --git a/petrelic/constants.py b/petrelic/constants.py new file mode 100644 index 0000000..8eb5bb0 --- /dev/null +++ b/petrelic/constants.py @@ -0,0 +1,6 @@ +from .bindings import _FFI, _C + +RLC_DIG = _C.get_rlc_dig() +RLC_OK = _C.get_rlc_ok() + +DIGIT_MAXIMUM = 2 ** RLC_DIG - 1 diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..2bed0f3 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = --doctest-modules \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..14edc94 --- /dev/null +++ b/setup.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python + +import os +import re + +from setuptools import setup + + +PACKAGE_NAME = "petrelic" + +SETUP_REQUIRES = ["pytest-runner", "cffi>=1.0.0"] + +TEST_REQUIRES = ["pytest"] + +INSTALL_REQUIRES = ["cffi>=1.0.0"] + +DEV_REQUIRES = TEST_REQUIRES + ["sphinx", "sphinx_rtd_theme", "black"] + +CFFI_MODULES = "petrelic/compile.py:_FFI" + + +# Import README as long description +here = os.path.abspath(os.path.dirname(__file__)) +with open(os.path.join(here, "README.rst")) as f: + long_description = f.read() + + +# Obtain settings from __init__.py +with open(os.path.join(here, PACKAGE_NAME, "__init__.py")) as f: + matches = re.findall(r"(__.+__) = \"(.*)\"", f.read()) + for var_name, var_value in matches: + globals()[var_name] = var_value + + +setup( + name=__title__, + version=__version__, + description=__description__, + long_description=long_description, + author=__author__, + author_email=__email__, + packages=[PACKAGE_NAME], + license=__license__, + url=__url__, + install_requires=INSTALL_REQUIRES, + setup_requires=SETUP_REQUIRES, + tests_require=TEST_REQUIRES, + extras_require={"dev": DEV_REQUIRES, "test": TEST_REQUIRES}, + cffi_modules=CFFI_MODULES, + classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Natural Language :: English", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Security :: Cryptography", + "License :: OSI Approved :: Apache Software License", + ], + zip_safe=False, +) diff --git a/tests/test_bn.py b/tests/test_bn.py new file mode 100644 index 0000000..a4b528a --- /dev/null +++ b/tests/test_bn.py @@ -0,0 +1,195 @@ +from petrelic.bn import Bn + +import pytest +from copy import copy, deepcopy + +def test_bn_constructors(): + assert Bn.from_decimal("100") == 100 + assert Bn.from_decimal("-100") == -100 + + with pytest.raises(Exception): + Bn.from_decimal("100ABC") + + with pytest.raises(Exception): + Bn.from_hex("100ABCZ") + + assert Bn.from_hex(Bn(-100).hex()) == -100 + assert Bn(15).hex() == Bn(15).hex() + + with pytest.raises(Exception) as excinfo: + Bn(-100).binary() + assert 'negative' in str(excinfo.value) + + #assert Bn.from_binary(Bn(-100).binary()) == 100 + assert Bn.from_binary(Bn(100).binary()) == Bn(100) + assert Bn.from_binary(Bn(100).binary()) == 100 + + with pytest.raises(Exception) as excinfo: + s = 2**65 + Bn(s) + assert 'does not fit' in str(excinfo.value) + + #assert Bn.from_binary(Bn(-100).binary()) != Bn(50) + assert int(Bn(-100)) == -100 + + assert repr(Bn(5)) == Bn(5).repr() + assert repr(Bn(5)) == Bn(5).repr() == "Bn(5)" + assert range(10)[Bn(4)] == 4 + + d = {Bn(5): 5, Bn(6): 6} + assert Bn(5) in d + + +def test_bn_prime(): + p = Bn.get_prime(128) + assert p > Bn(0) + assert p.is_prime() + assert not Bn(16).is_prime() + assert p.num_bits() > 127 + + +def test_bn_arithmetic(): + assert (Bn(1) + Bn(1) == Bn(2)) + assert (Bn(1).int_add(Bn(1)) == Bn(2)) + + assert (Bn(1) + 1 == Bn(2)) + # assert (1 + Bn(1) == Bn(2)) + + assert (Bn(1) + Bn(-1) == Bn(0)) + assert (Bn(10) + Bn(10) == Bn(20)) + assert (Bn(-1) * Bn(-1) == Bn(1)) + assert (Bn(-1).int_mul(Bn(-1)) == Bn(1)) + + assert (Bn(10) * Bn(10) == Bn(100)) + assert (Bn(10) - Bn(10) == Bn(0)) + assert (Bn(10) - Bn(100) == Bn(-90)) + assert (Bn(10) + (-Bn(10)) == Bn(0)) + s = -Bn(100) + assert (Bn(10) + s == Bn(-90)) + assert (Bn(10) - (-Bn(10)) == Bn(20)) + assert -Bn(-10) == 10 + assert Bn(-10).int_neg() == 10 + + assert divmod(Bn(10), Bn(3)) == (Bn(3), Bn(1)) + assert Bn(10).divmod(Bn(3)) == (Bn(3), Bn(1)) + + assert Bn(10) // Bn(3) == Bn(3) + assert Bn(10).int_div(Bn(3)) == Bn(3) + + assert Bn(10) % Bn(3) == Bn(1) + assert Bn(10).mod(Bn(3)) == Bn(1) + + assert Bn(2) ** Bn(8) == Bn(2 ** 8) + assert pow(Bn(2), Bn(8), Bn(27)) == Bn(2 ** 8 % 27) + + pow(Bn(10), Bn(10)).binary() + + assert pow(Bn(2), 8, 27) == 2 ** 8 % 27 + + assert Bn(3).mod_inverse(16) == 11 + + with pytest.raises(Exception) as excinfo: + Bn(3).mod_inverse(0) + print("Got inverse") + assert 'No inverse' in str(excinfo.value) + + with pytest.raises(Exception) as excinfo: + x = Bn(0).mod_inverse(Bn(13)) + print("!!! Got inverse", x) + assert 'No inverse' in str(excinfo.value) + + # with pytest.raises(Exception) as excinfo: + # x = Bn(0).mod_inverse(Bn(13)) + # print("Got inverse", x) + #assert 'No inverse' in str(excinfo.value) + + assert Bn(10).mod_add(10, 15) == (10 + 10) % 15 + assert Bn(10).mod_sub(100, 15) == (10 - 100) % 15 + assert Bn(10).mod_mul(10, 15) == (10 * 10) % 15 + assert Bn(-1).bool() + + +def test_bn_right_arithmetic(): + assert (1 + Bn(1) == Bn(2)) + + assert (-1 * Bn(-1) == Bn(1)) + + assert (10 * Bn(10) == Bn(100)) + assert (10 - Bn(10) == Bn(0)) + assert (10 - Bn(100) == Bn(-90)) + assert (10 + (-Bn(10)) == Bn(0)) + s = -Bn(100) + assert (10 + s == Bn(-90)) + assert (10 - (-Bn(10)) == Bn(20)) + + assert divmod(10, Bn(3)) == (Bn(3), Bn(1)) + + assert 10 // Bn(3) == Bn(3) + + assert 10 % Bn(3) == Bn(1) + assert 2 ** Bn(8) == Bn(2 ** 8) + + assert 100 == Bn(100) + + pow(10, Bn(10)) + + +def test_bn_allocate(): + # Test allocation + n0 = Bn(10) + assert True + + assert str(Bn()) == "0" + assert str(Bn(1)) == "1" + assert str(Bn(-1)) == "-1" + + assert Bn(15).hex() == "F" + assert Bn(-15).hex() == "-F" + + assert int(Bn(5)) == 5 + assert Bn(5).int() == 5 + + assert 0 <= Bn(15).random() < 15 + + # Test copy + o0 = copy(n0) + o1 = deepcopy(n0) + + assert o0 == n0 + assert o1 == n0 + + # Test nonzero + assert not Bn() + assert not Bn(0) + assert Bn(1) + assert Bn(100) + + +def test_bn_cmp(): + assert Bn(1) < Bn(2) + assert Bn(1) <= Bn(2) + assert Bn(2) <= Bn(2) + assert Bn(2) == Bn(2) + assert Bn(2) <= Bn(3) + assert Bn(2) < Bn(3) + + +def test_extras(): + two = Bn(2) + two2 = two.copy() + assert two == two2 + + +def test_odd(): + assert Bn(1).is_odd() + assert Bn(1).is_bit_set(0) + assert not Bn(1).is_bit_set(1) + + assert Bn(3).is_odd() + assert Bn(3).is_bit_set(0) + assert Bn(3).is_bit_set(1) + + assert not Bn(0).is_odd() + assert not Bn(2).is_odd() + + assert Bn(100).is_bit_set(Bn(100).num_bits() - 1)