from ast import parse
from typing import Iterator, List, Type, Optional, Set
from ..op import OpType
from .base import Formula
from .graph import FormulaGraph, Node, CodeOpNode, CodeOp, CodeFormula
[docs]
class Fliparoo:
"""
Fliparoo is a chain of nodes N1->N2->...->Nk in FormulaGraph for k>=2 such that:
- All Ni are * or All Ni are +/-
- For every Ni, except for Nk, the only outgoing node is Ni+1
- Neither of N1,...,Nk-1 is an output node
"""
nodes: List[CodeOpNode]
graph: FormulaGraph
operator: Optional[OpType]
def __init__(self, chain: List[CodeOpNode], graph: FormulaGraph):
self.verify_chain(chain)
self.nodes = chain
self.graph = graph
self.operator = None
[docs]
def verify_chain(self, nodes: List[CodeOpNode]):
for i, node in enumerate(nodes[:-1]):
if node.outgoing_nodes != [nodes[i + 1]]:
raise BadFliparoo
if node.output_node:
raise BadFliparoo
@property
def first(self):
return self.nodes[0]
@property
def last(self):
return self.nodes[-1]
def __len__(self):
return len(self.nodes)
def __repr__(self):
return "->".join(map(lambda x: x.__repr__(), self.nodes))
[docs]
def previous(self, node: CodeOpNode) -> Optional[CodeOpNode]:
i = self.nodes.index(node)
if i == 0:
return None
return self.nodes[i - 1]
def __getitem__(self, i: int):
return self.nodes[i]
def __iter__(self):
return iter(self.nodes)
def __eq__(self, other):
return self.graph == other.graph and self.nodes == other.nodes
# unhashable at the moment
__hash__ = None # type: ignore
[docs]
def deepcopy(self):
ngraph = self.graph.deepcopy()
nchain = [mirror_node(node, self.graph, ngraph) for node in self.nodes]
return self.__class__(nchain, ngraph)
[docs]
def slice(self, i: int, j: int):
return self.__class__(self.nodes[i:j], self.graph)
[docs]
class MulFliparoo(Fliparoo):
def __init__(self, chain: List[CodeOpNode], graph: FormulaGraph):
super().__init__(chain, graph)
operations = {node.op.operator for node in self.nodes}
if len(operations) != 1 or list(operations)[0] != OpType.Mult:
raise BadFliparoo
self.operator = OpType.Mult
[docs]
class AddSubFliparoo(Fliparoo):
def __init__(self, chain: List[CodeOpNode], graph: FormulaGraph):
super().__init__(chain, graph)
operations = {node.op.operator for node in self.nodes}
if not operations.issubset([OpType.Add, OpType.Sub]):
raise BadFliparoo
[docs]
class AddFliparoo(Fliparoo):
def __init__(self, chain: List[CodeOpNode], graph: FormulaGraph):
super().__init__(chain, graph)
operations = {node.op.operator for node in self.nodes}
if len(operations) != 1 or list(operations)[0] != OpType.Add:
raise BadFliparoo
self.operator = OpType.Add
[docs]
class BadFliparoo(Exception):
pass
[docs]
def find_fliparoos(
graph: FormulaGraph, fliparoo_type: Optional[Type[Fliparoo]] = None
) -> List[Fliparoo]:
"""Finds a list of Fliparoos in a graph"""
fliparoos: List[Fliparoo] = []
for ilong_path in graph.find_all_paths():
long_path = ilong_path[1:] # get rid of the input variables
fliparoo = largest_fliparoo(long_path, graph, fliparoo_type) # type: ignore
if fliparoo and fliparoo not in fliparoos:
fliparoos.append(fliparoo)
# remove duplicities and fliparoos in inclusion
fliparoos = sorted(fliparoos, key=len, reverse=True)
longest_fliparoos: List[Fliparoo] = []
for fliparoo in fliparoos:
if not is_subfliparoo(fliparoo, longest_fliparoos):
longest_fliparoos.append(fliparoo)
return longest_fliparoos
[docs]
def is_subfliparoo(fliparoo: Fliparoo, longest_fliparoos: List[Fliparoo]) -> bool:
"""Returns true if fliparoo is a part of any fliparoo in longest_fliparoos"""
for other_fliparoo in longest_fliparoos:
l1, l2 = len(fliparoo), len(other_fliparoo)
for i in range(l2 - l1):
if other_fliparoo.slice(i, i + l1) == fliparoo:
return True
return False
[docs]
def largest_fliparoo(
chain: List[CodeOpNode],
graph: FormulaGraph,
fliparoo_type: Optional[Type[Fliparoo]] = None,
) -> Optional[Fliparoo]:
"""Finds the largest fliparoo in a list of Nodes"""
for i in range(len(chain) - 1):
for j in range(len(chain) - 1, i, -1):
subchain = chain[i : j + 1]
if fliparoo_type:
try:
fliparoo_type(subchain, graph)
except BadFliparoo:
continue
try:
return MulFliparoo(subchain, graph)
except BadFliparoo:
pass
try:
return AddSubFliparoo(subchain, graph)
except BadFliparoo:
pass
return None
[docs]
class SignedNode:
"""
Represents a summand in an expression X1-X2+X3+X4-X5...
Used for creating +/- Fliparoos
"""
node: CodeOpNode
sign: int
def __init__(self, node: CodeOpNode):
self.node = node
self.sign = 1
def __repr__(self):
s = {1: "+", -1: "-"}[self.sign]
return f"{s}{self.node.__repr__()}"
[docs]
class SignedSubGraph:
"""Subgraph of an EFDFormula graph with signed nodes"""
def __init__(self, nodes: List[SignedNode], graph: FormulaGraph):
self.nodes = nodes
self.supergraph = graph
[docs]
def add_node(self, node: SignedNode):
self.nodes.append(node)
[docs]
def remove_node(self, node: SignedNode):
self.nodes.remove(node)
[docs]
def change_signs(self):
for node in self.nodes:
node.sign *= -1
def __getitem__(self, i):
return self.nodes[i]
@property
def components(self):
return len(self.nodes)
[docs]
def deepcopy(self):
sgraph = self.supergraph.deepcopy()
return SignedSubGraph(
[mirror_node(n, self.supergraph, sgraph) for n in self.nodes], sgraph
)
[docs]
def mirror_node(node, graph, graph_copy):
"""Finds the corresponding node in a copy of the graph"""
if isinstance(node, SignedNode):
ns = SignedNode(graph_copy.nodes[graph.node_index(node.node)])
ns.sign = node.sign
return ns
if isinstance(node, Node):
return graph_copy.nodes[graph.node_index(node)]
raise NotImplementedError
[docs]
class DummyNode(Node):
def __repr__(self):
return "Dummy node"
@property
def label(self):
return None
@property
def result(self):
return None
[docs]
def generate_fliparood_graphs(fliparoo: Fliparoo) -> Iterator[FormulaGraph]:
fliparoo = fliparoo.deepcopy()
last_str = fliparoo.last.result
disconnect_fliparoo_outputs(fliparoo)
signed_subgraph = extract_fliparoo_signed_inputs(fliparoo)
# Starting with a single list of unconnected signed nodes
signed_subgraphs = [signed_subgraph]
for _ in range(signed_subgraph.components - 1):
subgraphs_updated = []
for signed_subgraph in signed_subgraphs:
extended_subgraphs = combine_all_pairs_signed_nodes(
signed_subgraph, fliparoo
)
subgraphs_updated.extend(extended_subgraphs)
signed_subgraphs = subgraphs_updated
for signed_subgraph in signed_subgraphs:
graph = signed_subgraph.supergraph
if signed_subgraph.components != 1:
raise ValueError
final_signed_node = signed_subgraph.nodes[0]
if final_signed_node.sign != 1:
continue
final_node: CodeOpNode = final_signed_node.node
opstr = f"{last_str} = {final_node.op.left}{final_node.optype.op_str}{final_node.op.right}"
final_node.op = CodeOp(parse(opstr))
reconnect_fliparoo_outputs(graph, final_node)
graph.update()
yield graph
[docs]
def disconnect_fliparoo_outputs(fliparoo: Fliparoo):
# remember positions of the last node with a DummyNode placeholder
dummy = DummyNode()
fliparoo.graph.add_node(dummy)
fliparoo.last.reconnect_outgoing_nodes(dummy)
[docs]
def reconnect_fliparoo_outputs(graph: FormulaGraph, last_node: Node):
dummy = next(filter(lambda x: isinstance(x, DummyNode), graph.nodes))
dummy.reconnect_outgoing_nodes(last_node)
graph.remove_node(dummy)
if any(map(lambda x: isinstance(x, DummyNode), graph.nodes)):
raise ValueError
[docs]
def combine_all_pairs_signed_nodes(
signed_subgraph: SignedSubGraph, fliparoo: Fliparoo
) -> List[SignedSubGraph]:
signed_subgraphs = []
n_components = signed_subgraph.components
for i in range(n_components):
for j in range(i + 1, n_components):
csigned_subgraph = signed_subgraph.deepcopy()
v, w = csigned_subgraph[i], csigned_subgraph[j]
combine_signed_nodes(csigned_subgraph, v, w, fliparoo)
signed_subgraphs.append(csigned_subgraph)
return signed_subgraphs
[docs]
def combine_signed_nodes(
subgraph: SignedSubGraph,
left_signed_node: SignedNode,
right_signed_node: SignedNode,
fliparoo: Fliparoo,
):
left_node, right_node = left_signed_node.node, right_signed_node.node
sign = 1
operator = OpType.Mult
if isinstance(fliparoo, AddSubFliparoo):
s0, s1 = left_signed_node.sign, right_signed_node.sign
if s0 == 1:
operator = OpType.Add if s1 == 1 else OpType.Sub
if s0 == -1 and s1 == 1:
operator = OpType.Sub
left_node, right_node = right_node, left_node
# propagate the sign
if s0 == -1 and s1 == -1:
operator = OpType.Add
sign = -1
new_node = CodeOpNode.from_str(
f"Fliparoo{id(left_node)}_{id(operator)}_{id(sign)}_{id(right_node)}",
left_node.result,
operator,
right_node.result,
)
new_node.incoming_nodes = [left_node, right_node]
left_node.outgoing_nodes.append(new_node)
right_node.outgoing_nodes.append(new_node)
subgraph.supergraph.add_node(new_node)
new_node = SignedNode(new_node)
new_node.sign = sign
subgraph.remove_node(left_signed_node)
subgraph.remove_node(right_signed_node)
subgraph.add_node(new_node)
[docs]
def recursive_fliparoo(formula: Formula, depth: int = 2) -> Set[Formula]:
all_fliparoos = {0: {formula}}
counter = 0
while depth > counter:
prev_level = all_fliparoos[counter]
fliparoo_level: Set[Formula] = set()
for flipparood_formula in prev_level:
rename = not counter # rename ivs before the first fliparoo
for newly_fliparood in generate_fliparood_formulas(
flipparood_formula, rename
):
fliparoo_level.add(newly_fliparood)
counter += 1
all_fliparoos[counter] = fliparoo_level
return set().union(*all_fliparoos.values())