import copy
import logging
from collections import defaultdict, deque
from typing import Any, Callable, Generator, Literal, TypeVar
from procfunc.util import pytree
from .compute_graph import ComputeGraph
from .node import Node, SubgraphCallNode
logger = logging.getLogger(__name__)
T = TypeVar("T")
[docs]
class LiteralConstant:
[docs]
def __init__(self, value: Any):
self.value = value
def __repr__(self) -> Any:
return self.value
[docs]
def traverse_breadth_first(
graph: ComputeGraph,
yield_parent: bool = False,
yield_name: bool = False,
yield_consts: bool = False,
) -> Generator[Any, None, None]:
"""
Traverse all nodes in the compute graph.
Args:
graph: The compute graph to traverse
yield_parent: If True, yield (parent, child), with output nodes having parent=None
yield_name: If True, yield (name, child) or (name, parent, child) if yield_parent is also True
yield_consts: If True, yield child arguments of nodes even if they are not Nodes
"""
visited = set()
frontier = deque((None, name, node) for name, node in graph.outputs.items())
# logger.debug(f"{traverse_breadth_first.__name__} {graph.name} {len(frontier)=}")
def res(parent, name, child):
res = (child,)
if yield_parent:
res = (parent,) + res
if yield_name:
res = (name,) + res
return res[0] if len(res) == 1 else tuple(res)
while len(frontier) > 0:
parent, name, node = frontier.popleft()
if yield_consts and not isinstance(node, Node):
yield res(parent, name, node)
continue
if node is None:
continue
if id(node) in visited:
continue
visited.add(id(node))
yield res(parent, name, node)
children = list(pytree.PyTree(node.args).items()) + list(
pytree.PyTree(node.kwargs).items()
)
for key, arg in children:
if not yield_consts and not isinstance(arg, Node):
continue
if id(arg) in visited:
continue
frontier.append((node, key, arg))
def _traverse_depth_first_node(
node: Node,
visited: set[int],
parent: Node | None,
name: str,
order: Literal["preorder", "postorder"],
yield_parent: bool,
yield_name: bool,
yield_consts: bool,
) -> Generator[Any, None, None]:
def res(parent, name, child):
res = (child,)
if yield_parent:
res = (parent,) + res
if yield_name:
res = (name,) + res
return res[0] if len(res) == 1 else tuple(res)
assert isinstance(node, Node), node
if id(node) in visited:
return
visited.add(id(node))
if order == "preorder":
yield res(parent, name, node)
children = list(pytree.PyTree(node.args).items()) + list(
pytree.PyTree(node.kwargs).items()
)
for key, arg in children:
if not isinstance(arg, Node):
if yield_consts:
yield res(node, key, arg)
continue
yield from _traverse_depth_first_node(
node=arg,
visited=visited,
parent=node,
name=key,
order=order,
yield_parent=yield_parent,
yield_name=yield_name,
yield_consts=yield_consts,
)
if order == "postorder":
yield res(parent, name, node)
[docs]
def traverse_depth_first_node(
node: Node,
yield_consts: bool = False,
order: Literal["preorder", "postorder"] = "postorder",
) -> Generator[Any, None, None]:
return _traverse_depth_first_node(
node=node,
visited=set(),
parent=None,
name="",
order=order,
yield_parent=False,
yield_name=False,
yield_consts=yield_consts,
)
[docs]
def traverse_depth_first(
graph: ComputeGraph,
yield_parent: bool = False,
yield_name: bool = False,
yield_consts: bool = False,
order: Literal["preorder", "postorder"] = "postorder",
) -> Generator[Any, None, None]:
visited = set()
for name, node in graph.outputs.items():
if node is None:
continue
yield from _traverse_depth_first_node(
node, visited, None, name, order, yield_parent, yield_name, yield_consts
)
[docs]
def traverse_nested_graphs(
graph: ComputeGraph,
yield_call_nodes: bool = False,
) -> Generator[tuple[Node | None, ComputeGraph], None, None]:
visited = set()
frontier = deque([(None, graph)])
while len(frontier) > 0:
node, graph = frontier.popleft()
if id(graph) in visited:
continue
visited.add(id(graph))
if yield_call_nodes:
yield node, graph
else:
yield graph
frontier.extend(
(node, node.subgraph)
for node in traverse_depth_first(graph)
if isinstance(node, SubgraphCallNode)
)
[docs]
def usages_per_node(
graph: ComputeGraph,
) -> dict[int, list[Node]]:
usages = defaultdict(list)
for node in traverse_depth_first(graph):
argtree = pytree.PyTree((node.args, node.kwargs))
for arg in argtree.values():
if isinstance(arg, Node):
usages[id(arg)].append(node)
return dict(usages)
[docs]
def graph_nodes_equal(graph1: ComputeGraph, graph2: ComputeGraph) -> bool:
nodes1 = list(traverse_depth_first(graph1))
nodes2 = list(traverse_depth_first(graph2))
if len(nodes1) != len(nodes2):
return False
for node1, node2 in zip(nodes1, nodes2):
if type(node1) is not type(node2):
return False
if isinstance(node1, SubgraphCallNode):
if not graph_nodes_equal(node1.subgraph, node2.subgraph):
return False
elif node1.args != node2.args or node1.kwargs != node2.kwargs:
return False
return True