Source code for procfunc.compute_graph.proxy

"""General-purpose Proxy wrapper for Node with all dunders."""

import operator
from dataclasses import dataclass
from typing import Generic, TypeVar

from .node import FunctionCallNode, GetAttributeNode, MethodCallNode, Node

T = TypeVar("T")


[docs] @dataclass class Proxy(Generic[T]): """General-purpose wrapper for Node that provides all dunder methods.""" node: Node def __repr__(self): return f"Proxy({self.node!r})" def __getattr__(self, attr: str) -> "AttributeProxy": node = GetAttributeNode(source=self.node, attribute_name=attr) return AttributeProxy(node) def __call__(self, *args, **kwargs) -> "Proxy": raise NotImplementedError("Proxy.__call__ is not implemented") def __len__(self) -> int: raise ValueError( "Tracing does not allow __len__ since real values are not evaluated" ) def __iter__(self): raise ValueError( "Proxy does not support __iter__. Use explicit indexing instead." ) def __getitem__(self, idx) -> "Proxy": idx_node = idx.node if isinstance(idx, Proxy) else idx getitem_node = FunctionCallNode( func=operator.getitem, args=(self.node, idx_node), kwargs={}, ) return Proxy(getitem_node) def __bool__(self): raise ValueError( "Base Proxy does not allow __bool__ during tracing since real values are unknown" )
NODE_DUNDER_METHODS = { "__add__": operator.add, "__sub__": operator.sub, "__mul__": operator.mul, "__truediv__": operator.truediv, "__floordiv__": operator.floordiv, "__mod__": operator.mod, "__pow__": operator.pow, "__lshift__": operator.lshift, "__rshift__": operator.rshift, "__and__": operator.and_, "__xor__": operator.xor, "__or__": operator.or_, "__neg__": operator.neg, "__pos__": operator.pos, "__abs__": operator.abs, "__invert__": operator.invert, "__eq__": operator.eq, "__ne__": operator.ne, "__lt__": operator.lt, "__le__": operator.le, "__gt__": operator.gt, "__ge__": operator.ge, } NODE_REFLECTABLE_METHODS = [ "add", "sub", "mul", "floordiv", "truediv", "div", "mod", "pow", "lshift", "rshift", "and_", "or_", "xor", "getitem", "matmul", ] def _add_proxy_operator(cls, name, operator_func): def proxy_method(self, *args, **kwargs): # Convert any Proxy args to their underlying nodes node_args = tuple(arg.node if isinstance(arg, Proxy) else arg for arg in args) node_kwargs = { k: v.node if isinstance(v, Proxy) else v for k, v in kwargs.items() } node = FunctionCallNode( func=operator_func, args=(self.node, *node_args), kwargs=node_kwargs, ) return Proxy(node) setattr(cls, name, proxy_method) def _add_proxy_reflection(cls, name: str): # __rmul__(self, rhs) means rhs * self — use the same operator but swap arg order fwd_dunder = f"__{name.rstrip('_')}__" operator_func = NODE_DUNDER_METHODS.get(fwd_dunder) if operator_func is None: return # no matching forward op, skip def proxy_method(self, rhs): rhs_node = rhs.node if isinstance(rhs, Proxy) else rhs node = FunctionCallNode( func=operator_func, args=(rhs_node, self.node), kwargs={}, ) return Proxy(node) setattr(cls, f"__r{name}__", proxy_method) # Add all dunder methods to Proxy for name, operator_func in NODE_DUNDER_METHODS.items(): _add_proxy_operator(Proxy, name, operator_func) # Add reflected methods for name in NODE_REFLECTABLE_METHODS: _add_proxy_reflection(Proxy, name)
[docs] @dataclass class AttributeProxy(Proxy): """Special proxy for attribute access that supports peekthrough optimization"""
[docs] def __init__(self, node: Node): super().__init__(node) assert isinstance(node, GetAttributeNode), node
def __call__(self, *args, **kwargs) -> Proxy: """ Someone did func = proxy.xyz, then func(), or equivelantly thats just proxy.xyz(). We can convert that to just a single node which is a method call on the obj node. torch.fx.symbolic_trace calls this a _peekthrough optimization_ TODO: this means that `self` is very often dropped from the graph. need to account for this if we check for drops. """ assert isinstance(self.node, GetAttributeNode), self.node # Convert any Proxy args to their underlying nodes node_args = tuple(arg.node if isinstance(arg, Proxy) else arg for arg in args) node_kwargs = { k: v.node if isinstance(v, Proxy) else v for k, v in kwargs.items() } # self.node.args[0] is the source node that we're calling the method on call_node = MethodCallNode( callee=self.node.args[0], method_name=self.node.attribute_name, args=node_args, kwargs=node_kwargs, ) return Proxy(call_node)