Source code for procfunc.transforms.extract_materials
"""
Transform to extract material SubgraphCallNodes from geometry node functions,
making them pure by moving material calls to the caller.
"""
import logging
import procfunc as pf
from procfunc import compute_graph as cg
from procfunc.nodes import types as nt
from procfunc.util import pytree
logger = logging.getLogger(__name__)
def _is_material_subgraph(subgraph: cg.ComputeGraph) -> bool:
return subgraph.outputs.toplevel_type() is pf.Material
def _sanitize_name(name: str) -> str:
return f"material_{name.replace('.', '_').replace(' ', '_').lower()}"
def _add_input(graph: cg.ComputeGraph, name: str) -> cg.InputPlaceholderNode:
inp = cg.InputPlaceholderNode(
name=name,
default_value=None,
metadata={"known_value_type": nt.ProcNode[pf.Material], "varname": name},
)
inputs = graph.inputs.obj()
inputs[name] = inp
graph.inputs = pytree.PyTree(inputs)
return inp
def _build_parent_map(
top_graph: cg.ComputeGraph,
) -> dict[int, tuple[cg.ComputeGraph, cg.SubgraphCallNode]]:
parent_map = {}
for call_node, graph in cg.traverse_nested_graphs(top_graph, yield_call_nodes=True):
for node in cg.traverse_depth_first(graph):
if isinstance(node, cg.SubgraphCallNode):
parent_map[id(node.subgraph)] = (graph, node)
return parent_map
def _replace_node_in_graph(
graph: cg.ComputeGraph,
old_node: cg.Node,
new_node: cg.Node,
) -> None:
for node in cg.traverse_depth_first(graph):
new_args = tuple(new_node if arg is old_node else arg for arg in node.args)
if new_args != node.args:
node.args = new_args
for key, val in list(node.kwargs.items()):
if val is old_node:
node.kwargs[key] = new_node
[docs]
def extract_materials_from_graph(
top_graph: cg.ComputeGraph,
) -> dict[str, cg.SubgraphCallNode]:
parent_map = _build_parent_map(top_graph)
extracted_materials = {}
for graph in cg.traverse_nested_graphs(top_graph):
if not graph.metadata.get("is_node_function", False):
continue
material_calls = []
for node in cg.traverse_depth_first(graph):
if isinstance(node, cg.SubgraphCallNode) and _is_material_subgraph(
node.subgraph
):
material_calls.append(node)
for mat_call in material_calls:
input_name = _sanitize_name(mat_call.subgraph.name)
inp = _add_input(graph, input_name)
_replace_node_in_graph(graph, mat_call, inp)
current_graph = graph
while id(current_graph) in parent_map:
parent_graph, call_node = parent_map[id(current_graph)]
if not parent_graph.metadata.get("is_node_function", False):
item_node = cg.MethodCallNode(mat_call, "item", args=(), kwargs={})
call_node.kwargs[input_name] = item_node
break
if input_name not in call_node.kwargs:
parent_inp = _add_input(parent_graph, input_name)
call_node.kwargs[input_name] = parent_inp
current_graph = parent_graph
extracted_materials[input_name] = mat_call
logger.debug(
f"Extracted material '{mat_call.subgraph.name}' from {graph.name}"
)
return extracted_materials
[docs]
def extract_materials_from_graphs(
graphs: list[cg.ComputeGraph],
) -> list[cg.ComputeGraph]:
for graph in graphs:
materials = extract_materials_from_graph(graph)
if materials:
logger.info(
f"Extracted materials from {graph.name}: {list(materials.keys())}"
)
return graphs