Source code for procfunc.tracer.trace

"""
torch.fx / jax-like function-compute-graph tracing tool, but specially designed for procedural generation functions


This tool was heavily inspired by torch.fx.symbolic_trace https://docs.pytorch.org/docs/2.6/fx.html
"""

import builtins
import inspect
import logging
import math
import random
from types import ModuleType
from typing import Any, Callable

import numpy as np
import numpy.linalg

import procfunc as pf
from procfunc import compute_graph as cg
from procfunc.util import pytree

from .patch import (
    Patcher,
    PatchFunctionTarget,
    TraceLevel,
)
from .proxy import RngProxy

logger = logging.getLogger(__name__)

_autowrap_modules: list[tuple[ModuleType, bool, TraceLevel]] = []


[docs] def autowrap_module( module: ModuleType, allow_exec: bool = False, trace_level: TraceLevel = TraceLevel.PRIMITIVES, ): _autowrap_modules.append((module, allow_exec, trace_level))
autowrap_module(math, allow_exec=True) autowrap_module(np, allow_exec=True) autowrap_module(numpy.linalg, allow_exec=True) _banned_modules: list[ModuleType] = [ np.random, random, ]
[docs] def add_banned_module(module: ModuleType): _banned_modules.append(module)
_patch_function_targets: list[PatchFunctionTarget] = []
[docs] def add_wrap_target(target: PatchFunctionTarget): _patch_function_targets.append(target)
WRAP_BUILTINS = ["min", "max", "abs", "round", "sum"] for _builtin_name in WRAP_BUILTINS: add_wrap_target( PatchFunctionTarget( frame=builtins.__dict__, name=_builtin_name, trace_level=TraceLevel.PRIMITIVES, normalize=False, allow_exec=True, source_name="builtins", ) ) WRAP_CONSTRUCTORS = ["Vector", "Color", "Euler"] for _name in WRAP_CONSTRUCTORS: add_wrap_target( PatchFunctionTarget( frame=pf.__dict__, name=_name, trace_level=TraceLevel.PRIMITIVES, normalize=False, allow_exec=True, source_name="mathutils", ) ) _search_scopes: list[dict] = [] def add_search_scope(module: ModuleType): """ Causes an intermediate module to be searched to discover any existing targets that need to be patched e.g. for procfunc.nodes.to_mesh_object, the original to_mesh_object is already a target, but we need to add_search_scope on the `nodes` module so that that module's references to to_mesh_object get wrapped. """ _search_scopes.append(module.__dict__) def _map_args( func: Callable, **inputs: Any, ) -> dict[str, cg.Proxy | Any]: signature = inspect.signature(func) res = {} for name, param in signature.parameters.items(): if name in inputs: val = inputs[name] if isinstance(val, np.random.Generator): node = cg.ConstantNode(value=val) res[name] = RngProxy(node, val, dirty=False) elif isinstance(val, cg.ConstantNode) and isinstance( val.value, np.random.Generator ): res[name] = RngProxy(val, val.value, dirty=False) else: res[name] = val continue if param.default is not param.empty: res[name] = param.default continue node = cg.InputPlaceholderNode( name=name, default_value=None, metadata={"varname": name} ) res[name] = cg.Proxy(node) return res
[docs] def trace( func: Callable, trace_level: TraceLevel = TraceLevel.GENERATORS, name: str | None = None, **inputs: Any, ): """ Turn a python function into a graph datastructure. Using this datastructure is (usually) equivelent to executing the function. Args: func: The function to trace trace_level: Granularity of the graph. Functions at this level become leaves; finer functions are traced through. choice() peeks through all options when trace_level >= RANDOM_CONTROL, or resolves to the chosen branch when finer. """ logger.debug(f"Tracing {func} {id(func)=} with {inputs=} {trace_level=}") if name is None: assert hasattr(func, "__name__") assert isinstance(func.__name__, str) name = func.__name__ proxy_args = _map_args(func, **inputs) if pf.context.globals.current_trace_level is not None: # TODO we can lift this restriction fairly(?) easily by having a global patcher & saving/restoring this state # rather than setting to false at end of function. see fx.trace for an example raise RuntimeError( f"Can't trace {name}, tracing is already in progress for another function. " "Nested tracing is not yet supported - contact the developers to request" ) pf.context.globals.current_trace_level = trace_level.value patcher = Patcher( trace_level=trace_level, autopatch_wrap_modules=_autowrap_modules, autopatch_remove_modules=_banned_modules, search_scopes=_search_scopes, patch_functions=_patch_function_targets, ) try: func = patcher.apply_preexecute_patches(func, trace_level) logger.debug(f"Executing {func.__name__} {id(func)=}") func_result = func(**proxy_args) finally: pf.context.globals.current_trace_level = None patcher.unpatch_all() metadata = { "operations": [ (trace, {"func": func, "trace_level": trace_level}), ], } def extract_node(v): if isinstance(v, cg.Proxy): return v.node return cg.ConstantNode(value=v) outputs = pytree.PyTree(func_result) outputs = outputs.map(extract_node) input_nodes = { k: v.node for k, v in proxy_args.items() if isinstance(v, cg.Proxy) and isinstance(v.node, cg.InputPlaceholderNode) } compgraph = cg.ComputeGraph( inputs=pytree.PyTree(input_nodes), outputs=outputs, name=name, metadata=metadata, ) return compgraph