Source code for procfunc.transpiler.bpy_to_computegraph

import inspect
import logging
from collections import namedtuple
from dataclasses import dataclass, field
from typing import (
    Any,
    Literal,
    TypeVar,
    Union,
    get_args,
    get_origin,
    get_type_hints,
)

import bpy
import numpy as np
import pandas as pd

import procfunc as pf
from procfunc import compute_graph as cg
from procfunc import types as t
from procfunc.codegen import identifiers
from procfunc.nodes import NODES_MANIFEST
from procfunc.nodes import types as nt
from procfunc.nodes.execute.util import get_active_sockets, normalize_socket_type
from procfunc.nodes.util import bpy_node_info
from procfunc.nodes.util import bpy_node_info as bni
from procfunc.ops import OPS_MANIFEST
from procfunc.transpiler.parse_default_values import (
    SUBCOMPONENT_TYPES,
    normalize_default_value,
)
from procfunc.transpiler.parse_special_cases import SPECIAL_CASE_NODES
from procfunc.util import bpy_info, log, manifest, pytree

logger = logging.getLogger(__name__)

_NODES_MANIFEST_INDEXED = NODES_MANIFEST.set_index("bpy_name")
_OPS_MANIFEST_INDEXED = OPS_MANIFEST.set_index("bpy_name")

MODE_ATTRS = [
    "mode",
    "data_type",
    "operation",
    "rotation_type",
    "feature",
    "distribute_method",
]
IGNORE_ATTRS = ["color_mapping", "texture_mapping", "active_item", "capture_items"]


class InvalidNodeGraph(Exception):
    def __init__(self, message: str, nodes: list[bpy.types.Node]):
        super().__init__(message)
        self.nodes = nodes


@dataclass
class ParseMemo:
    nodes: dict[tuple[int, str], cg.Node] = field(default_factory=dict)
    """
    (str, str) key is (node_tree.name, node.name)
    """

    links: dict[tuple[int, str, str], cg.Node] = field(default_factory=dict)
    """
    (str, str, str) key is (node_tree.name, node.name, from_socket.identifier)
    """

    compute_graphs: dict[int, tuple[cg.ComputeGraph, dict[str, cg.Node]]] = field(
        default_factory=dict
    )
    """
    int key is id(node_tree)
    """

    assets: dict[tuple[type, str], cg.ComputeGraph] = field(default_factory=dict)
    """
    str key is asset.name
    """


def _find_node_blidname(
    node_tree: bpy.types.NodeTree, bl_idname: str
) -> list[bpy.types.Node]:
    return [node for node in node_tree.nodes if node.bl_idname == bl_idname]


def parse_texture(tex: t.Texture, memo: ParseMemo) -> cg.Node:
    assert tex.use_nodes
    raise NotImplementedError("Texture not implemented")
    # node_tree = parse_node_tree(tex.node_tree, memo)


def _target_attrs(node: bpy.types.Node) -> dict[str, Any]:
    attr_vals = {}
    for k in dir(node):
        if k.startswith("_"):
            continue
        elif k in bpy_node_info.SPECIAL_CASE_ATTR_NAMES:
            attr_vals[k] = None
        elif k not in bpy_node_info.UNIVERSAL_ATTR_NAMES:
            v = getattr(node, k)
            # Skip bound methods exposed alongside data properties
            # (e.g. ShaderNodeTexPointDensity.cache_point_density()).
            if callable(v):
                continue
            attr_vals[k] = v

    return attr_vals


def _is_empty_enum(node: bpy.types.Node, attr: str) -> bool:
    prop = node.bl_rna.properties.get(attr)
    return prop is not None and prop.type == "ENUM"


def _bpy_node_defaults(
    node_tree: bpy.types.NodeTree,
    node: bpy.types.Node,
    attr_keys: list[str],
) -> dict[str, Any]:
    temp_default_node = node_tree.nodes.new(node.bl_idname)
    if node.bl_idname.endswith("NodeGroup"):
        temp_default_node.node_tree = node.node_tree
    # data_type must precede operation: FunctionNodeCompare.operation enum
    # values (e.g. BRIGHTER/DARKER) are only valid for certain data_types.
    if hasattr(node, "data_type"):
        temp_default_node.data_type = node.data_type
    if hasattr(node, "operation"):
        # assert "operation" not in attr_keys, (node, attr_keys)
        temp_default_node.operation = node.operation

    attr_defaults = {}
    for k in attr_keys:
        if hasattr(temp_default_node, k):
            val = getattr(temp_default_node, k)
            # copy mathutils types before removing the node to avoid dangling pointer segfaults.
            # ID datablocks (Scene/Object/Material/...) are persistent — copying them duplicates
            # the asset and breaks equality comparison against the source node's attr.
            if hasattr(val, "copy") and not isinstance(val, bpy.types.ID):
                val = val.copy()
            attr_defaults[k] = val

    node_tree.nodes.remove(temp_default_node)

    return attr_defaults


def _remove_banned_attrs(
    attrs: dict[str, Any],
    blender_attr_vals: dict[str, Any],
):
    for k in IGNORE_ATTRS:
        res = attrs.pop(k, None)
        if (
            res is not None
            and k not in ["capture_items", "active_item"]
            and res != blender_attr_vals[k]
        ):
            logger.warning(
                f"Ignoring {k}={res} which had been changed from its default value {blender_attr_vals[k]!r}"
            )


def _parse_getattr(
    res: cg.Node,
    link: bpy.types.NodeLink,
):
    target = link.from_socket.name
    func_spec, _ = _node_to_spec(
        link.from_node.bl_idname, _target_attrs(link.from_node)
    )
    if func_spec is not None and isinstance(func_spec["output_names_map"], dict):
        target = func_spec["output_names_map"].get(target, target)
    target = identifiers.bpy_name_to_pythonid(target)

    res = cg.GetAttributeNode(attribute_name=target, source=res)
    return res


def _create_link_impl_node(
    node_tree: bpy.types.NodeTree,
    link: bpy.types.NodeLink,
    memo: ParseMemo,
) -> cg.Node:
    if link.from_node.bl_idname in ["NodeGroupInput", "NodeGroupOutput"]:
        key = (id(link.from_node), link.from_socket.identifier)
        raise ValueError(
            f"{parse_link.__name__} {node_tree.name=} {key} {link.from_node.name} {link.from_socket.name} -> {link.to_node.name} {link.to_socket.name} "
            f"has {link.from_node.bl_idname=}, which should have been avoided from parsing via {len(memo.links)=}"
        )
    elif link.from_node.bl_idname == "NodeReroute":
        if len(link.from_node.inputs[0].links) == 0:
            raise ValueError(
                f"Node {link.from_node.bl_idname} {link.from_node.name=} in {node_tree.name=} has no inputs"
            )

        # pass through to_socket to avoid potentially having wrong type inference
        res = parse_link(node_tree, link.from_node.inputs[0].links[0], memo)
        assert res is not None, link
        return res
    elif link.from_node.bl_idname == "ShaderNodeSeparateXYZ":
        inp_vec = link.from_node.inputs[0]
        if len(inp_vec.links) == 0:
            # WARN: creates multiple constants without memoizing
            default_value = normalize_default_value(inp_vec.default_value, inp_vec.type)
            source = cg.FunctionCallNode(
                func=pf.nodes.func.constant,
                args=(default_value,),
                kwargs={},
            )
        else:
            # skips over `link.from_node` since we are relying on the ProcNode getattr() to do that, rather than an explicit functioncall
            source = parse_link(node_tree, inp_vec.links[0], memo)

        return cg.GetAttributeNode(
            attribute_name=link.from_socket.name.lower(),
            source=source,
        )
    else:
        res = parse_node(node_tree, link.from_node, memo)
        assert res is not None, link

        outsockets = get_active_sockets(link.from_node.outputs)
        assert len(outsockets) > 0
        if len(outsockets) > 1:
            res = _parse_getattr(res, link)
            assert res is not None, link

        assert res is not None, link
        return res

    raise ValueError("Impossible")


def parse_link(
    node_tree: bpy.types.NodeTree,
    link: bpy.types.NodeLink,
    memo: ParseMemo,
) -> cg.Node:
    """
    Create a cg.Node for the from_node, and prefix it with a GET_ATTRIBUTE
        if this is necessary to disambiguate multiple outputs
    """

    key = (node_tree.session_uid, link.from_node.name, link.from_socket.identifier)

    if link_node := memo.links.get(key):
        assert link_node is not None, key
        res = link_node
    else:
        res = _create_link_impl_node(node_tree, link, memo)
        memo.links[key] = res

    # make all implicit blender typeconversions into explicit .astype calls
    if link.from_socket.type != link.to_socket.type:
        to_socket_type = bpy_node_info.SocketType(
            normalize_socket_type(link.to_socket.bl_idname)
        )
        to_py_type = bpy_node_info.SOCKET_TYPE_TO_PYTHON_TYPE[to_socket_type]

        assert to_py_type is not None, to_socket_type

        logger.debug(
            f"Adding explicit astype({to_py_type}) for {link.from_node.name}.{link.from_socket.name} -> "
            f"{link.to_node.name}.{link.to_socket.name}"
        )

        res = cg.MethodCallNode(
            callee=res,
            method_name="astype",
            args=(),
            kwargs={"dtype": to_py_type},
        )

    assert res is not None, link

    return res


def _create_link_input(
    node_tree: bpy.types.NodeTree,
    socket: bpy.types.NodeSocket,
    memo: ParseMemo,
    is_toplevel: bool,
    func_default: Any | None = None,
) -> cg.Node | list[cg.Node] | None:
    if len(socket.links) > 1:
        return [parse_link(node_tree, l, memo) for l in socket.links]

    if len(socket.links) == 1:
        return parse_link(node_tree, socket.links[0], memo)

    if not hasattr(socket, "default_value"):
        return None

    if isinstance(socket.default_value, SUBCOMPONENT_TYPES) and not is_toplevel:
        logger.warning(
            (
                f"Transpiler recommends against {type(socket.default_value)} as default_value but got {socket.default_value.name=} {is_toplevel=}"
                f"please use a typed socket connected from the toplevel nodegroup inputs instead. This makes your subcomponent user-configurable ",
                [socket],
            )
        )

    if isinstance(socket.default_value, bpy.types.Material):
        mat_graph = parse_material(socket.default_value, memo)
        return cg.SubgraphCallNode(subgraph=mat_graph, args=(), kwargs={})

    if isinstance(socket.default_value, bpy.types.Object):
        return t.MeshObject(socket.default_value)

    if isinstance(socket.default_value, bpy.types.Collection):
        return t.Collection(socket.default_value)

    if not hasattr(socket, "default_value"):
        return None

    if socket.hide_value:
        # hidden-value sockets are implicit fields (e.g. extrude_mesh Offset =
        # extrude along normal); their stored default is meaningless, so omit it
        # and let the binding default express the implicit behavior
        return None

    res = normalize_default_value(getattr(socket, "default_value", None), socket.type)

    if func_default is not None:
        repr_func_default = normalize_default_value(func_default, socket.type)
        # float sockets store single precision, so a python-double default like
        # 0.001 reads back inexactly - compare in float32 space
        if isinstance(res, float) and isinstance(repr_func_default, float):
            if np.float32(res) == np.float32(repr_func_default):
                return None
        elif isinstance(res, np.ndarray) or isinstance(repr_func_default, np.ndarray):
            if np.array_equal(res, repr_func_default):
                return None
        elif res == repr_func_default:
            return None

    return res


def _create_inputs(
    node_tree: bpy.types.NodeTree,
    node: bpy.types.Node,
    memo: ParseMemo,
    func_defaults: dict[str, Any],
    names: dict[str, str] | None = None,
    is_toplevel: bool = False,
) -> dict[str, cg.Node]:
    res = {}

    if names is None:
        names = {
            socket.identifier: socket.name
            for socket in node.inputs.values()
            if socket.enabled and socket.name != ""
        }
        names = identifiers.dedup_names_with_suffix(names, first_use_suffix=True)

    inputs = {}
    for identifier, name in names.items():
        socket = next(
            (s for s in node.inputs.values() if s.identifier == identifier), None
        )
        assert socket is not None

        name = identifiers.bpy_name_to_pythonid(name)

        func_default_kwarg = func_defaults.get(name, None)
        res = _create_link_input(
            node_tree,
            socket,
            memo,
            is_toplevel,
            func_default=func_default_kwarg,
        )
        if res is None:
            if name not in func_defaults:
                # binding declares this input as required (no default), but the
                # source has nothing to supply for it — emit explicit None so
                # the transpiled call still satisfies the signature
                inputs[name] = None
                continue
            logger.debug(
                f"Skipping argument for {name=} {socket.node.bl_idname=} {func_default_kwarg=}"
            )
            continue
        inputs[name] = res

    return inputs


def _find_manifest_func(
    bpy_name: str,
    mode_vals: dict[str, Any],
    manifest_indexed: pd.DataFrame,
) -> dict | None:
    if bpy_name not in manifest_indexed.index:
        return None

    candidates = manifest_indexed.loc[bpy_name]
    if isinstance(candidates, pd.Series):
        candidates = pd.DataFrame([candidates])

    exploded = candidates["bpy_mode_args"].fillna({}).apply(pd.Series)

    # Each row matches a (mode_attr, val) constraint if it specifies that
    # value, or if the row leaves the attr null (= unconstrained / any).
    # When multiple rows match, prefer the one with the most specific
    # (non-null) overlap with the requested mode_vals — that's the row
    # whose author intended to handle this exact case.
    mask = pd.Series([True] * len(candidates), index=candidates.index)
    specificity = pd.Series([0] * len(candidates), index=candidates.index)
    for mode_attr, val in mode_vals.items():
        if val is None:
            continue
        if mode_attr not in exploded.columns:
            continue
        column = exploded[mode_attr]
        explicit = column == val
        match_mask = explicit | column.isna()
        mask &= match_mask
        specificity = specificity + explicit.astype(int)

    if mask.sum() == 0:
        raise ValueError(
            f"{bpy_name=} had {len(candidates)=}, but filtering for {mode_vals=} eliminated them all"
        )

    if mask.sum() > 1:
        max_spec = specificity[mask].max()
        mask = mask & (specificity == max_spec)

    if mask.sum() > 1:
        options_for_modevals = {
            k: list(exploded[k].unique()) if k in exploded.columns else None
            for k in mode_vals.keys()
        }
        raise ValueError(
            f"Found {mask.sum()} nodes with {bpy_name=} {mode_vals=} in manifest, expected exactly 1. "
            f"Options for {mode_vals.keys()=} are {options_for_modevals}"
        )

    matches = candidates.loc[mask]

    if len(matches) == 1:
        return matches.iloc[0].to_dict()
    else:
        return None


def _node_to_spec(
    bl_idname: str,
    attrs: dict[str, Any],
) -> tuple[dict[str, Any] | None, dict]:
    if bpy_info.NodeGroupType.from_str(bl_idname) is not None:
        return None, attrs

    mode_attr_vals = {k: attrs[k] for k in MODE_ATTRS if k in attrs}

    func_row = _find_manifest_func(
        bl_idname,
        mode_attr_vals,
        _NODES_MANIFEST_INDEXED,
    )

    if func_row is None:
        raise ValueError(f"Node {bl_idname} {mode_attr_vals=} had no manifest row")
    elif func_row["name"] in ["LATER", "DECLINE"]:
        raise ValueError(f"Node {bl_idname} {mode_attr_vals=} had {func_row['name']=}")

    return func_row, attrs


def _map_inputs_with_arg_map(
    inputs: dict[str, Any],
    arg_names_map: dict[str, str | None],
) -> dict[str, Any]:
    """Rename input keys per arg_names_map. A target value of None drops
    the input entirely (used when the bpy node exposes a socket the binding
    intentionally ignores, e.g. CombineColor's Alpha for combine_rgb)."""
    mapped_inputs = {}
    for k, v in inputs.items():
        if k in arg_names_map:
            new_k = arg_names_map[k]
            if new_k is None:
                continue
            k = new_k
        mapped_inputs[k] = v

    return mapped_inputs


def _keep_attr(
    node: bpy.types.Node,
    k: str,
    v: Any,
    param: str | None,
    func_defaults: dict[str, Any],
    attr_defaults: dict[str, Any],
) -> bool:
    """Whether to emit attr `k`, or drop it because the binding already reproduces
    its value: compared against the procfunc default when `k` binds to a parameter
    that has one (these intentionally diverge from bpy's), else the bpy default.
    """
    if k == "data_type" and node.bl_idname == "GeometryNodeInputNamedAttribute":
        return True
    if v == "" and _is_empty_enum(node, k):
        return False  # state-gated enum with no valid member: nothing to set
    if param in func_defaults:
        if v == func_defaults[param] and v != attr_defaults.get(k, v):
            logger.debug(
                f"Stripping attr {k!r} ({param!r}) on {node.bl_idname}: source "
                f"value {v!r} equals procfunc default but differs from bpy "
                f"default {attr_defaults[k]!r}"
            )
        return v != func_defaults[param]
    return v != attr_defaults[k]


def parse_standard_node(
    node_tree: bpy.types.NodeTree,
    node: bpy.types.Node,
    memo: ParseMemo,
) -> cg.Node:
    # note: read/write result into memo happens at parse_node level, not here

    if node.bl_idname == "NodeGroupInput":
        raise ValueError(
            f"NodeGroupInput {node} {id(node)=} should have been pre-memo'd in parse_node_tree"
        )

    attrs = _target_attrs(node)
    func_spec, attrs = _node_to_spec(node.bl_idname, attrs)

    func = manifest.import_item_iterative(func_spec["name"].replace("pf.", "procfunc."))
    func_sig = inspect.signature(func)
    arg_names_map = func_spec.get("arg_names_map")

    func_defaults = {
        param.name: param.default
        for param in func_sig.parameters.values()
        if param.default is not param.empty
    }

    attr_defaults = _bpy_node_defaults(node_tree, node, list(attrs.keys()))
    attrs = {
        k: v
        for k, v in attrs.items()
        if _keep_attr(
            node, k, v, (arg_names_map or {}).get(k, k), func_defaults, attr_defaults
        )
    }

    if arg_names_map is not None:
        attrs = {
            arg_names_map.get(k, k): v
            for k, v in attrs.items()
            if arg_names_map.get(k, k) is not None
        }

    # we only want to remove MODE_ATTRS which were actually used to resolve the function
    #   (since presumably the restriction implied by these is already enforced by the new function signature)
    resolve_mode_args = func_spec.get("bpy_mode_args")
    if resolve_mode_args is not None:
        for k, v in resolve_mode_args.items():
            if k in attrs and k not in func_sig.parameters.keys():
                attrs.pop(k)

    # we will assume the data_types in an input .blend can always be inferred.
    #   it is the job of the .astype() insertion to preserve enough info for this
    if "data_type" in attrs and node.bl_idname != "GeometryNodeInputNamedAttribute":
        attrs.pop("data_type")

    # normalize the data_type spelling to the canonical NodeDataType
    if "data_type" in attrs:
        attrs["data_type"] = bpy_node_info.datatype_from_bpy_str(attrs["data_type"])

    inputs = _create_inputs(node_tree, node, memo, func_defaults=func_defaults)
    arg_names_map = func_spec["arg_names_map"]
    if arg_names_map is not None:
        inputs = _map_inputs_with_arg_map(inputs, arg_names_map)

    # Validate keys after arg_names_map has had a chance to rename socket-derived
    # names that aren't valid Python identifiers (e.g. IndexSwitch's '0', '1').
    for name in inputs.keys():
        if not identifiers.is_valid_snake_identifier(name):
            raise ValueError(
                f"Input name {name!r} is not a valid identifier. "
                f"{node.bl_idname=}, {node.inputs.keys()=}"
            )

    if overlap := set(attrs.keys()).intersection(set(inputs.keys())):
        raise ValueError(
            f"Node {node.bl_idname} had keys {overlap=} between {attrs.keys()=} and {inputs.keys()=}, which is invalid"
        )

    # SPECIAL_CASE_ATTR_NAMES enter attrs as None placeholders. Handlers that
    # need them read from the bpy node directly, so the placeholder is never
    # useful to the function call — drop before constructing the cg_node.
    placeholder_attrs = {**attrs, **inputs}
    for k in bpy_node_info.SPECIAL_CASE_ATTR_NAMES:
        if placeholder_attrs.get(k, ...) is None:
            placeholder_attrs.pop(k, None)

    cg_node = cg.FunctionCallNode(
        func=func,
        args=(),
        kwargs=placeholder_attrs,
    )

    cg_node_orig = cg_node
    if handler := SPECIAL_CASE_NODES.get(node.bl_idname):
        cg_node = handler(node, cg_node)

    _remove_banned_attrs(cg_node.kwargs, attr_defaults)

    signature = inspect.signature(func)

    # Check if the function accepts **kwargs (VAR_KEYWORD)
    has_var_keyword = any(
        p.kind == inspect.Parameter.VAR_KEYWORD for p in signature.parameters.values()
    )

    # Only check for missing parameters if the function doesn't accept **kwargs
    excess_kwargs = set(cg_node_orig.kwargs.keys()) - set(signature.parameters.keys())
    if not has_var_keyword and excess_kwargs:
        node_mode = getattr(node, "mode", None)
        node_operation = getattr(node, "operation", None)
        node_data_type = getattr(node, "data_type", None)
        raise ValueError(
            f"Codegen would attempt to call {func.__name__=} with {excess_kwargs} "
            f"but these attributes do not exist in the procfunc signature, which had {list(signature.parameters.keys())} "
            f"source node had {node.bl_idname} {node.inputs.keys()=} {node_mode=} {node_operation=} {node_data_type=} "
            "Please contact the developers."
        )

    return cg_node


def parse_nodegroup_call(
    node_tree: bpy.types.NodeTree,
    node: bpy.types.Node,
    memo: ParseMemo,
) -> cg.Node:
    assert hasattr(node, "node_tree"), f"Node {node.bl_idname} has no node_tree"

    sockets = [socket for socket in node.inputs.values() if socket.enabled]
    input_names = {socket.identifier: socket.name for socket in sockets}
    input_names = identifiers.apply_panel_names_to_input_names(
        node.node_tree, input_names, only_dedup=False
    )

    with log.add_exception_context_msg(f"While processing {node.name=}:"):
        graph, _ = parse_node_tree(node.node_tree, memo)

    func_defaults = {
        name: value.kwargs.get("default_value", None)
        for name, value in graph.inputs.items()
    }
    inputs = _create_inputs(
        node_tree,
        node,
        memo,
        func_defaults=func_defaults,
        is_toplevel=False,
        names=input_names,
    )

    return cg.SubgraphCallNode(
        subgraph=graph,
        args=(),
        kwargs=inputs,
    )


def _parse_constant_node(node: bpy.types.Node) -> cg.Node:
    attr_name = bpy_node_info.CONSTANT_NODES[node.bl_idname]

    if attr_name == "DEFAULT_VALUE":
        val = node.outputs[0].default_value
    else:
        val = getattr(node, attr_name)

    return cg.ConstantNode(value=val)


def parse_node(
    node_tree: bpy.types.NodeTree,
    node: bpy.types.Node,
    memo: ParseMemo,
) -> cg.Node:
    memo_key = (node_tree.session_uid, node.name)
    if node_node := memo.nodes.get(memo_key):
        return node_node

    # logger.debug(f"Parsing node {node_tree.name} {node.bl_idname}")

    if bpy_info.NodeGroupType.from_str(node.bl_idname) is not None:
        res = parse_nodegroup_call(node_tree, node, memo)
    elif node.bl_idname == "NodeReroute":
        raise ValueError(
            f"Node {node.bl_idname} {node.name=} is a NodeReroute, which should have been folded away"
        )
    else:
        res = parse_standard_node(node_tree, node, memo)

    if node.label != "":
        res.metadata["varname"] = identifiers.bpy_name_to_pythonid(node.label)

    memo.nodes[memo_key] = res
    return res


def _find_output_node(
    node_tree: bpy.types.NodeTree,
):
    node_tree_type = bpy_info.NodeTreeType(node_tree.bl_idname)
    ng_type = bpy_info.NODETREE_TO_NODEGROUP[node_tree_type]
    main_output_node_type = bpy_info.NODETREE_TYPE_TO_MAIN_OUTPUT[node_tree_type]

    output_nodes_ng = _find_node_blidname(node_tree, "NodeGroupOutput")
    output_nodes_ctx = _find_node_blidname(node_tree, main_output_node_type)
    output_nodes = output_nodes_ng + output_nodes_ctx

    if len(output_nodes) > 1:
        raise ValueError(f"Found mutltiple {output_nodes=} for {node_tree=}")
    if len(output_nodes) == 0:
        idnames = set(node.bl_idname for node in node_tree.nodes)
        raise ValueError(
            f"No {main_output_node_type=} found "
            f"for {node_tree.bl_idname} of type {ng_type} with {idnames}"
        )
    return output_nodes[0]


def _name_from_socket_and_panels(
    socket: bpy.types.NodeSocket,
    panels: list[bpy.types.NodeSocket],
) -> str:
    for panel in panels:
        match = next(
            (
                psock
                for psock in panel.interface_items.values()
                if psock.identifier == socket.identifier
            ),
            None,
        )
        if match is not None:
            return identifiers.bpy_name_to_pythonid(panel.name + "_" + socket.name)
    return identifiers.bpy_name_to_pythonid(socket.name)


def _infer_geometry_type(node: cg.Node, _depth: int = 0) -> type | None:
    """Infer concrete geometry type by inspecting the procfunc function that produces this node."""
    if _depth > 10:
        return None

    if isinstance(node, cg.FunctionCallNode):
        try:
            hints = get_type_hints(node.func)
        except Exception:
            return None
        return_type = hints.get("return")
        if return_type is None:
            return None
        concrete = _extract_procnode_inner_type(return_type)
        if concrete is not None:
            return concrete
        # Return type is generic (TypeVar) — recurse into geometry input args
        for arg in list(node.args) + list(node.kwargs.values()):
            if isinstance(arg, list):
                for item in arg:
                    if isinstance(item, cg.Node):
                        result = _infer_geometry_type(item, _depth + 1)
                        if result is not None:
                            return result
            elif isinstance(arg, cg.Node):
                result = _infer_geometry_type(arg, _depth + 1)
                if result is not None:
                    return result
        return None

    if isinstance(node, cg.GetAttributeNode):
        source = node.args[0]
        if not isinstance(source, cg.FunctionCallNode):
            return None
        try:
            hints = get_type_hints(source.func)
        except Exception:
            return None
        return_type = hints.get("return")
        if return_type is None:
            return None
        if hasattr(return_type, "__annotations__"):
            field_type = return_type.__annotations__.get(node.attribute_name)
            if field_type is not None:
                return _extract_procnode_inner_type(field_type)
        return None

    return None


def _extract_procnode_inner_type(t: type) -> type | None:
    origin = get_origin(t)
    if origin is nt.ProcNode:
        args = get_args(t)
        if args and not isinstance(args[0], TypeVar):
            return args[0]
    return None


def _socket_to_pf_type(
    socket: bpy.types.NodeSocket,
    is_output: bool,
    interface: Any | None = None,  # unsure type
    use_socket_bounds: bool = False,
    use_specialized_sockets: bool = False,
) -> type:
    """
    Create a python typing str to represent the interface bounds of a blender _socket_to_pf_type

    Args:
        socket: The blender socket to create a python typing str for
        interface: The nodegroup interface of the socket, if one exists
        use_socket_bounds: Whether to use the bounds of the socket's interface.
            Disabled by default since these are often accidentally / imprecisely filled by implementers

    Returns:
        A python typing str to represent the interface bounds of a blender socket
    """

    type_str = normalize_socket_type(socket.bl_idname)
    st = bpy_node_info.SocketType(type_str)
    py_type = bpy_node_info.SOCKET_TYPE_TO_PYTHON_TYPE[st]

    bounds = [None, None]

    if use_socket_bounds and interface is not None:
        if hasattr(interface, "min_value") and interface.min_value > -1000:
            bounds[0] = interface.min_value
        if hasattr(interface, "max_value") and interface.max_value < 1000:
            bounds[1] = interface.max_value
    elif use_specialized_sockets and type_str != socket.bl_idname:
        match socket.bl_idname:
            case "NodeSocketFloatFactor":
                bounds = [0.0, 1.0]
            case "NodeSocketVectorEuler":
                bounds = [(0.0, 0.0, 0.0), (3.141592, 3.141592, 3.141592)]
            case _:
                logger.info(
                    f"No implemented bounds annot for special socket {socket.bl_idname}"
                )

    if bounds[0] is not None or bounds[1] is not None:
        raise NotImplementedError("Range annotation current not supported")
        # py_type = Annotated[float, t.ValueRange{tuple(bounds)}]

    if py_type is None:
        return nt.ProcNode
    elif is_output:
        return nt.ProcNode[py_type]
    else:
        return nt.SocketOrVal[py_type]


def _placeholder_for_graph_input(
    socket: bpy.types.NodeSocket,
    varname: str,
    node_tree: bpy.types.NodeTree,
) -> cg.Node:
    interface = node_tree.interface.items_tree[socket.name]
    if interface.hide_value:
        logger.warning(
            f"{node_tree.name=} {socket.name=} has hide_value=True, "
            "overwriting it to False or else current transpiler implementationl will break"
        )
        interface.hide_value = False

    inner_type = _socket_to_pf_type(
        socket,
        is_output=False,
        interface=interface,
    )

    raw_default = getattr(interface, "default_value", None)
    if raw_default is None:
        raw_default = getattr(socket, "default_value", None)
    default_value = normalize_default_value(raw_default, socket.type)

    norm_soc = normalize_socket_type(socket.bl_idname)
    if default_value is None and norm_soc in [
        "NodeSocketFloat",
        "NodeSocketVector",
    ]:
        raise ValueError(f"{socket.name=} has no default_value and is a {norm_soc=}")

    node = cg.InputPlaceholderNode(
        name=varname,
        default_value=default_value,
        metadata=dict(
            known_value_type=inner_type,
            varname=varname,
        ),
    )
    if default_value is not None:
        node.kwargs["default_value"] = default_value

    return node


def _create_and_memoize_input_placeholders(
    node_tree: bpy.types.NodeTree,
    input_nodes: list[bpy.types.Node],
    memo: ParseMemo,
) -> tuple[dict[str, cg.Node], dict[str, cg.Node]]:
    """
    prefill memo for all output links of all input nodes
    this is so that we never later recurse onto input nodes, since we dont actually want to create input nodes
    the output identifiers of input nodes are defined by the function args instead.

    also: create the placeholder nodes for the input nodes
    """

    panels = [
        socket
        for socket in node_tree.interface.items_tree.values()
        if socket.item_type == "PANEL"
    ]

    placeholders = {}
    id_to_node = {}
    for input_node in input_nodes:
        active_sockets = get_active_sockets(input_node.outputs)
        for socket in active_sockets:
            key = (node_tree.session_uid, input_node.name, socket.identifier)
            if socket.identifier in id_to_node:
                memo.links[key] = id_to_node[socket.identifier]
                continue

            varname = _name_from_socket_and_panels(socket, panels)
            node = _placeholder_for_graph_input(socket, varname, node_tree)

            if varname in placeholders:
                raise ValueError(
                    f"Duplicate varname {varname} for {key=} in {input_node.outputs.keys()=} with existing {placeholders.keys()=}"
                )
            id_to_node[socket.identifier] = node
            placeholders[varname] = node
            memo.links[key] = node

    # logger.debug(
    #    f"{_create_and_memoize_input_placeholders.__name__} {node_tree.name=} memoized {placeholders.keys()=}"
    # )

    return placeholders, id_to_node


[docs] def parse_node_tree( node_tree: bpy.types.NodeTree, memo: ParseMemo, ) -> tuple[cg.ComputeGraph, dict[str, cg.Node]]: """ Note: recursive over nodes and node_trees which is not ideal. TODO convert to stack breadth-first """ assert node_tree.name != "Shader Nodetree", ( "nodetree name must be nondefault as we use it as a hash key" ) memo_key = node_tree.session_uid if res := memo.compute_graphs.get(memo_key): return res cg_name = node_tree.name if cg_name.startswith("nodegroup_"): cg_name = cg_name.replace("nodegroup_", "") if cg_name.endswith(" (no gc)"): # comes from v1 to_nodegroup singleton=True cg_name = cg_name.replace(" (no gc)", "") cg_name = identifiers.bpy_name_to_pythonid(cg_name) name_parts = cg_name.split("_") if name_parts[0].isdigit(): cg_name = "_".join(name_parts[1:]) + "_" + name_parts[0] if not identifiers.is_valid_snake_identifier(cg_name): raise ValueError(f"Invalid cg_name {cg_name} for {node_tree.name}") input_nodes = _find_node_blidname(node_tree, "NodeGroupInput") output_node = _find_output_node(node_tree) inputs, id_to_node = _create_and_memoize_input_placeholders( node_tree, input_nodes, memo ) outputs = {} for output_name, output_result_socket in output_node.inputs.items(): if output_name == "": continue # nodegroups seem to have an empty socket with identifier __extend__, skip it if len(output_result_socket.links) == 0: continue if len(output_result_socket.links) > 1: raise ValueError( f"Node {node_tree.bl_idname} has multiple inputs for {output_name=} {output_result_socket.identifier=} " "Multi-output sockets not supported on nodegroup. please contact the developers." ) # logger.debug( # f"Parsing link for {cg_name=} {output_name=} {output_result_socket.identifier=}" # ) output_name = identifiers.bpy_name_to_pythonid(output_name) proc_node = parse_link(node_tree, output_result_socket.links[0], memo) if proc_node.metadata.get("known_value_type") is None: inferred = _infer_geometry_type(proc_node) if inferred is not None: vt = nt.ProcNode[inferred] logger.debug( f"Inferred known_value_type={vt} for {proc_node=} from function signature" ) else: vt = _socket_to_pf_type( output_result_socket, is_output=True, ) logger.debug( f"Setting known_value_type={vt} for {proc_node=} for {output_result_socket=}" ) proc_node.metadata["known_value_type"] = vt if isinstance(proc_node, cg.InputPlaceholderNode): proc_node.default_value = None outputs[output_name] = proc_node if len(outputs) == 0: # Sink-style trees (e.g. compositor with only Composite/Viewer, no # NodeGroupOutput links) produce a graph with no outputs. Use an empty # dict so PyTree flattens to len==0 (vs PyTree(None) which has len==1). output = {} elif len(outputs) > 1: assert " " not in cg_name, f"cg_name {cg_name} contains spaces" # remove .001 suffixes typename = identifiers.snake_to_pascal(cg_name).rsplit(".", 1)[0] + "Result" output_type = namedtuple(typename, outputs.keys()) output = output_type(**outputs) else: output = list(outputs.values())[0] compute_graph = cg.ComputeGraph( inputs=pytree.PyTree(inputs), outputs=pytree.PyTree(output), name=cg_name, metadata={ "is_node_function": True, # causes codegen to apply decorator }, ) logger.debug( f"Parsed node_tree {cg_name} with {len(inputs.keys())} inputs, {len(outputs.keys())} outputs " f"and {len(list(cg.traverse_depth_first(compute_graph)))} nodes" ) memo.compute_graphs[memo_key] = (compute_graph, id_to_node) return compute_graph, id_to_node
def _parse_geomod_input( mod: bpy.types.Modifier, name: str, node_curr: cg.Node, memo: ParseMemo, ) -> cg.Node | None: socket = mod.node_group.interface.items_tree[name] if socket.socket_type == bni.SocketType.GEOMETRY.value: return node_curr value = mod[socket.identifier] if isinstance(value, bpy.types.Material): mat_graph = parse_material(value, memo) return cg.SubgraphCallNode(subgraph=mat_graph, args=(), kwargs={}) if isinstance(value, bpy.types.Object): return t.MeshObject(value) if isinstance(value, bpy.types.Collection): return t.Collection(value) datatype = bni.SOCKET_CLASS_TO_DATATYPE[socket.socket_type] dtype = bni.DATATYPE_TO_SOCKET_DTYPE[datatype].value return normalize_default_value(value, dtype) def parse_geo_modifier( obj: bpy.types.Object, node_curr: cg.Node, mod: bpy.types.Modifier, memo: ParseMemo, ) -> cg.Node: # TODO: need to find and memoize the input geometries and input attribute assignments. logger.info(f"Parsing geometry node modifier {mod.node_group.name}") graph, id_to_node = parse_node_tree(mod.node_group, memo) inputs = {} for name, soc in mod.node_group.interface.items_tree.items(): if soc.in_out != "INPUT": continue if soc.identifier not in id_to_node: # Interface declares this input but nothing inside the nodetree reads it # (e.g. NodeGroupInput's Geometry output is unlinked). Skip — the subgraph # has no placeholder for it, so passing a kwarg would be a type error. logger.debug( f"Skipping unused modifier input {soc.name=} {soc.identifier=}" ) continue parsed_name = id_to_node[soc.identifier].metadata.get("varname", None) assert parsed_name is not None, id_to_node[soc.identifier] inputs[parsed_name] = _parse_geomod_input(mod, name, node_curr, memo) node_curr = cg.SubgraphCallNode(subgraph=graph, args=(), kwargs=inputs) # ignore unlinked group output sockets: they carry no data output_nodes = [n for n in mod.node_group.nodes if n.bl_idname == "NodeGroupOutput"] active_output = next( (n for n in output_nodes if n.is_active_output), output_nodes[0] ) connected_output_ids = {s.identifier for s in active_output.inputs if s.links} output_socs = [ soc for soc in mod.node_group.interface.items_tree.values() if soc.in_out == "OUTPUT" and soc.identifier in connected_output_ids ] geo_output_keys = [ soc.name.lower() for soc in output_socs if soc.socket_type == bni.SocketType.GEOMETRY.value ] attribute_output_keys = [ soc.name.lower() for soc in output_socs if soc.socket_type != bni.SocketType.GEOMETRY.value ] geo_output_getattrs = { k: cg.GetAttributeNode(source=node_curr, attribute_name=k) for k in geo_output_keys } attribute_output_getattrs = { k: cg.GetAttributeNode(source=node_curr, attribute_name=k) for k in attribute_output_keys } match len(geo_output_keys), len(attribute_output_keys): case 1, 0: return cg.FunctionCallNode( pf.nodes.to_mesh_object, args=(node_curr,), kwargs={} ) case 1, _: return cg.FunctionCallNode( pf.nodes.to_mesh_object_with_attributes, kwargs={**geo_output_getattrs, **attribute_output_getattrs}, ) case _, _: return cg.FunctionCallNode( pf.nodes.to_objects_multi, args=(geo_output_getattrs, attribute_output_getattrs), kwargs={}, )
[docs] def parse_modifier( obj: bpy.types.Object, node_curr: cg.Node, mod: bpy.types.Modifier, memo: ParseMemo, ) -> cg.Node: if mod.type == "NODES": return parse_geo_modifier(obj, node_curr, mod, memo) mode_vals = {"type": mod.type, "operation": getattr(mod, "operation", None)} func_row = _find_manifest_func( "bpy.ops.object.modifier_add", mode_vals, _OPS_MANIFEST_INDEXED ) if func_row is None: raise ValueError(f"Modifier {mod.type} {mode_vals=} not found in manifest") func_name = func_row["name"].replace("pf.", "procfunc.") func = manifest.import_item_iterative(func_name) inputs = { k: getattr(mod, k) for k in inspect.signature(func).parameters.keys() if k != "mutates_obj" and hasattr(mod, k) } res = cg.FunctionCallNode(func=func, args=(node_curr,), kwargs=inputs) return cg.MutatedArgumentNode(original_node=node_curr, mutator_call_node=res)
def _replace_vector_inpnodes_as_arg( node_tree: bpy.types.NodeTree, memo: ParseMemo, ) -> cg.Node: vector_input_nodes = _find_node_blidname(node_tree, "ShaderNodeTexCoord") vector_input_nodes += _find_node_blidname(node_tree, "ShaderNodeNewGeometry") vector_links = [ link for node in vector_input_nodes for output in node.outputs.values() for link in output.links ] vector_placeholder = cg.InputPlaceholderNode( name="vector", default_value=None, metadata={"known_value_type": pf.ProcNode[pf.Vector], "varname": "vector"}, ) for node in vector_input_nodes: memo.nodes[(node_tree.session_uid, node.name)] = vector_placeholder for link in vector_links: key = (node_tree.session_uid, link.from_node.name, link.from_socket.identifier) memo.links[key] = vector_placeholder return vector_placeholder _MATERIAL_OUTPUT_SOCKETS = ["Surface", "Displacement", "Volume"]
[docs] def parse_material( mat: bpy.types.Material, memo: ParseMemo, coord_inp_as_arg: bool = False ) -> cg.ComputeGraph: memo_key = (type(mat), mat.name) if mat_node := memo.assets.get(memo_key): return mat_node node_tree = mat.node_tree inputs_dict = {} if coord_inp_as_arg: vector_placeholder = _replace_vector_inpnodes_as_arg(node_tree, memo) inputs_dict["vector"] = vector_placeholder (output_node,) = _find_node_blidname(node_tree, "ShaderNodeOutputMaterial") outputs_dict = {} for key in _MATERIAL_OUTPUT_SOCKETS: expect_type = pf.Vector if key == "Displacement" else pf.Shader if output_node.inputs[key].is_linked: res = parse_link(node_tree, output_node.inputs[key].links[0], memo) res.metadata["known_value_type"] = pf.ProcNode[expect_type] else: res = cg.ConstantNode(value=None) res.metadata["known_value_type"] = Union[pf.ProcNode[expect_type], None] outputs_dict[key.lower()] = res func_name = identifiers.bpy_name_to_pythonid(mat.name) graph = cg.ComputeGraph( inputs=pytree.PyTree(inputs_dict), outputs=pytree.PyTree(t.Material(**outputs_dict)), name=func_name, metadata={}, # TODO ) memo.assets[memo_key] = graph return graph
[docs] def parse_primitive(obj: t.Object) -> cg.Node: return cg.FunctionCallNode(pf.ops.primitives.mesh_monkey, args=(), kwargs={})
[docs] def parse_object( obj: bpy.types.Object, memo: ParseMemo, object_mode: Literal["monkey", "active", "named"] = "monkey", include_set_material: bool = True, ) -> cg.ComputeGraph: memo_key = (type(obj), obj.name) if obj_node := memo.assets.get(memo_key): return obj_node # TODO assert starting from single vertex? match object_mode: case "monkey": node_curr = cg.FunctionCallNode( pf.ops.primitives.mesh_monkey, args=(), kwargs={} ) case "active": node_curr = cg.ConstantNode( value=cg.LiteralConstant("pf.MeshObject(bpy.context.active_object)") ) case "named": node_curr = cg.ConstantNode( value=cg.LiteralConstant( f"pf.MeshObject(bpy.data.objects[{obj.name!r}])" ) ) case _: raise ValueError(f"Invalid object mode: {object_mode}") coord = cg.FunctionCallNode(pf.nodes.shader.coord, args=(), kwargs={}) coord = cg.GetAttributeNode(source=coord, attribute_name="generated") if include_set_material: for mat in obj.material_slots: mat_graph = parse_material(mat.material, memo) mat_kwargs = ( {"vector": coord} if "vector" in mat_graph.inputs.obj().keys() else {} ) mat_call = cg.SubgraphCallNode( subgraph=mat_graph, args=(), kwargs=mat_kwargs ) node_curr = cg.FunctionCallNode( pf.ops.object.set_material, args=(node_curr,), kwargs={"material": mat_call}, ) for mod in obj.modifiers: node_curr = parse_modifier(obj, node_curr, mod, memo) name = "object_" + identifiers.bpy_name_to_pythonid(obj.name) + "_generate" graph = cg.ComputeGraph( inputs=pytree.PyTree({}), outputs=pytree.PyTree({"result": node_curr}), name=name, metadata={"func": parse_object, "object": obj.name}, ) memo.assets[memo_key] = graph return graph
[docs] def parse_scene( scene: bpy.types.Scene, memo: ParseMemo, ) -> cg.Node: raise NotImplementedError("Scene not implemented")