Source code for procfunc.compute_graph.compute_graph

import itertools
import logging
from dataclasses import dataclass
from typing import Any

from procfunc.util.pytree import PyTree

from .node import Node

logger = logging.getLogger(__name__)


def _evaluate_node(node: Node) -> Any:
    if node.result is not None:
        return node.result

    for arg in itertools.chain(node.args, node.kwargs.values()):
        if isinstance(arg, Node):
            arg.result = _evaluate_node(arg)

    arg_eval, kwarg_eval = node.inputs.map(
        lambda x: x.result if isinstance(x, Node) else x
    ).obj()

    # match node:
    #     case cg.FunctionCallNode:
    #         return node.func(*arg_vals, **kwarg_vals)
    #     case cg.PlaceholderNode:
    #         raise NotImplementedError(
    #             f"Placeholder {node!r} or '<unnamed>'} should not be evaluated - "
    #             "its .result should be populated in advance"
    #         )
    #     case cg.MutatedArgumentNode:
    #         # Evaluate the mutation call first (for side effects), then return the original object
    #         mutation_call_node = node.args[1]
    #         _evaluate_node(mutation_call_node)
    #         return (
    #             node.args[0].result if isinstance(node.args[0], Node) else node.args[0]
    #         )
    #     case _:
    #         raise NotImplementedError(f"Unsupported node operation: {node.kind}")


def _clear_node_results(node: Node):
    node.result = None
    for arg in itertools.chain(node.args, node.kwargs.values()):
        if isinstance(arg, Node) and arg.result is not None:
            _clear_node_results(arg)


[docs] @dataclass class ComputeGraph: inputs: PyTree[Any, Node] outputs: PyTree[Any, Node] name: str metadata: dict[str, Any] def __post_init__(self): # input_names = set(self.inputs.names()) # if len(input_names) != len(set(input_names)): # raise ValueError(f"Input names had duplicates: {input_names}") # output_names = set(self.outputs.names()) # if len(output_names) != len(set(output_names)): # raise ValueError(f"Output names had duplicates: {output_names}") pass def __repr__(self): return f"{self.__class__.__name__}({self.name!r})"
[docs] def clear_values(self): for node in self.outputs.values(): _clear_node_results(node)
def __call__( self, *args, allow_clear: bool = True, **kwargs, ): """ Execute the compute graph. If this graph came from a tracer, this should be exactly equivelant to executing the original python function. """ raise NotImplementedError("Not implemented") if len(args) != len(self.inputs): raise ValueError(f"Expected {len(self.inputs)} arguments, got {len(args)}") if allow_clear: self.clear_values() extra_kwargs = self.kwarg_nodes.keys() - kwargs.keys() if extra_kwargs: raise ValueError( f"{self.__class__.__name__} {self.name!r} got unexpected keyword arguments: {extra_kwargs}" ) missing_kwargs = kwargs.keys() - self.kwarg_nodes.keys() if missing_kwargs: raise ValueError( f"{self.__class__.__name__} {self.name!r} had missing keyword arguments: {missing_kwargs}" ) for arg, arg_node in zip(args, self.arg_nodes): arg_node.result = arg for k, v in kwargs.items(): self.kwarg_nodes[k].result = v for node in self.outputs.values(): _evaluate_node(node) return {name: node.result for name, node in self.outputs.items()}