Source code for procfunc.tracer.patch

import enum
import functools
import logging
from dataclasses import dataclass
from types import ModuleType
from typing import Any, Callable, TypeAlias, TypeVar

from procfunc import compute_graph as cg
from procfunc.tracer.proxy import RngProxy
from procfunc.util import pytree

logger = logging.getLogger(__name__)

PATCHING_FLAG_ATTR = "_gen_tracing_is_patched"
_MODULE_CALLFUNC = "__call__"
_MODULE_GENFUNC = "_generate"

Tfunc = TypeVar("Tfunc")

TWrapperCreate: TypeAlias = Callable[["PatchFunctionTarget", "Patcher"], Callable]


[docs] class TraceLevel(enum.IntEnum): """ Higher = coarser. Lower = finer. The resulting graph will contain nodes which are this level """ GRAMMAR = 100 # _distribution functions RANDOM_CONTROL = 60 # pf.control.choice RANDOM_PARAMS = 50 # np.random calls, all other operations of numbers. GENERATORS = 40 NODEGROUPS = 30 # @node_function PRIMITIVES = 20
[docs] @dataclass class PatchFunctionTarget: frame: dict name: str trace_level: TraceLevel # Compared to users requested trace level to decide whether this target will be a leaf. More normalize: bool = True allow_exec: bool = False custom_trace_wrapper_create: TWrapperCreate | None = None source_name: str | None = None # used for logging/debugging mutates: list[str] | None = None # list of argument names that the function mutates
@dataclass class Patch: frame: dict[str, Any] fn_name: str orig_fn: Callable patched_fn: Callable def patch(self): raise NotImplementedError() def unpatch(self): raise NotImplementedError() @dataclass class PatchSetItem(Patch): def patch(self): self.frame[self.fn_name] = self.patched_fn def unpatch(self): self.frame[self.fn_name] = self.orig_fn @dataclass class PatchSetAttr(Patch): def patch(self): logger.debug( f"Patching {self.fn_name} in {id(self.frame)} from {self.orig_fn} to {self.patched_fn}" ) setattr(self.frame, self.fn_name, self.patched_fn) def unpatch(self): setattr(self.frame, self.fn_name, self.orig_fn) def _targets_from_module( module: ModuleType, seen: set[int], trace_level: TraceLevel, normalize: bool = False, allow_exec: bool = False, ) -> list[PatchFunctionTarget]: """ Gather functions from modules like `math`, `numpy`, `blendfunc` wherein we want to wrap everything they contain We will use such functions as primitives - we will not trace their internals We will allow them to be executed if they are called with non-dynamic args Args: module: the module to gather targets from seen: a set of ids of functions that have already been processed and should be skipped normalize: whether to try to convert args to kwargs where available. fails for some modules such as numpy allow_exec: whether to allow the function to be executed if it is called with non-dynamic args Returns: a list of PatchFunctionTargets, one for each valid callable in the module. We ignore any function starting with _ or that is not in __all__. """ results = [] exports = getattr(module, "__all__", None) for name, value in module.__dict__.items(): if ( name.startswith("_") or issubclass(type(value), type) or not callable(value) or id(value) in seen or (exports is not None and name not in exports) ): continue seen.add(id(value)) target = PatchFunctionTarget( frame=module.__dict__, name=name, source_name=module.__name__, normalize=normalize, allow_exec=allow_exec, trace_level=trace_level, ) results.append(target) # logger.debug( # f"Gathered {len(results)} targets from {module.__name__} with {normalize=} {allow_exec=}" # ) return results
[docs] class Patcher: """ A 'patch' is an invasive modification of another module (e.g numpy, math) which makes their functions behave differently for tracing This class creates patches and records them so that we can undo them later """
[docs] def __init__( self, trace_level: TraceLevel, autopatch_wrap_modules: list[tuple[ModuleType, bool, TraceLevel]] | None = None, autopatch_remove_modules: list[ModuleType] | None = None, patch_functions: list[PatchFunctionTarget] | None = None, search_scopes: list[dict] | None = None, ): self.trace_level = trace_level if autopatch_wrap_modules is None: autopatch_wrap_modules = [] if autopatch_remove_modules is None: autopatch_remove_modules = [] if patch_functions is None: patch_functions = [] if search_scopes is None: search_scopes = [] self.patched: list[Patch] = [] self.visited_frames: set[int] = set() modules_seen = set() autowrap_targets = [] for m, allow_exec, mod_trace_level in autopatch_wrap_modules: autowrap_targets += _targets_from_module( m, modules_seen, trace_level=mod_trace_level, allow_exec=allow_exec, ) ban_targets = [] for m in autopatch_remove_modules: ban_targets += _targets_from_module( m, modules_seen, trace_level=TraceLevel.PRIMITIVES ) self.patch_function_ids = { # keys ensure we can have separate unique values for same func in different frames id(target.frame[target.name]): target for target in autowrap_targets + patch_functions } self.ban_function_ids = { id(target.frame[target.name]): target for target in ban_targets } self._autopatch_wrap_modules = autopatch_wrap_modules self._autopatch_remove_modules = autopatch_remove_modules for scope in search_scopes: logger.debug(f"Adding search scope {id(scope)}") self.search_autowrap_targets(scope) logger.debug( f"{Patcher.__name__} found {len(self.patch_function_ids)=}, {len(self.ban_function_ids)=}" )
[docs] def apply_preexecute_patches( self, func: Callable, trace_level: TraceLevel | None = None, ): for target in self.patch_function_ids.values(): orig_fn = target.frame.get(target.name) if orig_fn is None: raise NotImplementedError( f"{target.name} not found in frame (possibly a builtin)" ) if getattr(orig_fn, PATCHING_FLAG_ATTR, False): continue func_wrapper = self.create_wrapper(orig_fn) patch = PatchSetItem( frame=target.frame, fn_name=target.name, orig_fn=orig_fn, patched_fn=func_wrapper, ) self.patch(patch) for _id, target in self.ban_function_ids.items(): orig_fn = target.frame.get(target.name) if orig_fn is None: raise NotImplementedError( f"{target.name} not found in frame (possibly a builtin)" ) patch = PatchSetItem( frame=target.frame, fn_name=target.name, orig_fn=orig_fn, patched_fn=_create_banned_func_wrapper(orig_fn), ) self.patch(patch) for module, _allow_exec, _trace_level in self._autopatch_wrap_modules: self.search_autowrap_targets(module.__dict__) for module in self._autopatch_remove_modules: self.search_autowrap_targets(module.__dict__) self.search_autowrap_targets(func.__globals__) # func may be itself be a target we need to wrap func = self.create_wrapper(func) return func
# call_wrapper = _create_module_call_wrapper(self) # module_call_patch = PatchSetAttr( # frame=Module, # fn_name=_MODULE_CALLFUNC, # orig_fn=_ORIG_MODULE_CALL, # patched_fn=call_wrapper, # ) # self.patch(module_call_patch) # module_getattr_patch = PatchSetAttr( # frame=Module, # fn_name="__getattribute__", # orig_fn=_ORIG_MODULE_GETATTR, # patched_fn=_create_module_getattr_wrapper(), # ) # patcher.patch(module_getattr_patch)
[docs] def create_wrapper( self, func: Callable, ): if getattr(func, PATCHING_FLAG_ATTR, False): raise ValueError( f"Function {func.__name__} is already wrapped, should have already been skipped" ) wrap_target = self.patch_function_ids.get(id(func)) ban_target = self.ban_function_ids.get(id(func)) is_wrap = wrap_target is not None is_ban = ban_target is not None # Can't wrap functions without __globals__ unless they're explicit targets if not hasattr(func, "__globals__") and not is_wrap and not is_ban: return func match is_wrap, is_ban: case (False, False): wrapper = _create_nonleaf_wrap_discover_wrapper(func, self) case (True, False): if wrap_target.custom_trace_wrapper_create is not None: wrapper = wrap_target.custom_trace_wrapper_create(wrap_target, self) elif wrap_target.trace_level <= self.trace_level: wrapper = _create_leaf_func_proxy_wrapper(wrap_target, self) else: wrapper = _create_nonleaf_wrap_discover_wrapper(func, self) case False, True: wrapper = _create_banned_func_wrapper(func) case True, True: raise ValueError( f"Got {func.__name__} which was both wrapped and banned? {wrap_target=} {ban_target=}" ) # if logger.isEnabledFor(logging.DEBUG): # logger.debug( # f"Created wrapper for {func.__name__} {id(func)} -> {id(wrapper)} {is_wrap=} {is_ban=}" # ) return wrapper
[docs] def search_autowrap_targets(self, frame: dict): # for efficiency - dont bother re-searching frames, since all the functions will be tagged/wrapped already if id(frame) in self.visited_frames: logger.debug(f"Already-visited frame {id(frame)}") return self.visited_frames.add(id(frame)) for name, value in frame.items(): skip = ( getattr(value, PATCHING_FLAG_ATTR, False) or (name.startswith("__") and name.endswith("__")) or not callable(value) ) if skip: continue patch = PatchSetItem( frame=frame, fn_name=name, orig_fn=value, patched_fn=self.create_wrapper(value), ) self.patch(patch)
[docs] def patch( self, patch: Patch, ): if hasattr(patch.orig_fn, PATCHING_FLAG_ATTR): logger.debug(f"skipping already patched {patch.orig_fn} {id(patch)}") return try: setattr(patch.patched_fn, PATCHING_FLAG_ATTR, True) except (TypeError, AttributeError) as _e: # logger.debug( # f"Failed to setattr {PATCHING_FLAG_ATTR} on {patch.patched_fn} {type(patch.patched_fn)=} {e}" # ) pass # cant cache bools on some immutable types like Tuple patch.patch() self.patched.append(patch)
[docs] def unpatch_all(self): for patch in self.patched: patch.unpatch() self.patched.clear()
def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.unpatch_all() return False
def _create_leaf_func_proxy_wrapper( target: PatchFunctionTarget, patcher: Patcher, ) -> cg.Node | Tfunc: func = target.frame[target.name] @functools.wraps(func) def wrapper(*args, **kwargs): logger.debug(f"Executing leaf_wrapper! func={func.__name__}") # zero proxy args means the user is trying to call this function with real args during tracing # e.g they might do: bias = np.zeros(3) # some functions will allow this and just execute the function and return the result def _unwrap_proxy(v): return v.node if isinstance(v, cg.Proxy) else v if target.allow_exec: all_leaves, _ = pytree.flatten((args, kwargs)) proxy_leaves = [v for v in all_leaves if isinstance(v, cg.Proxy)] if not proxy_leaves: return func(*args, **kwargs) rng_only = all(isinstance(v, RngProxy) for v in proxy_leaves) if rng_only and patcher.trace_level < TraceLevel.RANDOM_PARAMS: def _unbox_rng(v): return v.rng if isinstance(v, RngProxy) else v concrete_args = tuple( pytree.PyTree(a).map(_unbox_rng).obj() for a in args ) concrete_kwargs = { k: pytree.PyTree(v).map(_unbox_rng).obj() for k, v in kwargs.items() } return func(*concrete_args, **concrete_kwargs) if target.normalize: args, kwargs = cg.normalize_args_to_kwargs(func, args, kwargs) # Convert any Proxy args to their underlying nodes, including those nested in containers node_args = tuple(pytree.PyTree(a).map(_unwrap_proxy).obj() for a in args) node_kwargs = { k: pytree.PyTree(v).map(_unwrap_proxy).obj() for k, v in kwargs.items() } node = cg.FunctionCallNode(func=func, args=node_args, kwargs=node_kwargs) # Handle mutations by updating proxies for mutated arguments for param_name in target.mutates or []: if not (param_name in kwargs and isinstance(kwargs[param_name], cg.Proxy)): continue original_proxy = kwargs[param_name] mutated_node = cg.MutatedArgumentNode( mutator_call_node=node, original_node=original_proxy.node ) logger.debug( f"MutatedArgumentNode created: {func.__name__}({param_name}=...) -> {mutated_node}" ) original_proxy.node = mutated_node return cg.Proxy(node) if hasattr(func, "reduce"): # numpy.random.Generator breaks if these special reduce functions-inside-functions are not copied over? # TODO: handle the general case of metadata attrs on wrapped functions, I believe torch.fx does this wrapper.reduce = func.reduce setattr(wrapper, PATCHING_FLAG_ATTR, True) return wrapper def _create_nonleaf_wrap_discover_wrapper( func: Callable, patcher: Patcher, ): # TODO: should record non-leaf functions in the graph while also executing their internals? if getattr(func, PATCHING_FLAG_ATTR, False): raise ValueError(f"Function {func.__name__} is already wrapped") @functools.wraps(func) def wrapper(*args, **kwargs): frame = getattr(func, "__globals__", {}) logger.debug(f"Searching autowrap targets for {func.__name__} {id(frame)}") patcher.search_autowrap_targets(frame) logger.debug(f"Finished autowrap targets for {func.__name__} {id(frame)}") # Wrap any callable args that are registered targets but weren't reachable # via module-dict patching (e.g. passed as function-valued arguments) def maybe_wrap(v): if not callable(v): return v if getattr(v, PATCHING_FLAG_ATTR, False): return v if id(v) in patcher.patch_function_ids: logger.debug( f"maybe_wrap: wrapping callable arg {getattr(v, '__name__', v)!r} id={id(v)} found in patch_function_ids" ) return patcher.create_wrapper(v) logger.debug( f"maybe_wrap: callable arg {getattr(v, '__name__', v)!r} id={id(v)} NOT in patch_function_ids (size={len(patcher.patch_function_ids)})" ) return v args = tuple(maybe_wrap(a) for a in args) kwargs = {k: maybe_wrap(v) for k, v in kwargs.items()} return func(*args, **kwargs) setattr(wrapper, PATCHING_FLAG_ATTR, True) return wrapper def _create_banned_func_wrapper(func: Callable) -> Callable: @functools.wraps(func) def wrapped_func(*args, **kwargs): raise ValueError( f"Tracing failed - {func.__name__} is banned from use in traceable functions" ) setattr(wrapped_func, PATCHING_FLAG_ATTR, True) return wrapped_func