Source code for pyecsca.ec.mod

"""
Provides several implementations of an element of ℤₙ.

The base class :py:class:`Mod` dynamically
dispatches to the implementation chosen by the runtime configuration of the library
(see :py:class:`pyecsca.misc.cfg.Config`). A Python integer based implementation is available under
:py:class:`RawMod`. A symbolic implementation based on sympy is available under :py:class:`SymbolicMod`. If
`gmpy2` is installed, a GMP based implementation is available under :py:class:`GMPMod`.
"""
import random
import secrets
from functools import wraps, lru_cache
from typing import Type, Dict, Any, Tuple, Union

from public import public
from sympy import Expr, FF

from .error import raise_non_invertible, raise_non_residue
from .context import ResultAction
from ..misc.cfg import getconfig

has_gmp = False
try:
    import gmpy2

    has_gmp = True
except ImportError:
    gmpy2 = None


[docs]@public def gcd(a, b): """Euclid's greatest common denominator algorithm.""" if abs(a) < abs(b): return gcd(b, a) while abs(b) > 0: _, r = divmod(a, b) a, b = b, r return a
[docs]@public def extgcd(a, b): """Compute the extended Euclid's greatest common denominator algorithm.""" if abs(b) > abs(a): x, y, d = extgcd(b, a) return y, x, d if abs(b) == 0: return 1, 0, a x1, x2, y1, y2 = 0, 1, 1, 0 while abs(b) > 0: q, r = divmod(a, b) x = x2 - q * x1 y = y2 - q * y1 a, b, x2, x1, y2, y1 = b, r, x1, x, y1, y return x2, y2, a
[docs]@public def jacobi(x: int, n: int) -> int: """Jacobi symbol.""" if n <= 0: raise ValueError("'n' must be a positive integer.") if n % 2 == 0: raise ValueError("'n' must be odd.") x %= n r = 1 while x != 0: while x % 2 == 0: x //= 2 nm8 = n % 8 if nm8 in (3, 5): r = -r x, n = n, x if x % 4 == 3 and n % 4 == 3: r = -r x %= n return r if n == 1 else 0
[docs]@public @lru_cache def miller_rabin(n: int, rounds: int = 50) -> bool: """Miller-Rabin probabilistic primality test.""" if n in (2, 3): return True if n % 2 == 0: return False r, s = 0, n - 1 while s % 2 == 0: r += 1 s //= 2 for _ in range(rounds): a = random.randrange(2, n - 1) x = pow(a, s, n) if x in (1, n - 1): continue for _ in range(r - 1): x = pow(x, 2, n) if x == n - 1: break else: return False return True
def _check(func): @wraps(func) def method(self, other): if self.__class__ is not type(other): other = self.__class__(other, self.n) elif self.n != other.n: raise ValueError return func(self, other) return method
[docs]@public class RandomModAction(ResultAction): """A random sampling from Z_n.""" order: int def __init__(self, order: int): super().__init__() self.order = order def __repr__(self): return f"{self.__class__.__name__}({self.order:x})"
_mod_classes: Dict[str, Type] = {}
[docs]@public class Mod: """An element x of ℤₙ.""" x: Any n: Any __slots__ = ("x", "n") def __new__(cls, *args, **kwargs) -> "Mod": if cls != Mod: return cls.__new__(cls, *args, **kwargs) if not _mod_classes: raise ValueError("Cannot find any working Mod class.") selected_class = getconfig().ec.mod_implementation if selected_class not in _mod_classes: # Fallback to something selected_class = next(iter(_mod_classes.keys())) return _mod_classes[selected_class].__new__( _mod_classes[selected_class], *args, **kwargs ) @_check def __add__(self, other) -> "Mod": return self.__class__((self.x + other.x) % self.n, self.n) @_check def __radd__(self, other) -> "Mod": return self + other @_check def __sub__(self, other) -> "Mod": return self.__class__((self.x - other.x) % self.n, self.n) @_check def __rsub__(self, other) -> "Mod": return -self + other def __neg__(self) -> "Mod": return self.__class__(self.n - self.x, self.n)
[docs] def inverse(self) -> "Mod": """ Invert the element. :return: The inverse. :raises: :py:class:`NonInvertibleError` if the element is not invertible. """ raise NotImplementedError
def __invert__(self) -> "Mod": return self.inverse()
[docs] def is_residue(self) -> bool: """Whether this element is a quadratic residue (only implemented for prime modulus).""" raise NotImplementedError
[docs] def sqrt(self) -> "Mod": """ Compute the modular square root of this element (only implemented for prime modulus). Uses the `Tonelli-Shanks <https://en.wikipedia.org/wiki/Tonelli–Shanks_algorithm>`_ algorithm. """ raise NotImplementedError
@_check def __mul__(self, other) -> "Mod": return self.__class__((self.x * other.x) % self.n, self.n) @_check def __rmul__(self, other) -> "Mod": return self * other @_check def __truediv__(self, other) -> "Mod": return self * ~other @_check def __rtruediv__(self, other) -> "Mod": return ~self * other @_check def __floordiv__(self, other) -> "Mod": return self * ~other @_check def __rfloordiv__(self, other) -> "Mod": return ~self * other @_check def __divmod__(self, divisor) -> Tuple["Mod", "Mod"]: q, r = divmod(self.x, divisor.x) return self.__class__(q, self.n), self.__class__(r, self.n) def __bytes__(self) -> bytes: raise NotImplementedError def __int__(self) -> int: raise NotImplementedError
[docs] @classmethod def random(cls, n: int) -> "Mod": """ Generate a random :py:class:`Mod` in ℤₙ. :param n: The order. :return: The random :py:class:`Mod`. """ with RandomModAction(n) as action: return action.exit(cls(secrets.randbelow(n), n))
def __pow__(self, n) -> "Mod": return NotImplemented def __str__(self): return str(self.x)
[docs]@public class RawMod(Mod): """An element x of ℤₙ (implemented using Python integers).""" x: int n: int __slots__ = ("x", "n") def __new__(cls, *args, **kwargs): return object.__new__(cls) def __init__(self, x: int, n: int): self.x = x % n self.n = n
[docs] def inverse(self) -> "RawMod": if self.x == 0: raise_non_invertible() x, _, d = extgcd(self.x, self.n) if d != 1: raise_non_invertible() return RawMod(x, self.n)
[docs] def is_residue(self): if not miller_rabin(self.n): raise NotImplementedError if self.x == 0: return True if self.n == 2: return self.x in (0, 1) legendre_symbol = jacobi(self.x, self.n) return legendre_symbol == 1
[docs] def sqrt(self) -> "RawMod": if not miller_rabin(self.n): raise NotImplementedError if self.x == 0: return RawMod(0, self.n) if not self.is_residue(): raise_non_residue() if self.n % 4 == 3: return self ** int((self.n + 1) // 4) q = self.n - 1 s = 0 while q % 2 == 0: q //= 2 s += 1 z = 2 while RawMod(z, self.n).is_residue(): z += 1 m = s c = RawMod(z, self.n) ** q t = self ** q r_exp = (q + 1) // 2 r = self ** r_exp while t != 1: i = 1 while not (t ** (2 ** i)) == 1: i += 1 two_exp = m - (i + 1) b = c ** int(RawMod(2, self.n) ** two_exp) m = int(RawMod(i, self.n)) c = b ** 2 t *= c r *= b return r
def __bytes__(self): return self.x.to_bytes((self.n.bit_length() + 7) // 8, byteorder="big") def __int__(self): return self.x def __eq__(self, other): if type(other) is int: return self.x == (other % self.n) if type(other) is not RawMod: return False return self.x == other.x and self.n == other.n def __ne__(self, other): return not self == other def __repr__(self): return str(self.x) def __hash__(self): return hash(("RawMod", self.x, self.n)) def __pow__(self, n) -> "RawMod": if type(n) is not int: raise TypeError if n == 0: return RawMod(1, self.n) if n < 0: return self.inverse() ** (-n) if n == 1: return RawMod(self.x, self.n) return RawMod(pow(self.x, n, self.n), self.n)
_mod_classes["python"] = RawMod
[docs]@public class Undefined(Mod): """A special undefined element.""" __slots__ = ("x", "n") def __new__(cls, *args, **kwargs): return object.__new__(cls) def __init__(self): self.x = None self.n = None def __add__(self, other): return NotImplemented def __radd__(self, other): return NotImplemented def __sub__(self, other): return NotImplemented def __rsub__(self, other): return NotImplemented def __neg__(self): raise NotImplementedError
[docs] def inverse(self): raise NotImplementedError
[docs] def sqrt(self): raise NotImplementedError
[docs] def is_residue(self): raise NotImplementedError
def __invert__(self): raise NotImplementedError def __mul__(self, other): return NotImplemented def __rmul__(self, other): return NotImplemented def __truediv__(self, other): return NotImplemented def __rtruediv__(self, other): return NotImplemented def __floordiv__(self, other): return NotImplemented def __rfloordiv__(self, other): return NotImplemented def __divmod__(self, divisor): return NotImplemented def __bytes__(self): raise NotImplementedError def __int__(self): raise NotImplementedError def __eq__(self, other): return False def __ne__(self, other): return False def __repr__(self): return "Undefined" def __hash__(self): return hash("Undefined") + 1 def __pow__(self, n): return NotImplemented
@lru_cache def __ff_cache(n): return FF(n) def _symbolic_check(func): @wraps(func) def method(self, other): if type(self) is not type(other): if type(other) is int: other = self.__class__(__ff_cache(self.n)(other), self.n) else: other = self.__class__(other, self.n) else: if self.n != other.n: raise ValueError return func(self, other) return method
[docs]@public class SymbolicMod(Mod): """A symbolic element x of ℤₙ (implemented using sympy).""" x: Expr n: int __slots__ = ("x", "n") def __new__(cls, *args, **kwargs): return object.__new__(cls) def __init__(self, x: Expr, n: int): self.x = x self.n = n @_symbolic_check def __add__(self, other) -> "SymbolicMod": return self.__class__((self.x + other.x), self.n) @_symbolic_check def __radd__(self, other) -> "SymbolicMod": return self + other @_symbolic_check def __sub__(self, other) -> "SymbolicMod": return self.__class__((self.x - other.x), self.n) @_symbolic_check def __rsub__(self, other) -> "SymbolicMod": return -self + other def __neg__(self) -> "SymbolicMod": return self.__class__(-self.x, self.n)
[docs] def inverse(self) -> "SymbolicMod": return self.__class__(self.x ** (-1), self.n)
[docs] def sqrt(self) -> "SymbolicMod": raise NotImplementedError
[docs] def is_residue(self): raise NotImplementedError
def __invert__(self) -> "SymbolicMod": return self.inverse() @_symbolic_check def __mul__(self, other) -> "SymbolicMod": return self.__class__(self.x * other.x, self.n) @_symbolic_check def __rmul__(self, other) -> "SymbolicMod": return self * other @_symbolic_check def __truediv__(self, other) -> "SymbolicMod": return self * ~other @_symbolic_check def __rtruediv__(self, other) -> "SymbolicMod": return ~self * other @_symbolic_check def __floordiv__(self, other) -> "SymbolicMod": return self * ~other @_symbolic_check def __rfloordiv__(self, other) -> "SymbolicMod": return ~self * other def __divmod__(self, divisor) -> "SymbolicMod": return NotImplemented def __bytes__(self): return int(self.x).to_bytes((self.n.bit_length() + 7) // 8, byteorder="big") def __int__(self): return int(self.x) def __eq__(self, other): if type(other) is int: return self.x == other % self.n if type(other) is not SymbolicMod: return False return self.x == other.x and self.n == other.n def __ne__(self, other): return not self == other def __repr__(self): return str(self.x) def __hash__(self): return hash(("SymbolicMod", self.x, self.n)) def __pow__(self, n) -> "SymbolicMod": return self.__class__(pow(self.x, n), self.n)
_mod_classes["symbolic"] = SymbolicMod if has_gmp: @lru_cache def _is_prime(x) -> bool: return gmpy2.is_prime(x)
[docs] @public class GMPMod(Mod): """An element x of ℤₙ. Implemented by GMP.""" x: gmpy2.mpz n: gmpy2.mpz __slots__ = ("x", "n") def __new__(cls, *args, **kwargs): return object.__new__(cls) def __init__(self, x: Union[int, gmpy2.mpz], n: Union[int, gmpy2.mpz], ensure: bool = True): if ensure: self.n = gmpy2.mpz(n) self.x = gmpy2.mpz(x % self.n) else: self.n = n self.x = x
[docs] def inverse(self) -> "GMPMod": if self.x == 0: raise_non_invertible() if self.x == 1: return GMPMod(gmpy2.mpz(1), self.n, ensure=False) try: res = gmpy2.invert(self.x, self.n) except ZeroDivisionError: raise_non_invertible() res = gmpy2.mpz(0) return GMPMod(res, self.n, ensure=False)
[docs] def is_residue(self) -> bool: if not _is_prime(self.n): raise NotImplementedError if self.x == 0: return True if self.n == 2: return self.x in (0, 1) return gmpy2.legendre(self.x, self.n) == 1
[docs] def sqrt(self) -> "GMPMod": if not _is_prime(self.n): raise NotImplementedError if self.x == 0: return GMPMod(gmpy2.mpz(0), self.n, ensure=False) if not self.is_residue(): raise_non_residue() if self.n % 4 == 3: return self ** int((self.n + 1) // 4) q = self.n - 1 s = 0 while q % 2 == 0: q //= 2 s += 1 z = gmpy2.mpz(2) while GMPMod(z, self.n, ensure=False).is_residue(): z += 1 m = s c = GMPMod(z, self.n, ensure=False) ** int(q) t = self ** int(q) r_exp = (q + 1) // 2 r = self ** int(r_exp) while t != 1: i = 1 while not (t ** (2 ** i)) == 1: i += 1 two_exp = m - (i + 1) b = c ** int(GMPMod(gmpy2.mpz(2), self.n, ensure=False) ** two_exp) m = int(GMPMod(gmpy2.mpz(i), self.n, ensure=False)) c = b ** 2 t *= c r *= b return r
@_check def __add__(self, other) -> "GMPMod": return GMPMod((self.x + other.x) % self.n, self.n, ensure=False) @_check def __sub__(self, other) -> "GMPMod": return GMPMod((self.x - other.x) % self.n, self.n, ensure=False) def __neg__(self) -> "GMPMod": return GMPMod(self.n - self.x, self.n, ensure=False) @_check def __mul__(self, other) -> "GMPMod": return GMPMod((self.x * other.x) % self.n, self.n, ensure=False) @_check def __divmod__(self, divisor) -> Tuple["GMPMod", "GMPMod"]: q, r = gmpy2.f_divmod(self.x, divisor.x) return GMPMod(q, self.n, ensure=False), GMPMod(r, self.n, ensure=False) def __bytes__(self): return int(self.x).to_bytes((self.n.bit_length() + 7) // 8, byteorder="big") def __int__(self): return int(self.x) def __eq__(self, other): if type(other) is int: return self.x == (gmpy2.mpz(other) % self.n) if type(other) is not GMPMod: return False return self.x == other.x and self.n == other.n def __ne__(self, other): return not self == other def __repr__(self): return str(int(self.x)) def __hash__(self): return hash(("GMPMod", self.x, self.n)) def __pow__(self, n) -> "GMPMod": if type(n) not in (int, gmpy2.mpz): raise TypeError if n == 0: return GMPMod(gmpy2.mpz(1), self.n, ensure=False) if n < 0: return self.inverse() ** (-n) if n == 1: return GMPMod(self.x, self.n, ensure=False) return GMPMod(gmpy2.powmod(self.x, gmpy2.mpz(n), self.n), self.n, ensure=False)
_mod_classes["gmp"] = GMPMod