from typing import Dict, Iterator, List, Any
from ast import parse
from ..op import OpType, CodeOp
from .base import Formula
from .graph import FormulaGraph, ConstantNode, CodeOpNode, CodeFormula
from itertools import chain, combinations
from ..point import Point
from ..mod import Mod
[docs]
def subnode_lists(graph: FormulaGraph):
return powerlist(filter(lambda x: x not in graph.roots and x.is_sub, graph.nodes))
[docs]
def switch_sign(graph: FormulaGraph, node_combination) -> FormulaGraph:
nodes_i = [graph.node_index(node) for node in node_combination]
graph = graph.deepcopy()
node_combination = {graph.nodes[node_i] for node_i in nodes_i}
output_signs = {out: 1 for out in graph.output_names}
queue = []
for node in node_combination:
change_sides(node)
if node.output_node:
output_signs[node.result] = -1
queue.extend([(out, node.result) for out in node.outgoing_nodes])
while queue:
node, variable = queue.pop()
queue = switch_sign_propagate(node, variable, output_signs) + queue
sign_test(output_signs, graph.coordinate_model)
return graph
[docs]
def sign_test(output_signs: Dict[str, int], coordinate_model: Any):
scale = coordinate_model.formulas.get("z", None)
if scale is None:
scale = coordinate_model.formulas.get("scale", None)
p = 7
out_inds = set(map(lambda x: "".join([o for o in x if o.isdigit()]), output_signs))
for ind in out_inds:
point_dict = {}
for out, sign in output_signs.items():
if not out.endswith(ind):
continue
out_var = out[:out.index(ind)]
if not out_var.isalpha():
continue
point_dict[out_var] = Mod(sign, p)
point = Point(coordinate_model, **point_dict)
try:
apoint = point.to_affine()
except NotImplementedError:
# Ignore switch signs if we cannot test them.
if scale is None:
raise BadSignSwitch
apoint = scale(p, point)[0]
if set(apoint.coords.values()) != {Mod(1, p)}:
raise BadSignSwitch
[docs]
class BadSignSwitch(Exception):
pass
[docs]
def switch_sign_propagate(
node: CodeOpNode, variable: str, output_signs: Dict[str, int]
):
if node.is_add:
if variable == node.incoming_nodes[1].result:
node.op = change_operator(node.op, OpType.Sub)
return []
change_sides(node)
node.op = change_operator(node.op, OpType.Sub)
return []
if node.is_id or node.is_neg:
output_signs[node.result] *= -1
return [(child, node.result) for child in node.outgoing_nodes]
if node.is_sqr:
return []
if node.is_sub:
if node.incoming_nodes[0].result == variable:
node.op = change_operator(node.op, OpType.Add)
if node.output_node:
output_signs[node.result] *= -1
return [(child, node.result) for child in node.outgoing_nodes]
node.op = change_operator(node.op, OpType.Add)
return []
if node.is_pow:
exponent = next(
filter(lambda n: isinstance(n, ConstantNode), node.incoming_nodes)
)
if exponent.value % 2 == 0:
return []
if node.output_node:
output_signs[node.result] *= -1
if not (node.is_mul or node.is_pow or node.is_inv or node.is_div):
raise ValueError
return [(child, node.result) for child in node.outgoing_nodes]
[docs]
def change_operator(op, new_operator):
result, left, right = op.result, op.left, op.right
opstr = f"{result} = {left if left is not None else ''}{new_operator.op_str}{right if right is not None else ''}"
return CodeOp(parse(opstr.replace("^", "**")))
[docs]
def change_sides(node):
op = node.op
result, left, operator, right = op.result, op.left, op.operator.op_str, op.right
left, right = right, left
opstr = f"{result} = {left if left is not None else ''}{operator}{right if right is not None else ''}"
node.op = CodeOp(parse(opstr.replace("^", "**")))
node.incoming_nodes[1], node.incoming_nodes[0] = (
node.incoming_nodes[0],
node.incoming_nodes[1],
)
[docs]
def powerlist(iterable: Iterator) -> List:
s = list(iterable)
return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)))