Source code for pyecsca.sca.re.rpa

"""
Provides functionality inspired by the Refined-Power Analysis attack by Goubin [RPA]_.
"""
from copy import copy, deepcopy

from public import public
from typing import MutableMapping, Optional, Callable, List, Set, cast

from sympy import FF, sympify, Poly, symbols

from .base import RE
from .tree import Tree, Map
from ...ec.coordinates import AffineCoordinateModel
from ...ec.formula import (
    FormulaAction,
    DoublingFormula,
    AdditionFormula,
    TriplingFormula,
    NegationFormula,
    DifferentialAdditionFormula,
    LadderFormula,
)
from ...ec.mod import Mod
from ...ec.mult import (
    ScalarMultiplicationAction,
    PrecomputationAction,
    ScalarMultiplier,
)
from ...ec.params import DomainParameters
from ...ec.model import ShortWeierstrassModel, MontgomeryModel
from ...ec.point import Point
from ...ec.context import Context, Action, local
from ...misc.utils import log, warn


[docs] @public class MultipleContext(Context): """Context that traces the multiples of points computed.""" base: Optional[Point] """The base point that all the multiples are counted from.""" points: MutableMapping[Point, int] """The mapping of points to the multiples they represent (e.g., base -> 1).""" parents: MutableMapping[Point, List[Point]] """The mapping of points to the formula types they are a result of.""" formulas: MutableMapping[Point, str] """The mapping of points to their parent they were computed from.""" inside: bool def __init__(self): self.base = None self.points = {} self.parents = {} self.formulas = {} self.inside = False
[docs] def enter_action(self, action: Action) -> None: if isinstance(action, (ScalarMultiplicationAction, PrecomputationAction)): if self.base: # If we already did some computation with this context try to see if we are building on top of it. if self.base != action.point: # If we are not building on top of it we have to forget stuff and set a new base and mapping. self.base = action.point self.points = {self.base: 1} self.parents = {self.base: []} self.formulas = {self.base: ""} else: self.base = action.point self.points = {self.base: 1} self.parents = {self.base: []} self.formulas = {self.base: ""} self.inside = True
[docs] def exit_action(self, action: Action) -> None: if isinstance(action, (ScalarMultiplicationAction, PrecomputationAction)): self.inside = False if isinstance(action, FormulaAction) and self.inside: action = cast(FormulaAction, action) if isinstance(action.formula, DoublingFormula): inp = action.input_points[0] out = action.output_points[0] self.points[out] = 2 * self.points[inp] self.parents[out] = [inp] self.formulas[out] = action.formula.shortname elif isinstance(action.formula, TriplingFormula): inp = action.input_points[0] out = action.output_points[0] self.points[out] = 3 * self.points[inp] self.parents[out] = [inp] self.formulas[out] = action.formula.shortname elif isinstance(action.formula, AdditionFormula): one, other = action.input_points out = action.output_points[0] self.points[out] = self.points[one] + self.points[other] self.parents[out] = [one, other] self.formulas[out] = action.formula.shortname elif isinstance(action.formula, NegationFormula): inp = action.input_points[0] out = action.output_points[0] self.points[out] = -self.points[inp] self.parents[out] = [inp] self.formulas[out] = action.formula.shortname elif isinstance(action.formula, DifferentialAdditionFormula): _, one, other = action.input_points out = action.output_points[0] self.points[out] = self.points[one] + self.points[other] self.parents[out] = [one, other] self.formulas[out] = action.formula.shortname elif isinstance(action.formula, LadderFormula): _, one, other = action.input_points dbl, add = action.output_points self.points[add] = self.points[one] + self.points[other] self.parents[add] = [one, other] self.formulas[add] = action.formula.shortname self.points[dbl] = 2 * self.points[one] self.parents[dbl] = [one] self.formulas[dbl] = action.formula.shortname
def __repr__(self): return f"{self.__class__.__name__}({self.base!r}, multiples={self.points.values()!r})"
[docs] @public def rpa_point_0y(params: DomainParameters) -> Optional[Point]: """Construct an (affine) [RPA]_ point (0, y) for given domain parameters.""" if isinstance(params.curve.model, ShortWeierstrassModel): if not params.curve.parameters["b"].is_residue(): return None y = params.curve.parameters["b"].sqrt() # TODO: We can take the negative as well. return Point( AffineCoordinateModel(params.curve.model), x=Mod(0, params.curve.prime), y=y ) elif isinstance(params.curve.model, MontgomeryModel): return Point( AffineCoordinateModel(params.curve.model), x=Mod(0, params.curve.prime), y=Mod(0, params.curve.prime), ) else: raise NotImplementedError
[docs] @public def rpa_point_x0(params: DomainParameters) -> Optional[Point]: """Construct an (affine) [RPA]_ point (x, 0) for given domain parameters.""" if isinstance(params.curve.model, ShortWeierstrassModel): if (params.order * params.cofactor) % 2 != 0: return None k = FF(params.curve.prime) expr = sympify("x^3 + a * x + b", evaluate=False) expr = expr.subs("a", k(int(params.curve.parameters["a"]))) expr = expr.subs("b", k(int(params.curve.parameters["b"]))) poly = Poly(expr, symbols("x"), domain=k) roots = poly.ground_roots() if not roots: return None x = Mod(int(next(iter(roots.keys()))), params.curve.prime) return Point( AffineCoordinateModel(params.curve.model), x=x, y=Mod(0, params.curve.prime) ) elif isinstance(params.curve.model, MontgomeryModel): return Point( AffineCoordinateModel(params.curve.model), x=Mod(0, params.curve.prime), y=Mod(0, params.curve.prime), ) else: raise NotImplementedError
[docs] @public def rpa_input_point(k: Mod, rpa_point: Point, params: DomainParameters) -> Point: """Construct an (affine) input point P that will lead to an RPA point [k]P.""" kinv = k.inverse() return params.curve.affine_multiply(rpa_point, int(kinv))
[docs] @public def rpa_distinguish( params: DomainParameters, multipliers: List[ScalarMultiplier], oracle: Callable[[int, Point], bool], bound: Optional[int] = None, tries: int = 10, majority: int = 1, use_init: bool = True, use_multiply: bool = True, ) -> Set[ScalarMultiplier]: """ Distinguish the scalar multiplier used (from the possible :paramref:`~.rpa_distinguish.multipliers`) using an [RPA]_ :paramref:`~.rpa_distinguish.oracle`. :param params: The domain parameters to use. :param multipliers: The list of possible multipliers. :param oracle: An oracle that returns `True` when an RPA point is encountered during scalar multiplication of the input by the scalar. :param bound: A bound on the size of the scalar to consider. :param tries: Number of tries to get a non-trivial tree. :param majority: Query the oracle up to `majority` times and take the majority vote of the results. :param use_init: Whether to consider the point multiples that happen in scalarmult initialization. :param use_multiply: Whether to consider the point multiples that happen in scalarmult multiply (after initialization). :return: The list of possible multipliers after distinguishing (ideally just one). """ re = RPA(set(multipliers)) re.build_tree(params, tries, bound, use_init, use_multiply) return re.run(oracle, majority)
[docs] @public class RPA(RE): params: Optional[DomainParameters] = None P0: Optional[Point] = None scalars: Optional[List[int]] = None
[docs] def build_tree( self, params: DomainParameters, tries: int = 10, bound: Optional[int] = None, use_init: bool = True, use_multiply: bool = True, ): if not (use_init or use_multiply): raise ValueError("Has to use either init or multiply or both.") P0 = rpa_point_0y(params) if not P0: raise ValueError("There are no RPA-points on the provided curve.") if not bound: bound = params.order mults = {copy(mult) for mult in self.configs} init_contexts = {} for mult in mults: with local(MultipleContext()) as ctx: mult.init(params, params.generator) init_contexts[mult] = ctx done = 0 tree = None scalars = [] while True: scalar = int(Mod.random(bound)) log(f"Got scalar {scalar}") log([mult.__class__.__name__ for mult in mults]) mults_to_multiples = {} for mult in mults: # Copy the context after init to not accumulate multiples by accident here. init_context = deepcopy(init_contexts[mult]) # Take the computed points during init init_points = set(init_context.parents.keys()) # And get their parents (inputs to formulas) init_parents = set( sum((init_context.parents[point] for point in init_points), []) ) # Go over the parents and map them to multiples of the base (plus-minus sign) init_multiples = set( map( lambda v: Mod(v, params.order), (init_context.points[parent] for parent in init_parents), ) ) init_multiples |= set(map(lambda v: -v, init_multiples)) # Now do the multiply and repeat the above, but only consider new computed points with local(init_context) as ctx: mult.multiply(scalar) all_points = set(ctx.parents.keys()) multiply_parents = set( sum((ctx.parents[point] for point in all_points - init_points), []) ) multiply_multiples = set( map( lambda v: Mod(v, params.order), (ctx.points[parent] for parent in multiply_parents), ) ) multiply_multiples |= set(map(lambda v: -v, multiply_multiples)) used = set() if use_init: used |= init_multiples if use_multiply: used |= multiply_multiples mults_to_multiples[mult] = used dmap = Map.from_sets(set(mults), mults_to_multiples) if tree is None: tree = Tree.build(set(mults), dmap) else: tree = tree.expand(dmap) log("Built distinguishing tree.") log(tree.render()) scalars.append(scalar) if not tree.precise: done += 1 if done > tries: warn(f"Tried more than {tries} times. Aborting. Distinguishing may not be precise.") break else: continue else: break self.scalars = scalars self.tree = tree self.params = params self.P0 = P0
[docs] def run( self, oracle: Callable[[int, Point], bool], majority: int = 1 ) -> Set[ScalarMultiplier]: if self.tree is None or self.scalars is None or self.P0 is None or self.params is None: raise ValueError("Need to build tree first.") if (majority % 2) == 0: raise ValueError("Cannot use even majority.") current_node = self.tree.root mults = current_node.cfgs while current_node.children: scalar = self.scalars[current_node.dmap_index] # type: ignore best_distinguishing_multiple: Mod = current_node.dmap_input # type: ignore P0_inverse = rpa_input_point( best_distinguishing_multiple, self.P0, self.params ) responses = [] response = True for _ in range(majority): responses.append(oracle(scalar, P0_inverse)) if responses.count(True) > (majority // 2): response = True break if responses.count(False) > (majority // 2): response = False break log(f"Oracle response -> {response}") response_map = {child.response: child for child in current_node.children} current_node = response_map[response] mults = current_node.cfgs log([mult.__class__.__name__ for mult in mults]) log() return mults