# Source code for pyecsca.ec.mult

from abc import ABC, abstractmethod
from copy import copy
from typing import Mapping, Tuple, Optional, MutableMapping, ClassVar, Set, Type

from public import public

from .context import Action
from .naf import naf, wnaf
from .params import DomainParameters
from .point import Point

[docs]@public
class ScalarMultiplicationAction(Action):
"""A scalar multiplication of a point on a curve by a scalar."""
point: Point
scalar: int

def __init__(self, point: Point, scalar: int):
super().__init__()
self.point = point
self.scalar = scalar

def __repr__(self):
return f"{self.__class__.__name__}({self.point}, {self.scalar})"

class ScalarMultiplier(ABC):
"""
A scalar multiplication algorithm.

:param short_circuit: Whether the use of formulas will be guarded by short-circuit on inputs
of the point at infinity.
:param formulas: Formulas this instance will use.
"""
requires: ClassVar[Set[Type[Formula]]]
optionals: ClassVar[Set[Type[Formula]]]
short_circuit: bool
formulas: Mapping[str, Formula]
_group: DomainParameters
_point: Point
_initialized: bool = False

def __init__(self, short_circuit=True, **formulas: Optional[Formula]):
if len(set(formula.coordinate_model for formula in formulas.values() if
formula is not None)) != 1:
raise ValueError
self.short_circuit = short_circuit
self.formulas = {k: v for k, v in formulas.items() if v is not None}

def _add(self, one: Point, other: Point) -> Point:
raise NotImplementedError
if self.short_circuit:
if one == self._group.neutral:
return copy(other)
if other == self._group.neutral:
return copy(one)

def _dbl(self, point: Point) -> Point:
if "dbl" not in self.formulas:
raise NotImplementedError
if self.short_circuit:
if point == self._group.neutral:
return copy(point)
return self.formulas["dbl"](point, **self._group.curve.parameters)[0]

def _scl(self, point: Point) -> Point:
if "scl" not in self.formulas:
raise NotImplementedError
return self.formulas["scl"](point, **self._group.curve.parameters)[0]

raise NotImplementedError
if self.short_circuit:
if to_dbl == self._group.neutral:
return self._dbl(to_dbl), to_dbl

def _dadd(self, start: Point, one: Point, other: Point) -> Point:
raise NotImplementedError
if self.short_circuit:
if one == self._group.neutral:
return copy(other)
if other == self._group.neutral:
return copy(one)

def _neg(self, point: Point) -> Point:
if "neg" not in self.formulas:
raise NotImplementedError
return self.formulas["neg"](point, **self._group.curve.parameters)[0]

def init(self, group: DomainParameters, point: Point):
"""Initialize the scalar multiplier with a group and a point."""
coord_model = set(self.formulas.values()).pop().coordinate_model
if group.curve.coordinate_model != coord_model or point.coordinate_model != coord_model:
raise ValueError
self._group = group
self._point = point
self._initialized = True

@abstractmethod
def multiply(self, scalar: int) -> Point:
"""Multiply the point with the scalar."""
...

[docs]@public
class LTRMultiplier(ScalarMultiplier):
"""
Classic double and add scalar multiplication algorithm, that scans the scalar left-to-right (msb to lsb)

The always parameter determines whether the double and add always method is used.
"""
optionals = {ScalingFormula}
always: bool
complete: bool

scl: ScalingFormula = None, always: bool = False, complete: bool = True,
short_circuit: bool = True):
self.always = always
self.complete = complete

[docs]    def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
with ScalarMultiplicationAction(self._point, scalar):
if scalar == 0:
return copy(self._group.neutral)
if self.complete:
q = self._point
r = copy(self._group.neutral)
top = self._group.order.bit_length() - 1
else:
q = self._dbl(self._point)
r = copy(self._point)
top = scalar.bit_length() - 2
for i in range(top, -1, -1):
r = self._dbl(r)
if scalar & (1 << i) != 0:
elif self.always:
if "scl" in self.formulas:
r = self._scl(r)
return r

[docs]@public
class RTLMultiplier(ScalarMultiplier):
"""
Classic double and add scalar multiplication algorithm, that scans the scalar right-to-left (lsb to msb)

The always parameter determines whether the double and add always method is used.
"""
optionals = {ScalingFormula}
always: bool

scl: ScalingFormula = None, always: bool = False, short_circuit: bool = True):
self.always = always

[docs]    def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
with ScalarMultiplicationAction(self._point, scalar):
if scalar == 0:
return copy(self._group.neutral)
q = self._point
r = copy(self._group.neutral)
while scalar > 0:
if scalar & 1 != 0:
elif self.always:
q = self._dbl(q)
scalar >>= 1
if "scl" in self.formulas:
r = self._scl(r)
return r

class CoronMultiplier(ScalarMultiplier):
"""
Coron's double and add resistant against SPA, from:

Resistance against Differential Power Analysis for Elliptic Curve Cryptosystems

"""
optionals = {ScalingFormula}

short_circuit: bool = True):

def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
with ScalarMultiplicationAction(self._point, scalar):
if scalar == 0:
return copy(self._group.neutral)
q = self._point
p0 = copy(q)
for i in range(scalar.bit_length() - 2, -1, -1):
p0 = self._dbl(p0)
if scalar & (1 << i) != 0:
p0 = p1
if "scl" in self.formulas:
p0 = self._scl(p0)
return p0

[docs]@public
"""
"""
optionals = {DoublingFormula, ScalingFormula}
complete: bool

def __init__(self, ladd: LadderFormula, dbl: DoublingFormula = None, scl: ScalingFormula = None,
complete: bool = True, short_circuit: bool = True):
self.complete = complete
if not complete and dbl is None:
raise ValueError

[docs]    def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
with ScalarMultiplicationAction(self._point, scalar):
if scalar == 0:
return copy(self._group.neutral)
q = self._point
if self.complete:
p0 = copy(self._group.neutral)
p1 = self._point
top = self._group.order.bit_length() - 1
else:
p0 = copy(q)
p1 = self._dbl(q)
top = scalar.bit_length() - 2
for i in range(top, -1, -1):
if scalar & (1 << i) == 0:
p0, p1 = self._ladd(q, p0, p1)
else:
p1, p0 = self._ladd(q, p1, p0)
if "scl" in self.formulas:
p0 = self._scl(p0)
return p0

[docs]@public
"""
"""
optionals = {ScalingFormula}
complete: bool

complete: bool = True, short_circuit: bool = True):
self.complete = complete

[docs]    def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
with ScalarMultiplicationAction(self._point, scalar):
if scalar == 0:
return copy(self._group.neutral)
if self.complete:
top = self._group.order.bit_length() - 1
else:
top = scalar.bit_length() - 1
p0 = copy(self._group.neutral)
p1 = copy(self._point)
for i in range(top, -1, -1):
if scalar & (1 << i) == 0:
p0 = self._dbl(p0)
else:
p1 = self._dbl(p1)
if "scl" in self.formulas:
p0 = self._scl(p0)
return p0

[docs]@public
"""
"""
optionals = {ScalingFormula}
complete: bool

scl: ScalingFormula = None, complete: bool = True, short_circuit: bool = True):
self.complete = complete

[docs]    def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
with ScalarMultiplicationAction(self._point, scalar):
if scalar == 0:
return copy(self._group.neutral)
if self.complete:
top = self._group.order.bit_length() - 1
else:
top = scalar.bit_length() - 1
q = self._point
p0 = copy(self._group.neutral)
p1 = copy(q)
for i in range(top, -1, -1):
if scalar & (1 << i) == 0:
p0 = self._dbl(p0)
else:
p1 = self._dbl(p1)
if "scl" in self.formulas:
p0 = self._scl(p0)
return p0

[docs]@public
class BinaryNAFMultiplier(ScalarMultiplier):
"""
Binary NAF (Non Adjacent Form) multiplier, left-to-right.
"""
_point_neg: Point

neg: NegationFormula, scl: ScalingFormula = None, short_circuit: bool = True):

[docs]    def init(self, group: DomainParameters, point: Point):
super().init(group, point)
self._point_neg = self._neg(point)

[docs]    def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
with ScalarMultiplicationAction(self._point, scalar):
if scalar == 0:
return copy(self._group.neutral)
bnaf = naf(scalar)
q = copy(self._group.neutral)
for val in bnaf:
q = self._dbl(q)
if val == 1:
if val == -1:
if "scl" in self.formulas:
q = self._scl(q)
return q

[docs]@public
class WindowNAFMultiplier(ScalarMultiplier):
"""
Window NAF (Non Adjacent Form) multiplier, left-to-right.
"""
_points: MutableMapping[int, Point]
_points_neg: MutableMapping[int, Point]
precompute_negation: bool = False
width: int

neg: NegationFormula, width: int, scl: ScalingFormula = None,
precompute_negation: bool = False, short_circuit: bool = True):
self.width = width
self.precompute_negation = precompute_negation

[docs]    def init(self, group: DomainParameters, point: Point):
super().init(group, point)
self._points = {}
self._points_neg = {}
current_point = point
double_point = self._dbl(point)
for i in range(1, (self.width + 1) // 2 + 1):
self._points[2 ** i - 1] = current_point
if self.precompute_negation:
self._points_neg[2 ** i - 1] = self._neg(current_point)

[docs]    def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
with ScalarMultiplicationAction(self._point, scalar):
if scalar == 0:
return copy(self._group.neutral)
naf = wnaf(scalar, self.width)
q = copy(self._group.neutral)
for val in naf:
q = self._dbl(q)
if val > 0: