import logging
from collections import defaultdict
from typing import Any, Callable
from procfunc import compute_graph as cg
from procfunc import types as t
from procfunc.nodes import types as nt
from procfunc.nodes.shader import coord, geometry
from procfunc.util import pytree
logger = logging.getLogger(__name__)
[docs]
def remove_v1_name_from_graph(
_call_node: cg.Node, graph: cg.ComputeGraph
) -> cg.ComputeGraph:
if graph.name.startswith("nodegroup_"):
graph.name = graph.name.replace("nodegroup_", "")
if graph.name.startswith("shader_"):
graph.name = graph.name.replace("shader_", "")
return graph
[docs]
def eliminate_duplicate_subgraphs(
graphs: list[cg.ComputeGraph],
) -> list[cg.ComputeGraph]:
unique: list[cg.ComputeGraph] = []
removed: list[cg.ComputeGraph] = []
# maps id(duplicate_subgraph) -> canonical subgraph to replace it with
replacements: dict[int, cg.ComputeGraph] = {}
for topgraph in graphs:
subgraphs = reversed(
list(cg.traverse_nested_graphs(topgraph, yield_call_nodes=True))
)
for _call_node, subgraph in subgraphs:
match = next((g for g in unique if cg.graph_nodes_equal(subgraph, g)), None)
if match is not None:
if len(subgraph.name) < len(match.name):
match.name = subgraph.name
replacements[id(subgraph)] = match
removed.append(subgraph)
else:
unique.append(subgraph)
# second pass: update ALL call nodes (in all nested subgraphs) that reference a replaced subgraph
for topgraph in graphs:
for subgraph in cg.traverse_nested_graphs(topgraph):
for node in cg.traverse_depth_first(subgraph):
if (
isinstance(node, cg.SubgraphCallNode)
and id(node.subgraph) in replacements
):
node.subgraph = replacements[id(node.subgraph)]
logger.debug(f"Eliminated duplicated subgraphs {[g.name for g in removed]}")
return graphs
[docs]
def eliminate_duplicate_result_types(
graphs: list[cg.ComputeGraph],
uses_threshold: int = 1,
) -> list[cg.ComputeGraph]:
rettype_uses: dict[type, list[cg.ComputeGraph]] = defaultdict(list)
for graph in graphs:
for subgraph in cg.traverse_nested_graphs(graph):
result_type = subgraph.outputs.toplevel_type()
if result_type is None or not pytree.is_type_namedtuple(result_type):
continue
for rt in rettype_uses.keys():
if list(rt._fields) == list(result_type._fields):
result_type = rt
break
rettype_uses[result_type].append(subgraph)
for rettype, uses in rettype_uses.items():
if len(uses) <= uses_threshold:
continue
first_rettype = uses[0].outputs.toplevel_type()
for subgraph in uses[1:]:
subgraph.outputs.spec.container = first_rettype
return graphs
[docs]
def fill_graph_defaults_with_call_node(
call_node: cg.SubgraphCallNode,
graph: cg.ComputeGraph,
) -> cg.ComputeGraph:
if call_node is None:
return graph
if any(
isinstance(arg.default_value, float) and arg.default_value != 0.0
for arg in graph.inputs.values()
):
logger.debug(
f"Skipping {graph.name} because it has nondefault existing default args"
)
return graph
for name, inpnode in graph.inputs.items():
fillval = call_node.kwargs.get(name, None)
if fillval is not None and not isinstance(fillval, cg.Node):
inpnode.kwargs["default_value"] = fillval
return graph
[docs]
def coerce_shaders_to_materialresult(
_call_node: cg.Node, subgraph: cg.ComputeGraph
) -> cg.ComputeGraph:
if subgraph.outputs.toplevel_type() is t.Material:
return subgraph
outputs = subgraph.outputs.dict()
surface = outputs.get("surface") or outputs.get("bsdf")
if surface is None:
return subgraph
shader_outputs = {
"surface": surface,
"displacement": outputs.get("displacement"),
"volume": outputs.get("volume"),
}
if len(outputs) > len(shader_outputs):
logger.warning(
f"{coerce_shaders_to_materialresult.__name__} skipping due to extra outputs: {outputs.keys()}"
)
return subgraph
logger.debug(
f"{coerce_shaders_to_materialresult.__name__} converted {subgraph.name} output"
)
subgraph.outputs = pytree.PyTree(t.Material(**shader_outputs))
return subgraph
[docs]
def replace_ids(
graph: cg.ComputeGraph,
ids: set[int],
val: Any,
):
"""
Pull out hardcoded arguments to be inputs to the graph instead
Args:
graph: The graph to extract constants from
extract_mask: A mask of which args to extract. The key is a tuple of the parent node id and the arg name.
"""
assert isinstance(graph, cg.ComputeGraph)
for name, parent, child in cg.traverse_depth_first(
graph, yield_consts=True, yield_name=True, yield_parent=True
):
if id(child) not in ids:
continue
if isinstance(name, int):
args = list(parent.args)
args[name] = val
parent.args = tuple(args)
else:
parent.kwargs[name] = val
return graph
def extract_as_input(
graph: cg.ComputeGraph,
nodes: set[int],
name: str,
arg_type: type,
):
inp = cg.InputPlaceholderNode(
default_value=None, metadata={"known_value_type": arg_type, "varname": name}
)
inputs = graph.inputs.obj()
assert isinstance(inputs, dict), inputs
inputs[name] = inp
graph.inputs = pytree.PyTree(inputs)
return replace_ids(graph, nodes, inp)