Source code for procfunc.tracer.proxy

from dataclasses import dataclass

import numpy as np

from procfunc import compute_graph as cg


[docs] @dataclass class RngProxy(cg.Proxy): """ We will do specialcase handling to trace rng nodes through the graph This allows us to know what the results of choice() and other random control flow will be, BUT only if the user has kept the rng non-dirty, i.e. the rng used for choice descends from the root via only spawn() calls This is essential so that the result of choice() during tracing is the same as the result of choice() during generation """ rng: np.random.Generator # true if any randomness has been taken from this rng besides splitting it via spawn dirty: bool = False
[docs] def __init__( self, node: cg.Node, rng: np.random.Generator, dirty: bool = False, ): super().__init__(node) node.metadata["known_type"] = np.random.Generator node.metadata["varname"] = "rng" self.rng = rng self.dirty = dirty
[docs] def spawn(self, n_children: int) -> "RngSpawnResultProxy": """ Returns a mock node which can only be unpacked into its constitutent mock rngs """ assert isinstance(n_children, int) spawn_node = cg.MethodCallNode( callee=self.node, method_name="spawn", args=(n_children,), kwargs={}, ) child_rngs = list(self.rng.spawn(n_children)) return RngSpawnResultProxy( node=spawn_node, from_rng_proxy=self, child_rngs=child_rngs, dirty=self.dirty, )
def __getattr__(self, name: str): is_dirty_op = ( not name.startswith("_") and not name == "spawn" and hasattr(self.rng, name) ) if is_dirty_op: self.dirty = True if not hasattr(self.rng, name): raise AttributeError( f"__getattr__ {name} is invalid because {type(self.rng)} has no attribute {name}" ) return super().__getattr__(name)
[docs] @dataclass class RngSpawnResultProxy(cg.Proxy): from_rng_proxy: "RngProxy" child_rngs: list[np.random.Generator] dirty: bool = False
[docs] def __init__( self, node: cg.Node, from_rng_proxy: "RngProxy", child_rngs: list[np.random.Generator], dirty: bool = False, ): super().__init__(node) node.metadata["varname"] = "rngs" self.from_rng_proxy = from_rng_proxy self.child_rngs = child_rngs self.dirty = dirty if self.dirty: import warnings warnings.warn( "RngSpawnResultProxy has dirty=True, tracing results may be incomplete" )
def __getitem__(self, idx: int) -> "RngProxy": if idx < 0 or idx >= len(self.child_rngs): raise IndexError( f"Index {idx} out of range for {len(self.child_rngs)} children" ) getitem_node = cg.MethodCallNode( callee=self.node, method_name="__getitem__", args=(idx,), kwargs={}, metadata={"known_value_type": np.random.Generator}, ) return RngProxy( node=getitem_node, rng=self.child_rngs[idx], dirty=self.dirty, ) def __iter__(self): """ specialcase to allow __iter__() since SpawnResult has known size allows x,y,z = rng.spawn(3) syntax to work in functions that need to be traceable """ for i in range(len(self.child_rngs)): yield self.__getitem__(i)