import dataclasses
import enum
import inspect
import itertools
import logging
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Any, Callable, Generator, Union, get_args, get_origin
import numpy as np
import procfunc as pf
from procfunc import compute_graph as cg
from procfunc.compute_graph.operators_info import (
FUNCTIONS_TO_OPERATORS,
OPERATOR_TEMPLATES,
OperatorType,
)
from procfunc.nodes import types as nt
from procfunc.transpiler import identifiers
from procfunc.util import pytree
logger = logging.getLogger(__name__)
INDENT = " "
def indent_lines(lines: list[str], indent: str = INDENT) -> list[str]:
return [indent + line for line in lines]
def _repr_type(x: Any) -> str:
# TODO: make the user pass in special resolutions for types, or else we will just do verbose types
if isinstance(x, str):
return x
if x.__name__ == "NoneType":
return "None"
origin = get_origin(x)
args = get_args(x)
if x.__name__ == "ProcNode":
if len(args) == 1:
return f"pf.ProcNode[{_repr_type(args[0])}]"
elif len(args) == 0:
return "pf.ProcNode"
else:
raise ValueError(f"Unsupported ProcNode type: {x} {args=}")
if hasattr(pf, x.__name__):
if len(args):
raise ValueError(f"procfunc type had unhandled annotations: {x} {args=}")
return f"pf.{x.__name__}"
if x.__module__ == "builtins":
return x.__name__
origin = get_origin(x)
args = get_args(x)
if origin is Union:
args_0 = get_args(args[0])
if get_origin(args[0]) is nt.ProcNode and args_0[0] is args[1]:
return f"t.SocketOrVal[{_repr_type(args_0[0])}]"
else:
return " | ".join([_repr_type(a) for a in args])
if getattr(x, "__module__", None) == "procfunc.nodes.types":
return f"t.{x.__name__}"
return x.__name__
def _repr_value(value: Any) -> str:
if hasattr(value, "__wrapped__"):
value = value.__wrapped__
if isinstance(value, cg.Proxy):
logger.warning(
f"Proxy object {value} should never appear as a raw value in codegen - "
f"its underlying node {value.node} was not resolved to a variable"
)
if isinstance(value, nt.ProcNode):
logger.warning(
f"Procnode object {value} should never be treated as a raw value in codegen"
)
if isinstance(value, np.random.Generator):
return "np.random.default_rng()"
elif isinstance(value, type):
return _repr_type(value)
elif isinstance(value, np.ndarray):
nprepr = repr(value).replace("\n", "")
return f"np.{nprepr}"
elif isinstance(value, np.dtype):
return f"np.dtype('{value}')"
elif isinstance(value, (pf.Color, pf.Vector, pf.Euler, pf.Quaternion, pf.Matrix)):
x = tuple(round(x, 6) for x in value)
return f"pf.{value.__class__.__name__}({x})"
elif isinstance(value, enum.Enum):
return f"{type(value).__name__}.{value.name}"
elif isinstance(value, Path):
return f"Path({str(value)!r})"
elif dataclasses.is_dataclass(value) and not isinstance(value, type):
args_str = ", ".join(
f"{f.name}={_repr_value(getattr(value, f.name))}"
for f in dataclasses.fields(value)
)
return f"{type(value).__name__}({args_str})"
elif isinstance(value, list):
return f"[{', '.join([_repr_value(x) for x in value])}]"
else:
return repr(value)
def _repr_inp(
arg: Any,
scope_expressions: dict[int, str | list[str]],
extra_parens: bool = False,
) -> str:
if isinstance(arg, cg.Node):
if id(arg) not in scope_expressions:
raise ValueError(
f"Scope expressions {scope_expressions} did not contain {arg=} possibly due to bad visit ordering"
)
expr = scope_expressions[id(arg)]
else:
expr = _repr_value(arg)
if isinstance(expr, list):
if len(expr) > 1:
raise ValueError(
"Inlined values should not resolve to more than one line in current implementation, "
f"got {expr=} for {arg=}"
)
expr = expr[0]
assert isinstance(expr, str)
if " " in expr and extra_parens and expr[0] != "(" and expr[-1] != ")":
return f"({expr})"
else:
return expr
def _kwarg_matches_default(sig: inspect.Signature, key: str, value: Any) -> bool:
if isinstance(value, (cg.Node, cg.Proxy)):
return False
param = sig.parameters.get(key)
if param is None or param.default is inspect.Parameter.empty:
return False
try:
return bool(value == param.default)
except Exception:
return False
def _repr_args(
func: Callable[..., Any] | None,
args: tuple[Any, ...],
kwargs: dict[str, Any],
scope_expressions: dict[int, str | list[str]],
) -> list[str]:
"""
Create string for arg and kwarg def for function inputs
"""
try:
sig = inspect.signature(func) if func is not None else None
except ValueError:
sig = None
if sig is not None:
kwargs = {
k: v for k, v in kwargs.items() if not _kwarg_matches_default(sig, k, v)
}
# common specialcase: nodes with a single output which would unnecessarily be a kwarg can just be a positional arg instead
if len(args) == 0 and len(kwargs) == 1 and sig is not None:
if next(iter(kwargs)) == next(iter(sig.parameters)):
args = (kwargs[next(iter(kwargs))],)
kwargs = {}
argreprs = pytree.PyTree(args).map(lambda x: _repr_inp(x, scope_expressions))
argreprs = [
pytree.repr_tree_to_str(v, type_namer=_repr_type)
for v in argreprs.unflatten_one_level()
]
kwargreprs = (
pytree.PyTree(kwargs)
.map(lambda x: _repr_inp(x, scope_expressions))
.unflatten_one_level()
)
kwargreprs = {
k: pytree.repr_tree_to_str(v, type_namer=_repr_type)
for k, v in kwargreprs.items()
}
# use func sig to sort kwargs
if sig is not None:
kwargkeys = list(sig.parameters.keys())
has_var_keyword = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
if not has_var_keyword:
assert set(kwargreprs.keys()).issubset(set(kwargkeys)), (
f"{kwargreprs.keys()=} {kwargkeys=}"
)
else:
kwargkeys = kwargkeys + [k for k in kwargreprs.keys() if k not in kwargkeys]
else:
kwargkeys = list(kwargs.keys())
kwarglist = [f"{k}={kwargreprs[k]}" for k in kwargkeys if k in kwargreprs]
return argreprs + kwarglist
def _repr_function_call(
node: cg.FunctionCallNode | cg.MethodCallNode | cg.SubgraphCallNode,
scope_expressions: dict[int, str | list[str]],
line_limit: int = 80,
) -> list[str]:
match node:
case cg.FunctionCallNode():
func = node.func
func_str = scope_expressions[id(func)]
case cg.MethodCallNode(args=(target, *_), method_name=method_name):
if not isinstance(target, cg.Node):
raise ValueError(f"Method call {node=} has non-node target {target=}")
func = None
func_str = f"{_repr_inp(target, scope_expressions)}.{method_name}"
case cg.SubgraphCallNode(subgraph=subgraph):
func = None
func_str = scope_expressions.get(id(subgraph))
if func_str is None:
raise ValueError(
f"Scope expressions did not contain definition for {subgraph=}"
)
assert isinstance(func_str, str), func_str
case _:
raise TypeError(f"Unsupported {node=}")
args = node.args[1:] if isinstance(node, cg.MethodCallNode) else node.args
arg_reprs = _repr_args(func, args, node.kwargs, scope_expressions) # type: ignore
if len(arg_reprs) == 0:
return [f"{func_str}()"]
total_len = len(func_str) + sum(len(arg) for arg in arg_reprs)
multiline = total_len > line_limit
if len(arg_reprs) > 1 and multiline:
arg_reprs = [line + "," for line in arg_reprs]
if multiline:
return [f"{func_str}("] + indent_lines(arg_reprs) + [")"]
else:
return [f"{func_str}({', '.join(arg_reprs)})"]
def _repr_operator_call(
node: cg.FunctionCallNode,
scope_expressions: dict[int, str | list[str]],
) -> list[str]:
assert isinstance(node, cg.FunctionCallNode), node
# Support both positional args and kwargs for operator templates
all_args = [
_repr_inp(v, scope_expressions, extra_parens=True) for v in node.args
] + [
_repr_inp(v, scope_expressions, extra_parens=True) for v in node.kwargs.values()
]
operator_template = scope_expressions[id(node.func)]
assert isinstance(operator_template, str), operator_template
return [operator_template.format(*all_args)]
def _codegen_for_node(
node: cg.Node,
scope_expressions: dict[int, str | list[str]],
) -> list[str]:
match node:
case cg.FunctionCallNode(func=func):
funcres = scope_expressions[id(func)]
if isinstance(funcres, list):
raise ValueError(
f"{node} resolved to {funcres} but functions should always resolve to names, not expressions"
)
elif funcres == OperatorType.NOOP:
return [] # no code needed
elif "{}" in funcres:
return _repr_operator_call(node, scope_expressions)
else:
return _repr_function_call(node, scope_expressions)
case cg.MethodCallNode() if node.method_name == "__getitem__":
callee_expr = _repr_inp(node.args[0], scope_expressions)
idx_expr = _repr_inp(node.args[1], scope_expressions)
return [f"{callee_expr}[{idx_expr}]"]
case cg.MethodCallNode():
return _repr_function_call(node, scope_expressions)
case cg.SubgraphCallNode():
return _repr_function_call(node, scope_expressions)
case cg.GetAttributeNode(args=(source,), attribute_name=attribute_name):
arg_expr = scope_expressions[id(source)]
if isinstance(arg_expr, list) and len(arg_expr) == 1:
arg_expr = arg_expr[0]
if not isinstance(arg_expr, str):
raise ValueError(
f"Attribute access {attribute_name!r} on {source!r} resolved to {arg_expr} but should be a string"
)
if " " in arg_expr:
raise ValueError(
f"f{_codegen_for_node.__name__} got would attempt to create getattr expression "
f"{arg_expr!r}.{attribute_name} due to space in {arg_expr=} "
f"for {id(node)=} {node=} {id(source)=} {source=}"
)
return [f"{arg_expr}.{attribute_name}"]
case cg.ConstantNode(value=value):
return [_repr_value(value)]
case _:
raise TypeError(f"Unsupported {node=}")
def _codegen_graph_inputs(
graph: cg.ComputeGraph,
node_names: dict[int, str],
typename: str | None,
func_name: str | None = None,
) -> list[str]:
args = sorted(
list(graph.inputs.values()),
key=lambda x: x.kwargs.get("default_value", None) is not None,
)
func_name = func_name or graph.name
if logger.isEnabledFor(logging.DEBUG):
argnames = [node_names.get(id(node)) for node in args]
logger.debug(f"Codegen inputs for {func_name} {argnames=}")
if len(args) == 0:
return [f"def {func_name}():"]
args_lines = []
for node in args:
if id(node) not in node_names:
raise ValueError(f"Node {node} has no name in {node_names}")
name = node_names[id(node)]
known_value_type = node.metadata.get("known_value_type", None)
line = (
f"{name}: {_repr_type(known_value_type)}"
if known_value_type is not None
else f"{name}"
)
if (default := node.kwargs.get("default_value")) is not None:
line += f" = {_repr_value(default)}"
args_lines.append(line + ",")
end_statement = "):" if typename is None else f") -> {typename}: "
return [f"def {func_name}("] + indent_lines(args_lines) + [end_statement]
def _codegen_namedtuple_def(outputs: pytree.PyTree):
tupletype = outputs.toplevel_type()
type_lines = []
for name, node in outputs.items():
if node is None:
continue
vt = node.metadata.get("known_value_type", None)
if vt is None:
type_lines.append(f"{name}: Any")
else:
type_lines.append(f"{name}: {_repr_type(vt)}")
return [f"class {tupletype.__name__}(NamedTuple):"] + indent_lines(type_lines)
def _codegen_for_outputs(
graph: cg.ComputeGraph,
scope_expressions: dict[int, str | list[str]],
) -> tuple[str | None, list[str], list[str]]:
if len(graph.outputs) == 0:
return None, [], []
if len(graph.outputs) == 1:
single_output = next(graph.outputs.values())
vt = single_output.metadata.get("known_value_type", None)
type_name = _repr_type(vt) if vt is not None else None
return type_name, [], [f"return {_repr_inp(single_output, scope_expressions)}"]
graph_output_type = graph.outputs.toplevel_type()
type_name = _repr_type(graph_output_type)
is_pf_type = hasattr(pf, graph_output_type.__name__)
if is_pf_type:
type_def = []
elif graph_output_type.__module__ == "builtins":
type_def = []
elif id(graph_output_type) in scope_expressions:
assert scope_expressions[id(graph_output_type)] == type_name
logger.debug(f"Skipping redefinition of {graph_output_type}")
type_def = []
elif pytree.is_type_namedtuple(graph_output_type):
type_def = _codegen_namedtuple_def(graph.outputs)
scope_expressions[id(graph_output_type)] = type_name
else:
raise ValueError(f"Unhandled graph output type: {graph_output_type}")
reprs_tree = graph.outputs.map(lambda node: _repr_inp(node, scope_expressions))
return_lines = [
f"return {pytree.repr_tree_to_str(reprs_tree, type_namer=_repr_type)}"
]
return type_name, type_def, return_lines
def _check_graph_input_names(
graph: cg.ComputeGraph,
scope_names: dict[int, str],
):
input_names = {id(node): name for name, node in graph.inputs.items()}
if len(input_names.values()) != len(set(input_names.values())):
raise ValueError(
f"Input names for {graph.name} had duplicate values. {input_names.values()=}"
)
overlap = set(input_names.values()).intersection(set(scope_names.values()))
for k, v in input_names.items():
if v not in overlap:
continue
newname = v + "_val"
assert newname not in input_names.values()
input_names[k] = newname
logger.warning(
f"Renaming input {k=} of {graph.name=} from {v} to {newname} to avoid "
f"collision, since {v} is also the name of a util function"
)
for orig_name, node in graph.inputs.items():
identifier = input_names[id(node)]
if not identifiers.is_valid_snake_identifier(identifier):
raise ValueError(
f"{graph.name=} had input {orig_name=} {node=} which recieved invalid identifier {identifier=}"
)
return input_names
def _codegen_graph_decorator(graph: cg.ComputeGraph) -> list[str]:
if graph.metadata.get("is_node_function"):
return ["@pf.nodes.node_function"]
return []
def _should_fold_node(
node: cg.Node,
parent: cg.Node | None,
scope_expressions: dict[int, str | list[str]],
usages: dict[int, list[cg.Node]],
fold_map: dict[int, bool],
) -> bool:
if isinstance(node, cg.MethodCallNode) and node.method_name in (
"astype",
"__getitem__",
):
return True
if isinstance(node, cg.GetAttributeNode):
return True
if any(isinstance(u, cg.GetAttributeNode) for u in usages.get(id(node), [])):
return False
if len(usages.get(id(node), [])) > 1:
return False
if parent is None:
return False
if (
isinstance(node, cg.FunctionCallNode)
and "{}" in scope_expressions[id(node.func)]
):
return not fold_map.get(id(parent), False)
return False
def _expression_fold_map(
graph: cg.ComputeGraph,
scope_expressions: dict[int, str | list[str]],
usages: dict[int, list[cg.Node]],
) -> dict[int, bool]:
fold_map: dict[int, bool] = {}
for output in graph.outputs.values():
fold_map[id(output)] = (
output is None
or isinstance(output, cg.ConstantNode)
or isinstance(output, cg.GetAttributeNode)
)
for parent, node in cg.traverse_breadth_first(graph, yield_parent=True):
if id(node) in fold_map:
continue # dont overwrite output settings
should_fold = _should_fold_node(
node, parent, scope_expressions, usages, fold_map
)
fold_map[id(node)] = should_fold
return fold_map
def traverse_chunks(
graph: cg.ComputeGraph,
pred: Callable[[cg.Node, list[cg.Node]], bool],
) -> Generator[list[cg.Node], None, None]:
visited = set()
def _greedy_singleuses(node: cg.Node, chunk: list[cg.Node]):
if id(node) in visited:
return
visited.add(id(node))
yield node
for arg in itertools.chain(node.args, node.kwargs.values()):
if not isinstance(arg, cg.Node):
continue
if id(arg) in visited:
continue
if not pred(arg, chunk):
continue
yield from _greedy_singleuses(arg, chunk)
for node in cg.traverse_breadth_first(graph):
if id(node) in visited:
continue
chunk = []
yield list(_greedy_singleuses(node, chunk))
def _code_paragraphing_predicate(
node: cg.Node,
chunk: list[cg.Node],
scope_expressions: dict[int, str | list[str]],
usages: dict[int, list[cg.Node]],
) -> bool:
if not isinstance(node, cg.FunctionCallNode):
return False
target_expr = scope_expressions[id(node.func)]
if not (".math." not in target_expr or "{}" not in target_expr):
return False
uses = usages[id(node)]
if len(uses) == 1:
return True
elif not all(any(id(u) == id(v) for v in chunk) for u in uses):
return False
return True
def _codegen_for_assignment(
assign_varname: str,
node_code: list[str] | str,
add_line_comments: bool,
) -> list[str]:
assert isinstance(assign_varname, str)
assert identifiers.is_valid_snake_identifier(assign_varname)
if isinstance(node_code, list):
node_code[0] = f"{assign_varname} = " + node_code[0]
else:
node_code = [f"{assign_varname} = {node_code}"]
if add_line_comments:
node_code[0] += f" # {node}" # noqa: F821
return node_code
def _expressions_scope_for_graph(
graph: cg.ComputeGraph,
scope_expressions: dict[int, str | list[str]],
) -> tuple[dict[int, str | list[str]], dict[int, bool]]:
expressions: dict[int, str | list[str]] = {
**scope_expressions.copy(),
**_check_graph_input_names(graph, scope_expressions),
}
# when we want to refer to a value, what string should we insert?
# - for most nodes: refer to a variable name
# - for inlined expressions: emplace a expression string
usages = cg.usages_per_node(graph)
fold_map = _expression_fold_map(graph, expressions, usages=usages)
node_names = identifiers.nodenames_from_fixed_and_infill(
graph,
fold_map=fold_map,
scope_expressions=expressions,
)
if duplicates := identifiers.duplicate_names(node_names):
raise ValueError(f"Duplicate node names: {duplicates}")
if intersection := set(expressions.values()).intersection(set(node_names.values())):
raise ValueError(f"Scope and node names had overlap: {intersection=}")
expressions.update(node_names)
return expressions, fold_map
def _codegen_for_graph(
graph: cg.ComputeGraph,
scope_expressions: dict[int, str],
as_maincall: bool = True,
add_version_comment: bool = True,
add_line_comments: bool = False,
func_name: str | None = None,
) -> list[str]:
code_lines: list[str] = []
if add_version_comment:
code_lines.append(f"# Code generated by procfunc v{pf.__version__}")
expressions, fold_map = _expressions_scope_for_graph(graph, scope_expressions)
_input_ids = set(id(node) for node in graph.inputs.values()) # noqa: F841
last_varname: str = ""
# Collect mutator call nodes so they emit as bare statements (no assignment)
mutator_call_ids = set()
for node in cg.traverse_depth_first(graph):
if isinstance(node, cg.MutatedArgumentNode):
mutator_call_ids.add(id(node.args[1]))
for node in cg.traverse_depth_first(graph):
if isinstance(node, cg.InputPlaceholderNode):
continue # arguments are defined in _codegen_graph_inputs
if isinstance(node, cg.MutatedArgumentNode):
# alias to the original node, since mutation is in-place
original_node = node.args[0]
expressions[id(node)] = expressions[id(original_node)]
continue
node_code = _codegen_for_node(node, expressions.copy())
if fold_map[id(node)]:
assert id(node) not in expressions, f"{node=} {expressions[id(node)]=}"
expressions[id(node)] = node_code
continue
if id(node) in mutator_call_ids:
code_lines.extend(node_code if isinstance(node_code, list) else [node_code])
continue
varname = expressions[id(node)]
node_code = _codegen_for_assignment(varname, node_code, add_line_comments)
code_lines.extend(node_code)
if last_varname.split("_")[0] != varname.split("_")[0]:
code_lines.append("")
last_varname = varname
if as_maincall:
assert len(graph.inputs) == 0, graph.inputs
return ["if __name__ == '__main__':"] + indent_lines(code_lines)
typename, typedef, return_lines = _codegen_for_outputs(graph, expressions)
return (
typedef
+ [""]
+ _codegen_graph_decorator(graph)
+ _codegen_graph_inputs(graph, expressions, typename, func_name=func_name)
+ indent_lines(code_lines)
+ indent_lines(return_lines)
)
def _resolve_func(func: Any) -> tuple[str | None, str]:
"""
Returns:
tuple[str, str]: import string, function callsite string
"""
if isinstance(func, np.ufunc):
return "import numpy as np", f"np.{func.__name__}"
module = getattr(func, "__module__", None)
if module is None:
raise NotImplementedError(f"Unsupported function: {func}")
elif module == "builtins":
return None, func.__name__
elif module.startswith("procfunc."):
callsite = "pf." + module[len("procfunc.") :] + "." + func.__name__
importstring = "import procfunc as pf"
return importstring, callsite
elif module.startswith("infinigen_v2."):
parent, _, mod_name = module.rpartition(".")
importstring = f"from {parent} import {mod_name}"
callsite = f"{mod_name}.{func.__name__}"
return importstring, callsite
else:
callsite = f"{module}.{func.__name__}"
importstring = f"import {module}"
return importstring, callsite
[docs]
def default_func_resolution_map(
toplevel_graph: cg.ComputeGraph,
skip_funcs: set | None = None,
) -> tuple[dict[Any, str | OperatorType], list[str]]:
func_resolution = {}
import_lines = set()
for graph in cg.traverse_nested_graphs(toplevel_graph):
assert isinstance(graph, cg.ComputeGraph), graph
for node in cg.traverse_depth_first(graph):
if not isinstance(node, cg.FunctionCallNode):
continue
if skip_funcs is not None and node.func in skip_funcs:
continue
if node.func in FUNCTIONS_TO_OPERATORS:
func_resolution[node.func] = FUNCTIONS_TO_OPERATORS[node.func]
continue
importstring, callsite = _resolve_func(node.func)
func_resolution[node.func] = callsite
if importstring is not None:
import_lines.add(importstring)
return func_resolution, list(import_lines)
def _topo_sort_subgraphs(graph: cg.ComputeGraph) -> list[cg.ComputeGraph]:
"""DFS post-order traversal: dependencies before dependents."""
visited = set()
result = []
def visit(g: cg.ComputeGraph):
if id(g) in visited:
return
visited.add(id(g))
for node in cg.traverse_depth_first(g):
if isinstance(node, cg.SubgraphCallNode):
visit(node.subgraph)
result.append(g)
visit(graph)
return result
def graphs_to_python_functions(
graph: cg.ComputeGraph,
func_resolution: dict[Any, str],
toplevel_as_maincall: bool = True,
add_version_comment: bool = True,
add_line_comments: bool = False,
) -> OrderedDict[str, list[str]]:
np_linewidth = np.get_printoptions()["linewidth"]
np.set_printoptions(linewidth=100000)
targets = _topo_sort_subgraphs(graph)
def _clean_graph_name(name: str) -> str:
for suffix in identifiers.NONDESCRIPTIVE_NODE_NAME_PARTS:
if name.endswith("_" + suffix):
name = name[: -(len(suffix) + 1)]
return name
for subgraph in cg.traverse_nested_graphs(graph):
subgraph.name = _clean_graph_name(subgraph.name)
subgraph_names = {
id(subgraph): subgraph.name for subgraph in cg.traverse_nested_graphs(graph)
}
subgraph_names = identifiers.dedup_names_with_suffix(subgraph_names, separator="_")
scope_expressions = subgraph_names.copy()
for k, v in func_resolution.items():
if isinstance(v, OperatorType):
scope_expressions[id(k)] = OPERATOR_TEMPLATES[v]
else:
scope_expressions[id(k)] = v
lines_for_modules = []
for subgraph in targets:
func_name = subgraph_names[id(subgraph)]
result = _codegen_for_graph(
subgraph,
scope_expressions=scope_expressions.copy(),
as_maincall=(subgraph is graph and toplevel_as_maincall),
add_version_comment=add_version_comment,
add_line_comments=add_line_comments,
func_name=func_name,
)
lines_for_modules.append((subgraph_names[id(subgraph)], result))
np.set_printoptions(linewidth=np_linewidth)
return OrderedDict(lines_for_modules)
def _define_multiuse_return_types(
graph: cg.ComputeGraph,
func_resolution: dict,
) -> list[str]:
counts, graphs_by_type, seen = defaultdict(int), {}, set()
for subgraph in cg.traverse_nested_graphs(graph):
rettype = subgraph.outputs.toplevel_type()
if (
id(subgraph) in seen
or rettype is None
or hasattr(pf, rettype.__name__)
or not pytree.is_type_namedtuple(rettype)
):
continue
seen.add(id(subgraph))
counts[rettype] += 1
graphs_by_type[rettype] = subgraph
# logger.debug(f"Found {rettype=} for {subgraph.name} {counts[rettype]=}")
multiuse = {
rettype: subgraph
for rettype, subgraph in graphs_by_type.items()
if counts[rettype] > 1
}
lines = []
for rettype, subgraph in multiuse.items():
lines.extend(_codegen_namedtuple_def(subgraph.outputs))
lines.append("")
func_resolution[rettype] = rettype.__name__
return lines
def _collect_graph_value_imports(graph: cg.ComputeGraph) -> list[str]:
import_lines = set()
def _collect_from_value(v):
if isinstance(v, cg.Node):
return
if isinstance(v, enum.Enum):
t = type(v)
if t.__module__ != "builtins":
import_lines.add(f"from {t.__module__} import {t.__name__}")
elif isinstance(v, Path):
import_lines.add("from pathlib import Path")
elif dataclasses.is_dataclass(v) and not isinstance(v, type):
t = type(v)
if t.__module__ != "builtins":
import_lines.add(f"from {t.__module__} import {t.__name__}")
for f in dataclasses.fields(v):
_collect_from_value(getattr(v, f.name))
elif isinstance(v, list):
for item in v:
_collect_from_value(item)
for subgraph in cg.traverse_nested_graphs(graph):
for node in cg.traverse_depth_first(subgraph):
for arg in itertools.chain(node.args, node.kwargs.values()):
_collect_from_value(arg)
return list(import_lines)
[docs]
def to_python(
graph: cg.ComputeGraph,
func_resolution: dict[Any, str | OperatorType] | None = None,
import_lines: list[str] | None = None,
toplevel_as_maincall: bool = True,
add_version_comment: bool = True,
add_line_comments: bool = False,
) -> str:
code_lines = []
code_lines.append("from typing import NamedTuple, Annotated")
code_lines.append("import numpy as np")
code_lines.append("import bpy")
# code_lines.append("import logging; logging.basicConfig(level=logging.DEBUG)")
code_lines.append("from procfunc.nodes import types as t")
code_lines.append("from procfunc.nodes.types import ProcNode, SocketOrVal")
if func_resolution is None:
func_resolution, import_lines = default_func_resolution_map(graph)
else:
assert import_lines is not None
all_imports = set(import_lines) | set(_collect_graph_value_imports(graph))
code_lines.extend(sorted(all_imports))
code_lines.append("")
code_lines.extend(_define_multiuse_return_types(graph, func_resolution))
lines_for_modules = graphs_to_python_functions(
graph,
func_resolution,
add_version_comment=add_version_comment,
add_line_comments=add_line_comments,
toplevel_as_maincall=toplevel_as_maincall,
)
for module_name, module_lines in lines_for_modules.items():
code_lines.extend(module_lines)
code_lines.append("")
return "\n".join(code_lines)