EPA-based reverse engineering

This notebook showcases the EPA-based reverse-engineering technique for addition formulas.

[ ]:
import io
import tabulate
import secrets
from tqdm.notebook import tqdm, trange
from functools import partial
from itertools import product
from IPython.display import HTML, display
from sympy.ntheory import factorint
from sympy.ntheory.modular import crt
from anytree import Node

from pyecsca.ec.model import ShortWeierstrassModel
from pyecsca.ec.coordinates import AffineCoordinateModel
from pyecsca.ec.curve import EllipticCurve
from pyecsca.ec.params import DomainParameters, load_params_ectester
from pyecsca.ec.mod import Mod, miller_rabin, gcd
from pyecsca.ec.point import Point, InfinityPoint
from pyecsca.ec.error import NonInvertibleError
from pyecsca.ec.mult import LTRMultiplier, AccumulationOrder
from pyecsca.ec.context import local
from pyecsca.ec.error import UnsatisfiedAssumptionError
from pyecsca.ec.formula.base import *
from pyecsca.ec.formula.fake import FakeAdditionFormula, FakeDoublingFormula, FakePoint
from pyecsca.ec.formula.unroll import unroll_formula_expr
from pyecsca.sca.re.tree import Map, Tree
from pyecsca.sca.re.rpa import MultipleContext
from pyecsca.misc.utils import log, warn

A few curves with composite “p”.

[ ]:
curves = [
    # phi(p)/p =
    # 0.8286039438617044
    "dfb2da5e1b7bd7bb098cb975966293ed,d9c4372806e8131b18d0036e8f832749,bcae41be8e808acdc04bb769dead91e2,0e2f983c0f852bef381f567448f0d488,1599bba77ed1cb8dec41555098958492,10fcabd48fffc71e6300d44acc236157d,0001",
    # 0.633325621696952
    "cca6f6718a06cad7094962b2a35f067d,67aa9464eb493fbb7b509d29381b9a9d,cafc69aa517b654a6a608644996cc8d1,4c092beb06cc00751eec39675f680cb8,82800378a47dd6f26ff6a50f69e4c4e6,18a22d20b6de3ff6bdc49329c21163f77,0001",
    # 0.8508806646440022
    "b3755d654bad73114e4191e9f5f36af9,9fe4f88cfbacba71f4b767ace8580c74,4610526fdcfbd69aed453ac2ee6efeef,542d8e0bbafe40dae36f25cbc350add6,68a65f5a5dc304bfd0d8fe963c250206,118a34a1ea295e78b3a3c960b6f680ee1,0001",

    # 0.9845701775215489 (has a = 0 for a subcurve)
    "de1406450d5d7e91d81907956019c0c1,5fbe46b9f1086011e18f5d823c6110ce,a859c36ceeadb39c7a978f7b1b0563ee,1cba89c3f099c29401ecf3fe1806e822,345d7282a0114070be91f95fe3db1faa,0fcd24d24e57a40547814b6766b9ea735,0001",
    # 0.980582605794486  (does not have a = 0 for any subcurve)
    "cab298b495875d4ab2c8ee3eb03016a7,a7c4f56f286d9eae44424c85c8b2fcb9,5e8c439d939273fdcb5503acbda7d3f8,816c9f865c831223067a88046bf00d75,972ce29ed18d5d73f15cef31187659be,0b0e97ff8c3e72e7ae75eb3f5e759fe03,0001",

    # 0.9547100843537808
    "f1a8a441b6d0e9600e33ccf16f9b8291,b3f55185bd6a63528e3d560c6a7b729a,c2fee2d65350e870eda0ac5e2b96b810,29b3e793822fad03a3c2ebca3cf62c12,b937d5389b6c5d0212d0f53e26843092,1153442389f9e1da8dd130bc93c6ef42b,0001",
    # 0.7214369438844093
    "a4dfa4b6b065c40b45980474266c9fbb,2c3486e725755b44a7c119473c5b9c64,329078ab070fc18edc6ce53047e00a39,9f6209be91b66943d9e8e0b61c4aae4e,05271c9ac628351b9add9e1be69a9fa4,0cef2e52ffe86ebc6dd323912ac7d9a87,0001",
    # 0.4716485170445178 (160-bit)
    "db49063db56b7783fa01dd62077c5a88dfa28009,aee572fdd4790bcd4729bb3b612b52a573df46e9,dab9e68366a593ca1df9cb2f20890a578729d6ef,d4a3aaf43bdb25be7c308b69ae54f639e6e32e8c,7b6c82140bb427ac6e2a64507f60775949b2c8ce,34a9fbe62b272f930b2e5027780a32300feb0dd8f,0001"
]
[ ]:
model = ShortWeierstrassModel()
affine = AffineCoordinateModel(model)

Exploration

Now, let’s define some functions for picking random scalars mod \(n\) and random points on the curve. There are several ways to do so, some guarantee that the scalars will be “trivial” w.r.t. the curve order \(n\) or even that all subscalars for a given scalarmult algo will be trivial w.r.t. the curve order.

[ ]:
def random_scalar(n):
    """Generate a random scalar mod n."""
    return secrets.randbelow(n)

def random_scalar_trivial(n):
    """Generate a random scalar with trivial gcd mod n."""
    scalar = secrets.randbelow(n)
    while gcd(scalar, n) != 1:
        scalar = secrets.randbelow(n)
    return scalar

def random_scalar_fully_trivial(n, mult):
    """Generate a random scalar with trivial gcd mod n, and also ensure that the given mult computes only multiples with trivial gcd mod n."""
    scalar = random_scalar_trivial(n)
    while True:
        with local(MultipleContext()) as ctx:
            mult.multiply(scalar)
        if all(map(lambda x: gcd(x, n) == 1, ctx.points.values())):
            return scalar
        scalar = random_scalar_trivial(n)

def fixed_point(params):
    """Generate a fixed point on the params."""
    return params.generator

def random_point(splitted, top, randomized=False):
    """Generate a random point on the splitted params."""
    results = {}
    for factor, params in splitted.items():
        results[factor] = params.curve.affine_random()
    factors = list(results.keys())
    xs = list(map(lambda factor: int(results[factor].x), factors))
    ys = list(map(lambda factor: int(results[factor].y), factors))
    res_x = Mod(int(crt(factors, xs)[0]), top.curve.prime)
    res_y = Mod(int(crt(factors, ys)[0]), top.curve.prime)
    res = Point(affine, x=res_x, y=res_y)
    return res.to_model(top.curve.coordinate_model, top.curve, randomized=randomized)

Let’s also define a way to project the points down to a subcurve, a way to split the curve to subcurves and a scalarmult algo that correctly computes on the top curve by splitting over the subcurves.

[ ]:
def project_down(point, subcurve):
    """Project a point down onto a subcurve."""
    return Point(subcurve.coordinate_model, **{name: Mod(int(value), subcurve.prime) for name, value in point.coords.items()})

def split_params(params):
    """Split composite "p" params into subcurves."""
    factors = factorint(params.curve.prime)
    if set(factors.values()) != {1}:
        raise ValueError("Not squarefree")
    results = {}
    # Construct the curves
    for factor in sorted(factors.keys()):
        p_i = factor
        parameters_i = {name: Mod(int(value), p_i) for name, value in params.curve.parameters.items()}
        curve_i = EllipticCurve(params.curve.model, params.curve.coordinate_model, p_i, params.curve.neutral, parameters_i)
        generator_i = project_down(params.generator, curve_i)
        params_i = DomainParameters(curve_i, generator_i, 0, 1)
        results[factor] = params_i
    # Now map the orders to the curves
    orders = list(factorint(params.order).keys())
    orders.sort()
    for factor_i, params_i in results.items():
        for order in orders:
            try:
                params_i.curve.affine_multiply(params_i.generator.to_affine(), order)
            except NonInvertibleError:
                params_i.order = order
                orders.remove(order)
                break
    return results

def split_scalarmult(splitted, top, point, scalar):
    """Perform affine scalarmult of "point" by "scalar" on the splitted params."""
    results = {}
    for factor, params in splitted.items():
        order = params.order
        projected = project_down(point, params.curve)
        partial_scalar = scalar % order
        if partial_scalar == 0:
            result = InfinityPoint(params.curve.coordinate_model)
        else:
            result = params.curve.affine_multiply(projected.to_affine(), partial_scalar)
        results[factor] = result
    if any(map(lambda point: isinstance(point, InfinityPoint), results.values())):
        # This is actually undefined if only one point is the infinity point.
        return InfinityPoint(top.curve.coordinate_model)
    factors = list(results.keys())
    xs = list(map(lambda factor: int(results[factor].x), factors))
    ys = list(map(lambda factor: int(results[factor].y), factors))
    res_x = Mod(int(crt(factors, xs)[0]), top.curve.prime)
    res_y = Mod(int(crt(factors, ys)[0]), top.curve.prime)
    return Point(affine, x=res_x, y=res_y)

With all of that we can now explore the behavior of the formulas, focusing on projective coordinates for now.

[ ]:
which = "projective"
coords = model.coordinates[which]

params = load_params_ectester(io.BytesIO(curves[4].encode()), which)
curve = params.curve
p = params.curve.prime
g = params.generator
n = params.order

adds = list(filter(lambda formula: formula.name.startswith("add"), coords.formulas.values()))
dbls = list(filter(lambda formula: formula.name.startswith("dbl"), coords.formulas.values()))
formula_pairs = list(product(adds, dbls))

fake_add = FakeAdditionFormula(params.curve.coordinate_model)
fake_dbl = FakeDoublingFormula(params.curve.coordinate_model)
fake_mult = LTRMultiplier(fake_add, fake_dbl, None, False, AccumulationOrder.PeqPR, True, True)
fake_mult.init(params, FakePoint(params.curve.coordinate_model))
[ ]:
def simulate_table(scalars, points, split, params, formula_pairs, adds, dbls):
    results = []
    chains = []
    gcds = []
    fgcds = []
    for scalar, point in tqdm(zip(scalars, points), desc="Precomp", total=len(scalars)):
        try:
            result = split_scalarmult(split, params, point, scalar)
        except NonInvertibleError:
            result = None
        results.append(result)
        with local(MultipleContext()) as ctx:
            fake_mult.multiply(scalar)
        chains.append(list(ctx.points.values()))
        scalar_trivial_gcd = gcd(scalar, n) == 1
        all_subscalars_trivial_gcd = all(map(lambda x: gcd(x, n) == 1, ctx.points.values()))
        gcds.append(scalar_trivial_gcd)
        fgcds.append(all_subscalars_trivial_gcd)

    table = [["Pair", "scalars with trivial gcd", "scalars with all multiples with trivial gcds", "scalars with invertible final zs", "scalars with all multiples's zs invertible", "scalars with correct result"]]
    pair_table = [[None for _ in dbls] for _ in adds]
    for pair in tqdm(formula_pairs):
        mult = LTRMultiplier(*pair, None, False, AccumulationOrder.PeqPR, True, True)
        inv = []
        correct = []
        zs = []
        for scalar, point, result in tqdm(zip(scalars, points, results), leave=None, total=len(scalars)):
            mult.init(params, point)
            with local(MultipleContext()) as ctx:
                res = mult.multiply(scalar)

            all_submultiples_invertible_z = all(map(lambda x: gcd(int(x.Z), p) == 1, ctx.points.keys()))
            result_invertible_z = False
            result_correct = False
            try:
                res_aff = res.to_affine()
                result_invertible_z = True
                if res_aff == result:
                    result_correct = True
            except NonInvertibleError as e:
                pass
            zs.append(all_submultiples_invertible_z)
            inv.append(result_invertible_z)
            correct.append(result_correct)
        pair_table[adds.index(pair[0])][dbls.index(pair[1])] = sum(inv)
        for i in inv:
            print("x" if i else ".", end="")
        print()
        table.append([f"{pair[0].name}, {pair[1].name}", sum(gcds), sum(fgcds), sum(inv), sum(zs), sum(correct)])
    for pl, add in zip(pair_table, adds):
        pl.insert(0, add.name)
    pair_table.insert(0, [None] + [dbl.name for dbl in dbls])
    return table, pair_table
[ ]:
split = split_params(params)
scalars = [random_scalar_trivial(n) for _ in trange(50, desc="Generate scalars")]
random_points = [random_point(split, params, randomized=False) for _ in trange(50, desc="Generate points")]
fixed_points = [fixed_point(params) for _ in trange(50, desc="Generate points")]

table, pair_table = simulate_table(scalars, random_points, split, params, formula_pairs, adds, dbls)
display(HTML(tabulate.tabulate(table, tablefmt="html", headers="firstrow")))
display(HTML(tabulate.tabulate(pair_table, tablefmt="html", headers="firstrow")))

table, pair_table = simulate_table(scalars, fixed_points, split, params, formula_pairs, adds, dbls)
display(HTML(tabulate.tabulate(table, tablefmt="html", headers="firstrow")))
display(HTML(tabulate.tabulate(pair_table, tablefmt="html", headers="firstrow")))

Reverse-engineering

[ ]:
def simulate_epa_oracle(affine_params, affine_point, scalar, real_coord_name="projective", real_add_name="add-2007-bl", real_dbl_name="dbl-2007-bl"):
    """
    Simulate an EPA oracle that computes a scalar multiplication of `affine_point` by `scalar` on `affine_params`.
    To select the "real" implementation, change the `real_coord_name`, `real_add_name` and `real_dbl_name` parameters.

    This simulates an LTR multiplier, we assume we already know the multiplier at this point.
    """
    real_coords = model.coordinates[real_coord_name]
    real_add = real_coords.formulas[real_add_name]
    real_dbl = real_coords.formulas[real_dbl_name]
    real_mult = LTRMultiplier(real_add, real_dbl, None, False, AccumulationOrder.PeqPR, True, True)
    params = affine_params.to_coords(real_coords)
    point = affine_point.to_model(real_coords, params.curve)
    real_mult.init(params, point)
    res = real_mult.multiply(scalar)
    try:
        res.to_affine()
        return True
    except NonInvertibleError as e:
        return False
[ ]:
def epa_precomp(affine_params, mult_factory, mult_class, model, queries=30):
    """
    Precompute a map of (cfg) -> set of indices into inputs for which the given cfg oracle will answer True,
    where inputs is a list of (scalar, point) pairs.

    Returns the list of inputs, the mapping and all of the considered cfgs.
    Note that the mapping might be restricted over a subset of the cfgs.
    """
    split = split_params(affine_params)
    scalars = [random_scalar_trivial(n) for _ in trange(queries, desc="Generate scalars")]
    random_points = [random_point(split, affine_params, randomized=False) for _ in trange(queries, desc="Generate points")]
    formula_classes = list(filter(lambda klass: klass in mult_class.requires, [AdditionFormula, DifferentialAdditionFormula, DoublingFormula, LadderFormula, NegationFormula]))
    results = {}
    inputs = list(zip(scalars, random_points))
    configs = set()
    for coord_name, coords in tqdm(model.coordinates.items(), desc="Precompute for coord systems"):
        try:
            params = affine_params.to_coords(coords)
        except UnsatisfiedAssumptionError:
            log(f"Skipping {coords.name}, does not fit.")
            continue
        log(f"Precomputing {coords.name}.")
        mapped_inputs = [(scalar, point.to_model(coords, params.curve)) for scalar, point in inputs]

        formula_groups = [list(filter(lambda formula: isinstance(formula, formula_class) and (formula.name.startswith("add") or formula.name.startswith("dbl")), coords.formulas.values())) for formula_class in formula_classes]
        formula_combinations = list(product(*formula_groups))

        for formulas in tqdm(formula_combinations, desc=coord_name, leave=False):
            cfg = tuple(formulas)
            configs.add(cfg)
            mult = mult_factory(*formulas)
            result = set()
            for i, pair in enumerate(mapped_inputs):
                scalar, point = pair
                mult.init(params, point)
                try:
                    res = mult.multiply(scalar)
                except UnsatisfiedAssumptionError as e:
                    break
                try:
                    res.to_affine()
                    result.add(i)
                except NonInvertibleError as e:
                    pass
            else:
                results[cfg] = result
    return inputs, results, configs
[ ]:
def epa_distinguish_precomp(inputs, precomp, configs, affine_params, oracle):
    """
    Distinguish the coordinate system and formulas using EPA given the precomputation.
    """
    dmap = Map.from_sets(configs, precomp)
    tree = Tree.build(configs, dmap)
    log("Built distinguishing tree.")
    log(tree.render())

    current_node = tree.root
    cfgs = list(precomp.keys())
    while current_node.children:
        best_distinguishing_index = current_node.dmap_input
        scalar, point = inputs[best_distinguishing_index]
        response = oracle(affine_params, point, scalar)
        log(f"Oracle response -> {response}")
        for cfg in cfgs:
            log(cfg, best_distinguishing_index in precomp[cfg])
        response_map = {child.response: child for child in current_node.children}
        current_node = response_map[response]
        cfgs = current_node.cfgs
        log(cfgs)
        log()
    return cfgs

Now we can run the precomp and the EPA reverse-engineering.

[ ]:
affine_params = load_params_ectester(io.BytesIO(curves[3].encode()), "affine")
inputs, precomp, configs = epa_precomp(affine_params, lambda add,dbl:LTRMultiplier(add, dbl, None, False, AccumulationOrder.PeqPR, True, True), LTRMultiplier, model)
epa_distinguish_precomp(inputs, precomp, configs, affine_params, simulate_epa_oracle)

Miscellaneous

[ ]:
param_categories = {
    "a=-1": ["projective-1"],
    "a=-3": ["projective-3", "jacobian-3", "xyzz-3"],
    "a=0": ["jacobian-0"],
    "generic": ["jacobian", "projective", "modified", "xyzz", "xz"],
    "b=0": ["w12-0"]
}
cfg_categories = {}
for name, coord_names in param_categories.items():
    category_cfgs = set()
    for coord_name in coord_names:
        coords = model.coordinates[coord_name]
        category_cfgs.update(filter(lambda cfg: cfg[0].coordinate_model == coords and cfg[1].coordinate_model == coords, configs))
    cfg_categories[name] = category_cfgs
category_map = {cfg: {"category": name} for name, category_cfgs in cfg_categories.items() for cfg in category_cfgs}
dmap_categories = Map.from_io_maps(configs, category_map)
[ ]:
dmap = Map.from_sets(configs, precomp, deduplicate=True)
tree = Tree.build(configs, dmap)
[ ]:
print(tree.describe())