Source code for pyecsca.ec.formula.graph
from .base import Formula
from .code import CodeFormula
from ..op import CodeOp, OpType
import matplotlib.pyplot as plt
import networkx as nx
from ast import parse, Expression
from typing import Dict, List, Tuple, Set, Optional, MutableMapping, Any
from copy import deepcopy
from abc import ABC, abstractmethod
[docs]
class Node(ABC):
def __init__(self):
self.incoming_nodes = []
self.outgoing_nodes = []
self.output_node = False
self.input_node = False
@property
@abstractmethod
def label(self) -> str:
pass
@property
@abstractmethod
def result(self) -> str:
pass
@property
def is_sub(self) -> bool:
return False
@property
def is_mul(self) -> bool:
return False
@property
def is_add(self) -> bool:
return False
@property
def is_id(self) -> bool:
return False
@property
def is_sqr(self) -> bool:
return False
@property
def is_pow(self) -> bool:
return False
@property
def is_inv(self) -> bool:
return False
@property
def is_div(self) -> bool:
return False
@property
def is_neg(self) -> bool:
return False
@abstractmethod
def __repr__(self) -> str:
pass
[docs]
def reconnect_outgoing_nodes(self, destination):
destination.output_node = self.output_node
for out in self.outgoing_nodes:
out.incoming_nodes = [
n if n != self else destination for n in out.incoming_nodes
]
destination.outgoing_nodes.append(out)
[docs]
class ConstantNode(Node):
color = "#b41f44"
def __init__(self, i: int):
super().__init__()
self.value = i
@property
def label(self) -> str:
return str(self.value)
@property
def result(self) -> str:
return str(self.value)
def __repr__(self) -> str:
return f"Node({self.value})"
[docs]
class CodeOpNode(Node):
color = "#1f78b4"
def __init__(self, op: CodeOp):
super().__init__()
self.op = op
if self.op.operator not in [
OpType.Sub,
OpType.Add,
OpType.Id,
OpType.Sqr,
OpType.Mult,
OpType.Div,
OpType.Pow,
OpType.Inv,
OpType.Neg,
]:
raise ValueError
[docs]
@classmethod
def from_str(cls, result: str, left, operator, right):
opstr = f"{result} = {left if left is not None else ''}{operator.op_str}{right if right is not None else ''}"
return cls(CodeOp(parse(opstr.replace("^", "**"))))
@property
def label(self) -> str:
return f"{self.op.result}:{self.op.operator.op_str}"
@property
def result(self) -> str:
return str(self.op.result)
@property
def optype(self) -> OpType:
return self.op.operator
@property
def is_sub(self) -> bool:
return self.optype == OpType.Sub
@property
def is_mul(self) -> bool:
return self.optype == OpType.Mult
@property
def is_add(self) -> bool:
return self.optype == OpType.Add
@property
def is_id(self) -> bool:
return self.optype == OpType.Id
@property
def is_sqr(self) -> bool:
return self.optype == OpType.Sqr
@property
def is_pow(self) -> bool:
return self.optype == OpType.Pow
@property
def is_inv(self) -> bool:
return self.optype == OpType.Inv
@property
def is_div(self) -> bool:
return self.optype == OpType.Div
@property
def is_neg(self) -> bool:
return self.optype == OpType.Neg
def __repr__(self) -> str:
return f"Node({self.op})"
[docs]
class InputNode(Node):
color = "#b41f44"
def __init__(self, input: str):
super().__init__()
self.input = input
self.input_node = True
@property
def label(self) -> str:
return self.input
@property
def result(self) -> str:
return self.input
def __repr__(self) -> str:
return f"Node({self.input})"
[docs]
def formula_input_variables(formula: Formula) -> List[str]:
return (
list(formula.inputs)
+ formula.parameters
+ formula.coordinate_model.curve_model.parameter_names
)
[docs]
class FormulaGraph:
coordinate_model: Any
name: str
shortname: str
parameters: List[str]
assumptions: List[Expression]
nodes: List[Node]
input_nodes: MutableMapping[str, InputNode]
output_names: Set[str]
roots: List[Node]
def __init__(self, formula: Formula, rename=True):
self.name = formula.name
self.shortname = formula.shortname
self.parameters = formula.parameters
self.assumptions = formula.assumptions
self.coordinate_model = formula.coordinate_model
self.output_names = formula.outputs
self.input_nodes = {v: InputNode(v) for v in formula_input_variables(formula)}
self.roots = list(self.input_nodes.values())
self.nodes = self.roots.copy()
discovered_nodes: Dict[str, Node] = self.input_nodes.copy() # type: ignore
constants: Dict[int, Node] = {}
for op in formula.code:
code_node = CodeOpNode(op)
for side in (op.left, op.right):
if side is None:
continue
if isinstance(side, int):
parent_node = constants.get(side, ConstantNode(side))
self.nodes.append(parent_node)
self.roots.append(parent_node)
else:
parent_node = discovered_nodes[side]
parent_node.outgoing_nodes.append(code_node)
code_node.incoming_nodes.append(parent_node)
self.nodes.append(code_node)
discovered_nodes[op.result] = code_node
# flag output nodes
for output_name in self.output_names:
discovered_nodes[output_name].output_node = True
# go through the nodes and make sure that every node is root or has parents
for node in self.nodes:
if not node.incoming_nodes and node not in self.roots:
self.roots.append(node)
if rename:
self.reindex()
[docs]
def to_formula(self, name=None) -> CodeFormula:
code = list(
map(
lambda x: deepcopy(x.op), # type: ignore
filter(lambda n: n not in self.roots, self.nodes),
)
)
parameters = list(self.parameters)
assumptions = [deepcopy(assumption) for assumption in self.assumptions]
for klass in CodeFormula.__subclasses__():
if klass.shortname == self.shortname:
return klass(
self.name if name is None else self.name + "-" + name,
code,
self.coordinate_model,
parameters,
assumptions,
)
raise ValueError(f"Bad formula type: {self.shortname}")
[docs]
def networkx_graph(self) -> nx.DiGraph:
graph = nx.DiGraph()
vertices = list(range(len(self.nodes)))
graph.add_nodes_from(vertices)
stack = self.roots.copy()
while stack:
node = stack.pop()
for out in node.outgoing_nodes:
stack.append(out)
graph.add_edge(self.node_index(node), self.node_index(out))
return graph
[docs]
def levels(self) -> List[List[Node]]:
stack = self.roots.copy()
levels = [(n, 0) for n in stack]
level_counter = 1
while stack:
tmp_stack = []
while stack:
node = stack.pop()
levels.append((node, level_counter))
for out in node.outgoing_nodes:
tmp_stack.append(out)
stack = tmp_stack
level_counter += 1
# separate into lists
level_lists: List[List[Node]] = [[] for _ in range(level_counter)]
discovered = []
for node, l in reversed(levels):
if node not in discovered:
level_lists[l].append(node)
discovered.append(node)
return level_lists
[docs]
def output_nodes(self) -> List[Node]:
return list(filter(lambda x: x.output_node, self.nodes))
[docs]
def planar_positions(self) -> Dict[int, Tuple[float, float]]:
positions = {}
for i, level in enumerate(self.levels()):
for j, node in enumerate(level):
positions[self.node_index(node)] = (
0.1 * float(i**2) + 3 * float(j),
-6 * float(i),
)
return positions
[docs]
def draw(self, filename: Optional[str] = None, figsize: Tuple[int, int] = (12, 12)):
gnx = self.networkx_graph()
pos = nx.rescale_layout_dict(self.planar_positions())
plt.figure(figsize=figsize)
colors = [self.nodes[n].color for n in gnx.nodes]
labels = {n: self.nodes[n].label for n in gnx.nodes}
nx.draw(gnx, pos, node_color=colors, node_size=2000, labels=labels)
if filename:
plt.savefig(filename)
plt.close()
else:
plt.show()
[docs]
def find_all_paths(self) -> List[List[Node]]:
gnx = self.networkx_graph()
index_paths = []
for u in self.roots:
iu = self.node_index(u)
for v in self.output_nodes():
iv = self.node_index(v)
index_paths.extend(nx.all_simple_paths(gnx, iu, iv))
paths = []
for p in index_paths:
paths.append([self.nodes[v] for v in p])
return paths
[docs]
def remove_node(self, node):
self.nodes.remove(node)
if node in self.roots:
self.roots.remove(node)
for in_node in node.incoming_nodes:
in_node.outgoing_nodes = list(
filter(lambda x: x != node, in_node.outgoing_nodes)
)
for out_node in node.outgoing_nodes:
out_node.incoming_nodes = list(
filter(lambda x: x != node, out_node.incoming_nodes)
)
[docs]
def add_node(self, node):
if isinstance(node, ConstantNode):
self.roots.append(node)
self.nodes.append(node)
[docs]
def reindex(self):
results: Dict[str, str] = {}
counter = 0
for node in self.nodes:
if not isinstance(node, CodeOpNode):
continue
op = node.op
result, left, operator, right = (
op.result,
op.left,
op.operator.op_str,
op.right,
)
if left in results:
left = results[left]
if right in results:
right = results[right]
if not node.output_node:
new_result = f"iv{counter}"
counter += 1
else:
new_result = result
opstr = f"{new_result} = {left if left is not None else ''}{operator}{right if right is not None else ''}"
results[result] = new_result
node.op = CodeOp(parse(opstr.replace("^", "**")))
[docs]
def rename_ivs(formula: Formula) -> CodeFormula:
graph = FormulaGraph(formula)
return graph.to_formula()