Source code for procfunc.codegen.codegen

import dataclasses
import enum
import inspect
import itertools
import logging
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Any, Callable, Generator

import numpy as np

import procfunc as pf
from procfunc import compute_graph as cg
from procfunc.codegen import identifiers
from procfunc.codegen.repr import repr_type, repr_value
from procfunc.compute_graph.operators_info import (
    FUNCTIONS_TO_OPERATORS,
    OPERATOR_TEMPLATES,
    OperatorType,
)
from procfunc.util import pytree

logger = logging.getLogger(__name__)

INDENT = "    "


def indent_lines(lines: list[str], indent: str = INDENT) -> list[str]:
    return [indent + line for line in lines]


def _repr_inp(
    arg: Any,
    scope_expressions: dict[int, str | list[str]],
    extra_parens: bool = False,
) -> str:
    if isinstance(arg, cg.Node):
        if id(arg) not in scope_expressions:
            raise ValueError(
                f"Scope expressions {scope_expressions} did not contain {arg=} possibly due to bad visit ordering"
            )
        expr = scope_expressions[id(arg)]
    else:
        expr = repr_value(arg)

    if isinstance(expr, list):
        if len(expr) > 1:
            raise ValueError(
                "Inlined values should not resolve to more than one line in current implementation, "
                f"got {expr=} for {arg=}"
            )
        expr = expr[0]
    assert isinstance(expr, str)

    if " " in expr and extra_parens and expr[0] != "(" and expr[-1] != ")":
        return f"({expr})"
    else:
        return expr


def _kwarg_matches_default(sig: inspect.Signature, key: str, value: Any) -> bool:
    if isinstance(value, (cg.Node, cg.Proxy)):
        return False
    param = sig.parameters.get(key)
    if param is None or param.default is inspect.Parameter.empty:
        return False
    try:
        return bool(value == param.default)
    except Exception:
        return False


def _repr_args(
    func: Callable[..., Any] | None,
    args: tuple[Any, ...],
    kwargs: dict[str, Any],
    scope_expressions: dict[int, str | list[str]],
) -> list[str]:
    """
    Create string for arg and kwarg def for function inputs
    """

    try:
        sig = inspect.signature(func) if func is not None else None
    except ValueError:
        sig = None

    if sig is not None:
        kwargs = {
            k: v for k, v in kwargs.items() if not _kwarg_matches_default(sig, k, v)
        }

    # common specialcase: nodes with a single output which would unnecessarily be a kwarg can just be a positional arg instead
    if len(args) == 0 and len(kwargs) == 1 and sig is not None:
        if next(iter(kwargs)) == next(iter(sig.parameters)):
            args = (kwargs[next(iter(kwargs))],)
            kwargs = {}

    argreprs = pytree.PyTree(args).map(lambda x: _repr_inp(x, scope_expressions))
    argreprs = [
        pytree.repr_tree_to_str(v, type_namer=repr_type)
        for v in argreprs.unflatten_one_level()
    ]

    kwargreprs = (
        pytree.PyTree(kwargs)
        .map(lambda x: _repr_inp(x, scope_expressions))
        .unflatten_one_level()
    )
    kwargreprs = {
        k: pytree.repr_tree_to_str(v, type_namer=repr_type)
        for k, v in kwargreprs.items()
    }

    # use func sig to sort kwargs
    if sig is not None:
        kwargkeys = list(sig.parameters.keys())
        has_var_keyword = any(
            p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
        )
        if not has_var_keyword:
            assert set(kwargreprs.keys()).issubset(set(kwargkeys)), (
                f"{kwargreprs.keys()=} {kwargkeys=}"
            )
        else:
            kwargkeys = kwargkeys + [k for k in kwargreprs.keys() if k not in kwargkeys]
    else:
        kwargkeys = list(kwargs.keys())

    kwarglist = [f"{k}={kwargreprs[k]}" for k in kwargkeys if k in kwargreprs]

    return argreprs + kwarglist


def _repr_function_call(
    node: cg.FunctionCallNode | cg.MethodCallNode | cg.SubgraphCallNode,
    scope_expressions: dict[int, str | list[str]],
    line_limit: int = 80,
    func_str: str | None = None,
) -> list[str]:
    match node:
        case cg.FunctionCallNode():
            func = node.func
            func_str = func_str or scope_expressions[id(func)]
        case cg.MethodCallNode(args=(target, *_), method_name=method_name):
            if not isinstance(target, cg.Node):
                raise ValueError(f"Method call {node=} has non-node target {target=}")
            func = None
            func_str = f"{_repr_inp(target, scope_expressions)}.{method_name}"
        case cg.SubgraphCallNode(subgraph=subgraph):
            func = None
            func_str = scope_expressions.get(id(subgraph))
            if func_str is None:
                raise ValueError(
                    f"Scope expressions did not contain definition for {subgraph=}"
                )
            assert isinstance(func_str, str), func_str
        case _:
            raise TypeError(f"Unsupported {node=}")

    args = node.args[1:] if isinstance(node, cg.MethodCallNode) else node.args
    arg_reprs = _repr_args(func, args, node.kwargs, scope_expressions)  # type: ignore

    if len(arg_reprs) == 0:
        return [f"{func_str}()"]

    total_len = len(func_str) + sum(len(arg) for arg in arg_reprs)
    multiline = total_len > line_limit

    if len(arg_reprs) > 1 and multiline:
        arg_reprs = [line + "," for line in arg_reprs]

    if multiline:
        return [f"{func_str}("] + indent_lines(arg_reprs) + [")"]
    else:
        return [f"{func_str}({', '.join(arg_reprs)})"]


def _operator_call_operands(
    node: cg.FunctionCallNode,
    template: str,
) -> list | None:
    """Operand values for rendering `node` in infix/operator form, or None to
    decline it. The operator template only has slots for the operands, so any
    extra argument beyond them (e.g. a non-default epsilon on func.equal) would
    be silently dropped by the infix form - decline unless it merely restates
    the signature default, in which case it is redundant and dropped."""
    n_slots = template.count("{}")
    try:
        sig = inspect.signature(node.func)
        bound = sig.bind(*node.args, **node.kwargs)
    except (TypeError, ValueError):
        return None

    operand_names = list(sig.parameters)[:n_slots]
    if any(name not in bound.arguments for name in operand_names):
        return None

    for name, value in bound.arguments.items():
        if name in operand_names:
            continue
        if not _kwarg_matches_default(sig, name, value):
            return None

    return [bound.arguments[name] for name in operand_names]


def _repr_operator_call(
    operands: list,
    template: str,
    scope_expressions: dict[int, str | list[str]],
) -> list[str]:
    operand_reprs = [
        _repr_inp(v, scope_expressions, extra_parens=True) for v in operands
    ]
    return [template.format(*operand_reprs)]


def _codegen_for_node(
    node: cg.Node,
    scope_expressions: dict[int, str | list[str]],
) -> list[str]:
    match node:
        case cg.FunctionCallNode(func=func):
            funcres = scope_expressions[id(func)]
            if isinstance(funcres, list):
                raise ValueError(
                    f"{node} resolved to {funcres} but functions should always resolve to names, not expressions"
                )
            elif funcres == OperatorType.NOOP:
                return []  # no code needed
            elif "{}" in funcres:
                operands = _operator_call_operands(node, funcres)
                if operands is None:
                    # args the infix form cannot express: emit a named call.
                    # scope_expressions holds the operator template, so re-derive
                    # the callsite name (import already present via
                    # default_func_resolution_map)
                    _, callsite = _resolve_func(func)
                    return _repr_function_call(
                        node, scope_expressions, func_str=callsite
                    )
                return _repr_operator_call(operands, funcres, scope_expressions)
            else:
                return _repr_function_call(node, scope_expressions)
        case cg.MethodCallNode() if node.method_name == "__getitem__":
            callee_expr = _repr_inp(node.args[0], scope_expressions)
            idx_expr = _repr_inp(node.args[1], scope_expressions)
            return [f"{callee_expr}[{idx_expr}]"]
        case cg.MethodCallNode():
            return _repr_function_call(node, scope_expressions)
        case cg.SubgraphCallNode():
            return _repr_function_call(node, scope_expressions)
        case cg.GetAttributeNode(args=(source,), attribute_name=attribute_name):
            arg_expr = scope_expressions[id(source)]
            if isinstance(arg_expr, list) and len(arg_expr) == 1:
                arg_expr = arg_expr[0]
            if not isinstance(arg_expr, str):
                raise ValueError(
                    f"Attribute access {attribute_name!r} on {source!r} resolved to {arg_expr} but should be a string"
                )
            if " " in arg_expr:
                raise ValueError(
                    f"f{_codegen_for_node.__name__} got would attempt to create getattr expression "
                    f"{arg_expr!r}.{attribute_name} due to space in {arg_expr=} "
                    f"for {id(node)=} {node=} {id(source)=} {source=}"
                )
            return [f"{arg_expr}.{attribute_name}"]
        case cg.ConstantNode(value=value):
            return [repr_value(value)]
        case _:
            raise TypeError(f"Unsupported {node=}")


def _codegen_graph_inputs(
    graph: cg.ComputeGraph,
    node_names: dict[int, str],
    typename: str | None,
    func_name: str | None = None,
) -> list[str]:
    args = sorted(
        list(graph.inputs.values()),
        key=lambda x: x.kwargs.get("default_value", None) is not None,
    )

    func_name = func_name or graph.name

    if logger.isEnabledFor(logging.DEBUG):
        argnames = [node_names.get(id(node)) for node in args]
        logger.debug(f"Codegen inputs for {func_name} {argnames=}")

    if len(args) == 0:
        return [f"def {func_name}():"]

    args_lines = []
    for node in args:
        if id(node) not in node_names:
            raise ValueError(f"Node {node} has no name in {node_names}")
        name = node_names[id(node)]

        known_value_type = node.metadata.get("known_value_type", None)
        line = (
            f"{name}: {repr_type(known_value_type)}"
            if known_value_type is not None
            else f"{name}"
        )

        if (default := node.kwargs.get("default_value")) is not None:
            line += f" = {repr_value(default)}"

        args_lines.append(line + ",")

    end_statement = "):" if typename is None else f") -> {typename}:"

    return [f"def {func_name}("] + indent_lines(args_lines) + [end_statement]


def _codegen_namedtuple_def(outputs: pytree.PyTree):
    tupletype = outputs.toplevel_type()

    type_lines = []
    for name, node in outputs.items():
        if node is None:
            continue
        vt = node.metadata.get("known_value_type", None)
        if vt is None:
            type_lines.append(f"{name}: Any")
        else:
            type_lines.append(f"{name}: {repr_type(vt)}")

    return [f"class {tupletype.__name__}(NamedTuple):"] + indent_lines(type_lines)


def _codegen_for_outputs(
    graph: cg.ComputeGraph,
    scope_expressions: dict[int, str | list[str]],
) -> tuple[str | None, list[str], list[str]]:
    if len(graph.outputs) == 0:
        # Sink graphs (no return values) still need a body statement so the
        # generated function parses; emit `pass` when nothing else fills it.
        return None, [], ["pass"]
    if len(graph.outputs) == 1:
        single_output = next(graph.outputs.values())
        vt = single_output.metadata.get("known_value_type", None)
        type_name = repr_type(vt) if vt is not None else None
        return type_name, [], [f"return {_repr_inp(single_output, scope_expressions)}"]

    graph_output_type = graph.outputs.toplevel_type()
    type_name = repr_type(graph_output_type)

    is_pf_type = hasattr(pf, graph_output_type.__name__)
    if is_pf_type:
        type_def = []
    elif graph_output_type.__module__ == "builtins":
        type_def = []
    elif id(graph_output_type) in scope_expressions:
        assert scope_expressions[id(graph_output_type)] == type_name
        logger.debug(f"Skipping redefinition of {graph_output_type}")
        type_def = []
    elif pytree.is_type_namedtuple(graph_output_type):
        type_def = _codegen_namedtuple_def(graph.outputs)
        scope_expressions[id(graph_output_type)] = type_name
    else:
        raise ValueError(f"Unhandled graph output type: {graph_output_type}")

    reprs_tree = graph.outputs.map(lambda node: _repr_inp(node, scope_expressions))
    return_lines = [
        f"return {pytree.repr_tree_to_str(reprs_tree, type_namer=repr_type)}"
    ]

    return type_name, type_def, return_lines


def _check_graph_input_names(
    graph: cg.ComputeGraph,
    scope_names: dict[int, str],
):
    input_names = {id(node): name for name, node in graph.inputs.items()}
    if len(input_names.values()) != len(set(input_names.values())):
        raise ValueError(
            f"Input names for {graph.name} had duplicate values. {input_names.values()=}"
        )

    overlap = set(input_names.values()).intersection(set(scope_names.values()))
    for k, v in input_names.items():
        if v not in overlap:
            continue

        newname = v + "_val"
        assert newname not in input_names.values()
        input_names[k] = newname
        logger.warning(
            f"Renaming input {k=} of {graph.name=} from {v} to {newname} to avoid "
            f"collision, since {v} is also the name of a util function"
        )

    for orig_name, node in graph.inputs.items():
        identifier = input_names[id(node)]
        if not identifiers.is_valid_snake_identifier(identifier):
            raise ValueError(
                f"{graph.name=} had input {orig_name=} {node=} which recieved invalid identifier {identifier=}"
            )

    return input_names


def _codegen_graph_decorator(graph: cg.ComputeGraph) -> list[str]:
    if graph.metadata.get("is_node_function"):
        return ["@pf.nodes.node_function"]
    return []


def _should_fold_node(
    node: cg.Node,
    parent: cg.Node | None,
    scope_expressions: dict[int, str | list[str]],
    usages: dict[int, list[cg.Node]],
    fold_map: dict[int, bool],
) -> bool:
    if isinstance(node, cg.MethodCallNode) and node.method_name in (
        "astype",
        "__getitem__",
    ):
        return True

    if isinstance(node, cg.GetAttributeNode):
        return True

    if any(isinstance(u, cg.GetAttributeNode) for u in usages.get(id(node), [])):
        return False

    if len(usages.get(id(node), [])) > 1:
        return False

    if parent is None:
        return False

    if (
        isinstance(node, cg.FunctionCallNode)
        and "{}" in scope_expressions[id(node.func)]
    ):
        return not fold_map.get(id(parent), False)

    return False


def _expression_fold_map(
    graph: cg.ComputeGraph,
    scope_expressions: dict[int, str | list[str]],
    usages: dict[int, list[cg.Node]],
) -> dict[int, bool]:
    fold_map: dict[int, bool] = {}

    for output in graph.outputs.values():
        fold_map[id(output)] = (
            output is None
            or isinstance(output, cg.ConstantNode)
            or isinstance(output, cg.GetAttributeNode)
        )
    for parent, node in cg.traverse_breadth_first(graph, yield_parent=True):
        if id(node) in fold_map:
            continue  # dont overwrite output settings
        should_fold = _should_fold_node(
            node, parent, scope_expressions, usages, fold_map
        )
        fold_map[id(node)] = should_fold

    return fold_map


def traverse_chunks(
    graph: cg.ComputeGraph,
    pred: Callable[[cg.Node, list[cg.Node]], bool],
) -> Generator[list[cg.Node], None, None]:
    visited = set()

    def _greedy_singleuses(node: cg.Node, chunk: list[cg.Node]):
        if id(node) in visited:
            return
        visited.add(id(node))
        yield node

        for arg in itertools.chain(node.args, node.kwargs.values()):
            if not isinstance(arg, cg.Node):
                continue
            if id(arg) in visited:
                continue
            if not pred(arg, chunk):
                continue
            yield from _greedy_singleuses(arg, chunk)

    for node in cg.traverse_breadth_first(graph):
        if id(node) in visited:
            continue
        chunk = []
        yield list(_greedy_singleuses(node, chunk))


def _code_paragraphing_predicate(
    node: cg.Node,
    chunk: list[cg.Node],
    scope_expressions: dict[int, str | list[str]],
    usages: dict[int, list[cg.Node]],
) -> bool:
    if not isinstance(node, cg.FunctionCallNode):
        return False

    target_expr = scope_expressions[id(node.func)]
    if not (".math." not in target_expr or "{}" not in target_expr):
        return False

    uses = usages[id(node)]
    if len(uses) == 1:
        return True
    elif not all(any(id(u) == id(v) for v in chunk) for u in uses):
        return False

    return True


def _codegen_for_assignment(
    assign_varname: str,
    node_code: list[str] | str,
    node: cg.Node,
    add_line_comments: bool,
) -> list[str]:
    assert isinstance(assign_varname, str)
    assert identifiers.is_valid_snake_identifier(assign_varname)

    if isinstance(node_code, list):
        node_code[0] = f"{assign_varname} = " + node_code[0]
    else:
        node_code = [f"{assign_varname} = {node_code}"]
    if add_line_comments:
        node_code[-1] += f"  # {str(node).replace(chr(10), ' ')}"

    return node_code


def _expressions_scope_for_graph(
    graph: cg.ComputeGraph,
    scope_expressions: dict[int, str | list[str]],
) -> tuple[dict[int, str | list[str]], dict[int, bool]]:
    expressions: dict[int, str | list[str]] = {
        **scope_expressions.copy(),
        **_check_graph_input_names(graph, scope_expressions),
    }

    # when we want to refer to a value, what string should we insert?
    # - for most nodes: refer to a variable name
    # - for inlined expressions: emplace a expression string
    usages = cg.usages_per_node(graph)
    fold_map = _expression_fold_map(graph, expressions, usages=usages)

    node_names = identifiers.nodenames_from_fixed_and_infill(
        graph,
        fold_map=fold_map,
        scope_expressions=expressions,
    )
    if duplicates := identifiers.duplicate_names(node_names):
        raise ValueError(f"Duplicate node names: {duplicates}")

    if intersection := set(expressions.values()).intersection(set(node_names.values())):
        raise ValueError(f"Scope and node names had overlap: {intersection=}")
    expressions.update(node_names)

    return expressions, fold_map


def _codegen_for_graph(
    graph: cg.ComputeGraph,
    scope_expressions: dict[int, str],
    as_maincall: bool = True,
    add_version_comment: bool = True,
    add_line_comments: bool = False,
    func_name: str | None = None,
) -> list[str]:
    code_lines: list[str] = []

    if add_version_comment:
        code_lines.append(f"# Code generated by procfunc v{pf.__version__}")

    expressions, fold_map = _expressions_scope_for_graph(graph, scope_expressions)
    _input_ids = set(id(node) for node in graph.inputs.values())  # noqa: F841

    last_varname: str = ""

    # Collect mutator call nodes so they emit as bare statements (no assignment)
    mutator_call_ids = set()
    for node in cg.traverse_depth_first(graph):
        if isinstance(node, cg.MutatedArgumentNode):
            mutator_call_ids.add(id(node.args[1]))

    for node in cg.traverse_depth_first(graph):
        if isinstance(node, cg.InputPlaceholderNode):
            continue  # arguments are defined in _codegen_graph_inputs
        if isinstance(node, cg.MutatedArgumentNode):
            # alias to the original node, since mutation is in-place
            original_node = node.args[0]
            expressions[id(node)] = expressions[id(original_node)]
            continue

        node_code = _codegen_for_node(node, expressions.copy())

        if fold_map[id(node)]:
            assert id(node) not in expressions, f"{node=} {expressions[id(node)]=}"
            expressions[id(node)] = node_code
            continue

        if id(node) in mutator_call_ids:
            code_lines.extend(node_code if isinstance(node_code, list) else [node_code])
            continue

        varname = expressions[id(node)]
        node_code = _codegen_for_assignment(varname, node_code, node, add_line_comments)
        code_lines.extend(node_code)

        if last_varname.split("_")[0] != varname.split("_")[0]:
            code_lines.append("")
        last_varname = varname

    if as_maincall:
        assert len(graph.inputs) == 0, graph.inputs
        return ["if __name__ == '__main__':"] + indent_lines(code_lines)

    typename, typedef, return_lines = _codegen_for_outputs(graph, expressions)

    return (
        typedef
        + [""]
        + _codegen_graph_decorator(graph)
        + _codegen_graph_inputs(graph, expressions, typename, func_name=func_name)
        + indent_lines(code_lines)
        + indent_lines(return_lines)
    )


def _resolve_func(func: Any) -> tuple[str | None, str]:
    """
    Returns:
        tuple[str, str]: import string, function callsite string
    """

    if isinstance(func, np.ufunc):
        return "import numpy as np", f"np.{func.__name__}"

    module = getattr(func, "__module__", None)

    if module is None:
        raise NotImplementedError(f"Unsupported function: {func}")
    elif module == "builtins":
        return None, func.__name__
    elif module.startswith("procfunc."):
        callsite = "pf." + module[len("procfunc.") :] + "." + func.__name__
        importstring = "import procfunc as pf"
        return importstring, callsite
    elif module.startswith("infinigen_v2."):
        parent, _, mod_name = module.rpartition(".")
        importstring = f"from {parent} import {mod_name}"
        callsite = f"{mod_name}.{func.__name__}"
        return importstring, callsite
    else:
        callsite = f"{module}.{func.__name__}"
        importstring = f"import {module}"
        return importstring, callsite


[docs] def default_func_resolution_map( toplevel_graph: cg.ComputeGraph, skip_funcs: set | None = None, ) -> tuple[dict[Any, str | OperatorType], list[str]]: func_resolution = {} import_lines = {"import procfunc as pf"} for graph in cg.traverse_nested_graphs(toplevel_graph): assert isinstance(graph, cg.ComputeGraph), graph for node in cg.traverse_depth_first(graph): if not isinstance(node, cg.FunctionCallNode): continue if skip_funcs is not None and node.func in skip_funcs: continue if node.func in FUNCTIONS_TO_OPERATORS: func_resolution[node.func] = FUNCTIONS_TO_OPERATORS[node.func] continue importstring, callsite = _resolve_func(node.func) func_resolution[node.func] = callsite if importstring is not None: import_lines.add(importstring) return func_resolution, list(import_lines)
def _topo_sort_subgraphs(graph: cg.ComputeGraph) -> list[cg.ComputeGraph]: """DFS post-order traversal: dependencies before dependents.""" visited = set() result = [] def visit(g: cg.ComputeGraph): if id(g) in visited: return visited.add(id(g)) for node in cg.traverse_depth_first(g): if isinstance(node, cg.SubgraphCallNode): visit(node.subgraph) result.append(g) visit(graph) return result def graphs_to_python_functions( graph: cg.ComputeGraph, func_resolution: dict[Any, str], toplevel_as_maincall: bool = True, add_version_comment: bool = True, add_line_comments: bool = False, ) -> OrderedDict[str, list[str]]: np_linewidth = np.get_printoptions()["linewidth"] np.set_printoptions(linewidth=100000) try: targets = _topo_sort_subgraphs(graph) def _clean_graph_name(name: str) -> str: for suffix in identifiers.NONDESCRIPTIVE_NODE_NAME_PARTS: if name.endswith("_" + suffix): name = name[: -(len(suffix) + 1)] return name for subgraph in cg.traverse_nested_graphs(graph): subgraph.name = _clean_graph_name(subgraph.name) subgraph_names = { id(subgraph): subgraph.name for subgraph in cg.traverse_nested_graphs(graph) } subgraph_names = identifiers.dedup_names_with_suffix( subgraph_names, separator="_" ) scope_expressions = subgraph_names.copy() for k, v in func_resolution.items(): if isinstance(v, OperatorType): scope_expressions[id(k)] = OPERATOR_TEMPLATES[v] else: scope_expressions[id(k)] = v lines_for_modules = [] for subgraph in targets: func_name = subgraph_names[id(subgraph)] result = _codegen_for_graph( subgraph, scope_expressions=scope_expressions.copy(), as_maincall=(subgraph is graph and toplevel_as_maincall), add_version_comment=add_version_comment, add_line_comments=add_line_comments, func_name=func_name, ) lines_for_modules.append((subgraph_names[id(subgraph)], result)) finally: np.set_printoptions(linewidth=np_linewidth) return OrderedDict(lines_for_modules) def _define_multiuse_return_types( graph: cg.ComputeGraph, func_resolution: dict, ) -> list[str]: counts, graphs_by_type, seen = defaultdict(int), {}, set() for subgraph in cg.traverse_nested_graphs(graph): rettype = subgraph.outputs.toplevel_type() if ( id(subgraph) in seen or rettype is None or hasattr(pf, rettype.__name__) or not pytree.is_type_namedtuple(rettype) ): continue seen.add(id(subgraph)) counts[rettype] += 1 graphs_by_type[rettype] = subgraph # logger.debug(f"Found {rettype=} for {subgraph.name} {counts[rettype]=}") multiuse = { rettype: subgraph for rettype, subgraph in graphs_by_type.items() if counts[rettype] > 1 } lines = [] for rettype, subgraph in multiuse.items(): lines.extend(_codegen_namedtuple_def(subgraph.outputs)) lines.append("") func_resolution[rettype] = rettype.__name__ return lines def _collect_graph_value_imports(graph: cg.ComputeGraph) -> list[str]: import_lines = set() def _collect_from_value(v): if isinstance(v, cg.Node): return if isinstance(v, enum.Enum): t = type(v) if t.__module__ != "builtins": import_lines.add(f"from {t.__module__} import {t.__name__}") elif isinstance(v, Path): import_lines.add("from pathlib import Path") elif dataclasses.is_dataclass(v) and not isinstance(v, type): t = type(v) if t.__module__ != "builtins": import_lines.add(f"from {t.__module__} import {t.__name__}") for f in dataclasses.fields(v): _collect_from_value(getattr(v, f.name)) elif isinstance(v, list): for item in v: _collect_from_value(item) for subgraph in cg.traverse_nested_graphs(graph): for node in cg.traverse_depth_first(subgraph): for arg in itertools.chain(node.args, node.kwargs.values()): _collect_from_value(arg) return list(import_lines)
[docs] def to_python( graph: cg.ComputeGraph, func_resolution: dict[Any, str | OperatorType] | None = None, import_lines: list[str] | None = None, toplevel_as_maincall: bool = True, add_version_comment: bool = True, add_line_comments: bool = False, ) -> str: code_lines = [] code_lines.append("from typing import NamedTuple, Annotated") code_lines.append("import numpy as np") code_lines.append("import bpy") # code_lines.append("import logging; logging.basicConfig(level=logging.DEBUG)") code_lines.append("from procfunc.nodes import types as t") code_lines.append("from procfunc.nodes.types import ProcNode, SocketOrVal") if func_resolution is None: func_resolution, import_lines = default_func_resolution_map(graph) else: assert import_lines is not None all_imports = set(import_lines) | set(_collect_graph_value_imports(graph)) code_lines.extend(sorted(all_imports)) code_lines.append("") code_lines.extend(_define_multiuse_return_types(graph, func_resolution)) lines_for_modules = graphs_to_python_functions( graph, func_resolution, add_version_comment=add_version_comment, add_line_comments=add_line_comments, toplevel_as_maincall=toplevel_as_maincall, ) for module_name, module_lines in lines_for_modules.items(): code_lines.extend(module_lines) code_lines.append("") return "\n".join(code_lines)