RPA-based reverse-engineering

This notebook showcases the RPA-based reverse-engineering technique for scalar multipliers.

[ ]:
from collections import Counter
from math import sqrt
import numpy as np
import xarray as xr
import holoviews as hv
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
from functools import partial, lru_cache
from scipy.stats import bernoulli
from concurrent.futures import ProcessPoolExecutor, as_completed

from IPython.display import HTML, display
from tqdm.auto import tqdm, trange
import tabulate
from anytree import LevelOrderGroupIter, RenderTree

from pyecsca.ec.model import ShortWeierstrassModel
from pyecsca.ec.coordinates import AffineCoordinateModel
from pyecsca.ec.curve import EllipticCurve
from pyecsca.ec.params import DomainParameters, get_params
from pyecsca.ec.formula import FormulaAction
from pyecsca.ec.point import Point
from pyecsca.ec.mod import Mod
from pyecsca.ec.mult import *
from pyecsca.misc.utils import silent, TaskExecutor
from pyecsca.sca.trace.sampling import downsample_average, downsample_max
from pyecsca.sca.trace.process import normalize, rolling_mean, absolute
from pyecsca.sca.trace.combine import average, subtract
from pyecsca.sca.trace.test import welch_ttest
from pyecsca.sca.attack.leakage_model import HammingWeight, NormalNoice
from pyecsca.ec.context import DefaultContext, local
from pyecsca.sca.re.rpa import MultipleContext, rpa_distinguish, RPA
from pyecsca.sca.trace import Trace
from pyecsca.sca.trace.plot import plot_trace, plot_traces

from eval import (eval_tree_symmetric, eval_tree_asymmetric,
                    success_rate_symmetric, success_rate_asymmetric,
                    query_rate_symmetric, query_rate_asymmetric,
                    precise_rate_symmetric, precise_rate_asymmetric,
                    amount_rate_symmetric, amount_rate_asymmetric,
                    success_rate_vs_majority_symmetric, success_rate_vs_majority_asymmetric,
                    success_rate_vs_query_rate_symmetric, load, store)
[ ]:
%matplotlib ipympl
hv.extension("bokeh")
[ ]:
model = ShortWeierstrassModel()
coordsaff = AffineCoordinateModel(model)
coords = model.coordinates["projective"]
add = coords.formulas["add-2007-bl"]  # The formulas are irrelevant for this method
dbl = coords.formulas["dbl-2007-bl"]
neg = coords.formulas["neg"]

# A 64-bit prime order curve for testing things out
p = 0xc50de883f0e7b167
a = Mod(0x4833d7aa73fa6694, p)
b = Mod(0xa6c44a61c5323f6a, p)
gx = Mod(0x5fd1f7d38d4f2333, p)
gy = Mod(0x21f43957d7e20ceb, p)
n = 0xc50de885003b80eb
h = 1

# A (0, y) RPA point on the above curve, in affine coords.
P0_aff = Point(coordsaff, x=Mod(0, p), y=Mod(0x1742befa24cd8a0d, p))

infty = Point(coords, X=Mod(0, p), Y=Mod(1, p), Z=Mod(0, p))
g = Point(coords, X=gx, Y=gy, Z=Mod(1, p))

curve = EllipticCurve(model, coords, p, infty, dict(a=a,b=b))
params = DomainParameters(curve, g, n, h)

# And P-256 for eval
p256 = get_params("secg", "secp256r1", "projective")

Exploration

First select a bunch of multipliers. We will be trying to distinguish among these.

[ ]:
multipliers = [
    LTRMultiplier(add, dbl, None, False, AccumulationOrder.PeqPR, True, True),
    LTRMultiplier(add, dbl, None, True, AccumulationOrder.PeqPR, True, True),
    RTLMultiplier(add, dbl, None, False, AccumulationOrder.PeqPR, True),
    RTLMultiplier(add, dbl, None, True, AccumulationOrder.PeqPR, False),
    SimpleLadderMultiplier(add, dbl, None, True, True),
    BinaryNAFMultiplier(add, dbl, neg, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    BinaryNAFMultiplier(add, dbl, neg, None, ProcessingDirection.RTL, AccumulationOrder.PeqPR, True),
    WindowNAFMultiplier(add, dbl, neg, 3, None, AccumulationOrder.PeqPR, True, True),
    WindowNAFMultiplier(add, dbl, neg, 4, None, AccumulationOrder.PeqPR, True, True),
    WindowNAFMultiplier(add, dbl, neg, 5, None, AccumulationOrder.PeqPR, True, True),
    WindowBoothMultiplier(add, dbl, neg, 3, None, AccumulationOrder.PeqPR, True, True),
    WindowBoothMultiplier(add, dbl, neg, 4, None, AccumulationOrder.PeqPR, True, True),
    WindowBoothMultiplier(add, dbl, neg, 5, None, AccumulationOrder.PeqPR, True, True),
    SlidingWindowMultiplier(add, dbl, 3, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    SlidingWindowMultiplier(add, dbl, 4, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    SlidingWindowMultiplier(add, dbl, 5, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    SlidingWindowMultiplier(add, dbl, 3, None, ProcessingDirection.RTL, AccumulationOrder.PeqPR, True),
    SlidingWindowMultiplier(add, dbl, 4, None, ProcessingDirection.RTL, AccumulationOrder.PeqPR, True),
    SlidingWindowMultiplier(add, dbl, 5, None, ProcessingDirection.RTL, AccumulationOrder.PeqPR, True),
    FixedWindowLTRMultiplier(add, dbl, 3, None, AccumulationOrder.PeqPR, True),
    FixedWindowLTRMultiplier(add, dbl, 4, None, AccumulationOrder.PeqPR, True),
    FixedWindowLTRMultiplier(add, dbl, 5, None, AccumulationOrder.PeqPR, True),
    FixedWindowLTRMultiplier(add, dbl, 8, None, AccumulationOrder.PeqPR, True),
    FixedWindowLTRMultiplier(add, dbl, 16, None, AccumulationOrder.PeqPR, True),
    FullPrecompMultiplier(add, dbl, None, True, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True, True),
    FullPrecompMultiplier(add, dbl, None, False, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True, True),
    BGMWMultiplier(add, dbl, 2, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    BGMWMultiplier(add, dbl, 3, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    BGMWMultiplier(add, dbl, 4, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    BGMWMultiplier(add, dbl, 5, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    CombMultiplier(add, dbl, 2, None, AccumulationOrder.PeqPR, True),
    CombMultiplier(add, dbl, 3, None, AccumulationOrder.PeqPR, True),
    CombMultiplier(add, dbl, 4, None, AccumulationOrder.PeqPR, True),
    CombMultiplier(add, dbl, 5, None, AccumulationOrder.PeqPR, True)
]
print(len(multipliers))

Then select a random scalar and simulate computation using all of the multipliers, track the multiples, print the projective and affine results.

[ ]:
scalar = 0b1000000000000000000000000000000000000000000000000
scalar = 0b1111111111111111111111111111111111111111111111111
scalar = 0b1010101010101010101010101010101010101010101010101
scalar = 0b1111111111111111111111110000000000000000000000000
scalar = 123456789123456789
scarar = 8750920244948492046
# multiples is a mapping from a multiple (integer) to a set of scalar multipliers that compute said multiple when doing [scalar]P
multiples = {}

table = [["Multiplier", "multiples"]]

for mult in multipliers:
    with local(MultipleContext()) as ctx:
        mult.init(params, g)
        res = mult.multiply(scalar)
    for m in ctx.points.values():
        s = multiples.setdefault(m, set())
        s.add(mult)
    table.append([str(mult), str(list(ctx.points.values()))])

display(HTML(tabulate.tabulate(table, tablefmt="html", headers="firstrow")))

Pick a multiple k that is computed by some multiplier for the scalar, invert it mod n, and do [k^-1]P0 to obtain a point P0_target, such that, [k]P0_target = P0 and P0 has a zero coordinate.

[ ]:
k = 108
kinv = Mod(k, n).inverse()
P0_target = curve.affine_multiply(P0_aff, int(kinv)).to_model(coords, curve)

print("Original P0", P0_aff)
print("P0_target  ", P0_target.to_affine())
print("Verify P0  ", curve.affine_multiply(P0_target.to_affine(), k))

Now go over the multipliers with P0_target and the original scalar as input. Then look whether a zero coordinate point was computed. Also look at whether the multiple “k” was computed. These two should be the same.

[ ]:
table = [["Multiplier", "zero present", "multiple computed"]]

for mult in multipliers:
    with local(MultipleContext()) as ctx:
        mult.init(params, P0_target)
        res = mult.multiply(scalar)
    zero = any(map(lambda P: P.X == 0 or P.Y == 0, ctx.points.keys()))
    multiple = k in ctx.points.values()
    table.append([str(mult), f"<b>{zero}</b>" if zero else zero, f"<b>{multiple}</b>" if multiple else multiple])

display(HTML(tabulate.tabulate(table, tablefmt="unsafehtml", headers="firstrow", colalign=("left", "center", "center"))))

Now lets look at the relation of multiples to multipliers.

[ ]:
table = [["Multiple", "Multipliers"]]
for multiple, mults in multiples.items():
    table.append([bin(multiple), [mult.__class__.__name__ for mult in mults]])

display(HTML(tabulate.tabulate(table, tablefmt="html", headers="firstrow")))

Note that all of the exploration so far was in a context of a fixed scalar. Even though for a given scalar some multipliers might be indistinguishable from the perspective of the multiples they compute, there may be other scalars that distinguish them.

Reverse-engineering

Oracle simulation

The simulated_oracle function simulates an RPA oracle that detect a zero coordinate point in the scalar multiplication. This can be used by the rpa_distinguish function to distinguish the true scalar multiplier. The oracle is parametrized with the simulated multiplier index in the table of multipliers (it simulates this “real” multiplier). Furthermore, lets also examine a noisy_oracle (with a flip probability) and a biased_oracle (with asymmetric flip probability).

Note that the oracle has two additional parameters measure_init and measure_multiply which determine whether the oracle considers the zero coordinate point in scalar multiplier initialization (precomputation) and in scalar multiplier multiplication, respectively. This is important for scalar multipliers with precomputation as there one might be able to separate the precomputation and multiplication stages and obtain oracle answers on both separately.

[ ]:
def simulated_oracle(scalar, affine_point, simulate_mult_id=0, measure_init=True, measure_multiply=True, randomize=False):
    real_mult = multipliers[simulate_mult_id]
    point = affine_point.to_model(params.curve.coordinate_model, params.curve, randomized=randomize)

    # Simulate the multiplier init
    with local(MultipleContext()) as ctx:
        real_mult.init(params, point)
    init_points = set(ctx.parents.keys())
    init_parents = set(sum((ctx.parents[point] for point in init_points), []))
    # Did zero happen in some input point during the init?
    init_zero = any(map(lambda P: P.X == 0 or P.Y == 0, init_parents))

    # Simulate the multiplier multiply
    with local(ctx) as ctx:
        real_mult.multiply(scalar)
    all_points = set(ctx.parents.keys())
    multiply_parents = set(sum((ctx.parents[point] for point in all_points - init_points), []))
    # Did zero happen in some input point during the multiply?
    multiply_zero = any(map(lambda P: P.X == 0 or P.Y == 0, multiply_parents))
    real_result = (init_zero and measure_init) or (multiply_zero and measure_multiply)
    return real_result

def noisy_oracle(oracle, flip_proba=0):
    def noisy(*args, **kwargs):
        real_result = oracle(*args, **kwargs)
        change = bernoulli(flip_proba).rvs()
        return bool(real_result ^ change)
    return noisy

def biased_oracle(oracle, flip_0=0, flip_1=0):
    def biased(*args, **kwargs):
        real_result = oracle(*args, **kwargs)
        change = bernoulli(flip_1).rvs() if real_result else bernoulli(flip_0).rvs()
        return bool(real_result ^ change)
    return biased

We can see how the RPA-RE method distinguishes a given multiplier:

[ ]:
res = rpa_distinguish(params, multipliers, simulated_oracle)

Let’s see if the result is correct. You can replace the simulated_oracle above with noisy_oracle(simulated_oracle, flip_proba=0.2) or with biased_oracle(simulated_oracle, flip_0=0.2, flip_1=0.1) to see how the process and result changes with noise.

[ ]:
print(multipliers[0] in res)

We can also have a look at the distinguishing tree that the method builds for this set of multipliers.

[ ]:
re = RPA(set(multipliers))
with silent():
    re.build_tree(p256, tries=10)
print(re.tree.describe())

We can also look at the rough tree structure.

[ ]:
print(re.tree.render_basic())

What about (symmetric) noise?

Now we can examine how the method performs in the presence of noise and with various majority vote parameters. Note that the code below spawns several processes (num_cores) and saturates their CPU fully, so set this to something appropriate.

[ ]:
def build_tree(cfgs):
    with silent():
        re = RPA(set(cfgs))
        re.build_tree(p256, tries=10)
        return re.tree

correct_rate, precise_rate, amount_rate, query_rate = eval_tree_symmetric(set(multipliers), build_tree, num_trees=100, num_tries=100, num_cores=30)

We can plot several heatmaps: - One for the average number of queries to the oracle. - One for the success rate of the reverse-engineering. - One for the precision of the reverse-engineering.

[ ]:
success_rate_symmetric(correct_rate, 100 / len(multipliers)).savefig("rpa_re_success_rate_symmetric.pdf", bbox_inches="tight")
query_rate_symmetric(query_rate).savefig("rpa_re_query_rate_symmetric.pdf", bbox_inches="tight")
precise_rate_symmetric(precise_rate).savefig("rpa_re_precise_rate_symmetric.pdf", bbox_inches="tight")
amount_rate_symmetric(amount_rate).savefig("rpa_re_amount_rate_symmetric.pdf", bbox_inches="tight")

Another way to look at these metrics is a scatter plot.

[ ]:
success_rate_vs_query_rate_symmetric(query_rate, correct_rate).savefig("rpa_re_scatter_symmetric.pdf", bbox_inches="tight")
success_rate_vs_majority_symmetric(correct_rate).savefig("rpa_re_plot_symmetric.pdf", bbox_inches="tight")

And save the results for later.

[ ]:
store("rpa_re_symmetric.nc", correct_rate, precise_rate, amount_rate, query_rate)
[ ]:
correct_rate, precise_rate, amount_rate, query_rate = load("rpa_re_symmetric.nc")

What about (asymmetric) noise?

The oracle may not only be noisy, but biased, this computation evaluates that case. Beware, for the same parameters this is about 6x slower because of the other dimension (two error probabilities instead of one).

[ ]:
def build_tree(cfgs):
    with silent():
        re = RPA(set(cfgs))
        re.build_tree(p256, tries=10)
        return re.tree

correct_rate_b, precise_rate_b, amount_rate_b, query_rate_b = eval_tree_asymmetric(set(multipliers), build_tree, num_trees=100, num_tries=100, num_cores=30)
[ ]:
success_rate_asymmetric(correct_rate_b, 100 / len(multipliers)).savefig("rpa_re_success_rate_asymmetric.pdf", bbox_inches="tight")
query_rate_asymmetric(query_rate_b).savefig("rpa_re_query_rate_asymmetric.pdf", bbox_inches="tight")
precise_rate_asymmetric(precise_rate_b).savefig("rpa_re_precise_rate_asymmetric.pdf", bbox_inches="tight")
amount_rate_asymmetric(amount_rate_b).savefig("rpa_re_amount_rate_asymmetric.pdf", bbox_inches="tight")
success_rate_vs_majority_asymmetric(correct_rate_b).savefig("rpa_re_plot_asymmetric.pdf", bbox_inches="tight")

And save the results for later.

[ ]:
store("rpa_re_asymmetric.nc", correct_rate_b, precise_rate_b, amount_rate_b, query_rate_b)
[ ]:
correct_rate_b, precise_rate_b, amount_rate_b, query_rate_b = load("rpa_re_asymmetric.nc")

Method simulation

The simulate_trace function simulates a Hamming weight leakage trace of a given multiplier computing a scalar multiple. This is used by the simulated_rpa_trace function that does the RPA attack on simulated traces and returns the differential trace. This is in turn used to build the simulated_rpa_oracle which can be used by the rpa_distinguish function to perform RPA-RE and distinguish the true scalar multiplier. The oracle is parametrized with the simulated multiplier index in the table of multipliers (it simulates this “real” multiplier).

[ ]:
def simulate_trace(mult, scalar, point):
    with local(DefaultContext()) as ctx:
        mult.init(params, point)
        mult.multiply(scalar)

    lm = HammingWeight()
    trace = []

    def callback(action):
        if isinstance(action, FormulaAction):
            for intermediate in action.op_results:
                leak = lm(intermediate.value)
                trace.append(leak)

    ctx.actions.walk(callback)
    return Trace(np.array(trace))

def simulated_rpa_trace(mult, scalar, affine_point, noise, num_target=10, num_random=10):
    random_traces = [noise(normalize(simulate_trace(mult, scalar, params.curve.affine_random().to_model(params.curve.coordinate_model, params.curve, randomized=True)))) for _ in range(num_random)]
    target_traces = [noise(normalize(simulate_trace(mult, scalar, affine_point.to_model(params.curve.coordinate_model, params.curve, randomized=True)))) for _ in range(num_target)]

    random_avg = average(*random_traces)
    target_avg = average(*target_traces)

    diff_trace = subtract(random_avg, target_avg)
    return diff_trace

def simulated_rpa_oracle(scalar, affine_point, simulate_mult_id = 0, variance=1):
    real_mult = multipliers[simulate_mult_id]
    noise = NormalNoice(0, variance)
    diff_trace = normalize(simulated_rpa_trace(real_mult, scalar, affine_point, noise))
    peaks, props = find_peaks(diff_trace.samples, height=4)
    return len(peaks) != 0
[ ]:
table = [["True multiplier", "Reversed", "Correct", "Remaining"]]
with silent():
    for i, mult in tqdm(enumerate(multipliers), total=len(multipliers)):
        res = rpa_distinguish(params, multipliers, partial(simulated_rpa_oracle, simulate_mult_id = i))
        table.append([mult, res, mult in res, len(res)])
display(HTML(tabulate.tabulate(table, tablefmt="html", headers="firstrow")))

Note that the oracle function above has several parameters, like noise standard deviation, amount of traces simulated, and peak finding height threshold. Below we analyze how these parameters influence the resulting error probabilities.

[ ]:
def eval(threshold, sdev, tries, num_traces):
    noise = NormalNoice(0, sdev)
    aff = P0_target.to_affine()
    true_pos = 0
    false_pos = 0
    for _ in range(tries):
        diff_real = normalize(simulated_rpa_trace(multipliers[0], scalar, aff, noise, num_random=num_traces, num_target=num_traces))
        true_pos += len(find_peaks(diff_real.samples, height=threshold)[0]) > 0
        diff_nothing = normalize(simulated_rpa_trace(multipliers[7], scalar, aff, noise, num_random=num_traces, num_target=num_traces))
        false_pos += len(find_peaks(diff_nothing.samples, height=threshold)[0]) > 0
    false_neg = tries - true_pos
    true_neg = tries - false_pos
    return true_pos / tries, true_neg / tries, false_pos / tries, false_neg / tries

threshold_range = [4]
sdev_range = list(range(0, 11))
traces_range = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
e0 = xr.DataArray(np.zeros((len(threshold_range), len(sdev_range), len(traces_range))),
                  dims=("threshold", "sdev", "traces"),
                  coords={"threshold": threshold_range, "sdev": sdev_range, "traces": traces_range}, name="e0")
e1 = xr.DataArray(np.zeros((len(threshold_range), len(sdev_range), len(traces_range))),
                  dims=("threshold", "sdev", "traces"),
                  coords={"threshold": threshold_range, "sdev": sdev_range, "traces": traces_range}, name="e1")
tries = 200
with TaskExecutor(max_workers=30) as pool:
    for threshold in threshold_range:
        for sdev in sdev_range:
            for num_traces in traces_range:
                pool.submit_task((threshold, sdev, num_traces),
                                 eval, threshold, sdev, tries, num_traces)
    for (threshold, sdev, num_traces), future in tqdm(pool.as_completed(), total=len(pool.tasks), smoothing=0):
        true_pos, true_neg, false_pos, false_neg = future.result()
        e0.loc[threshold, sdev, num_traces] = false_pos
        e1.loc[threshold, sdev, num_traces] = false_neg
[ ]:
fig, axs = plt.subplots(ncols=2, figsize=(10, 4), sharey="row")
for i, threshold in enumerate(threshold_range):
    res0 = e0.sel(threshold=threshold).plot(ax=axs[0], vmin=0, vmax=1, cmap="plasma", add_colorbar=False)
    for j, sdev in enumerate(sdev_range):
        for k, traces in enumerate(traces_range):
            val = e0.sel(threshold=threshold, sdev=sdev, traces=traces)
            sval = f"{val:.2f}"
            color="white" if val < 0.5 else "black"
            if sval == "0.00":
                color="grey"
            axs[0].text(traces, sdev, sval.lstrip("0"), ha="center", va="center", color=color)
    axs[0].set_title("$e_0$")
    axs[0].set_ylabel("noise $\sigma$")
    axs[0].set_xlabel("traces per group")
    res1 = e1.sel(threshold=threshold).plot(ax=axs[1], vmin=0, vmax=1, cmap="plasma", add_colorbar=False)
    for j, sdev in enumerate(sdev_ramge):
        for k, traces in enumerate(traces_range):
            val = e1.sel(threshold=threshold, sdev=sdev, traces=traces)
            sval = f"{val:.2f}"
            color="white" if val < 0.5 else "black"
            if sval == "0.00":
                color="grey"
            axs[1].text(traces, sdev, sval.lstrip("0"), ha="center", va="center", color=color)
    axs[1].set_title("$e_1$")
    axs[1].set_ylabel("noise $\sigma$")
    axs[1].set_xlabel("traces per group")
    fig.tight_layout(h_pad=1.5, rect=(0, 0, 0.9, 1))
    cbar_ymin, cbar_ymax = axs[0].get_position().ymin, axs[0].get_position().ymax
    cbar_ax = fig.add_axes((0.92, 0.145, 0.02, 0.77))
    cbar = fig.colorbar(res0, cax=cbar_ax, label="error probability")
    cbar.ax.yaxis.set_label_coords(2.8, 0.5);
    cbar.ax.set_ylabel("error probability", rotation=-90, va="bottom")
[ ]:
fig.savefig("rpa_re_errors.pdf", bbox_inches="tight")
[ ]:
fig, axs = plt.subplots(ncols=len(traces_range), nrows=len(var_range), figsize=(10, 12), sharex="col", sharey="row")
for i, traces in enumerate(traces_range):
    for j, var in enumerate(var_range):
        (e0.sel(traces=traces, var=var) + e1.sel(traces=traces, var=var)).plot(ax=axs[j,i], label="e0")
        axs[j, i].set_ylabel("error")
        axs[j,i].set_ylim((0, 1))
fig.tight_layout()
[ ]:
fig.savefig("rpa_re_errors_all.pdf", bbox_inches="tight")
[ ]: