Source code for procfunc.transforms.infer_distribution

import logging
from collections import defaultdict
from typing import Any

import numpy as np

from procfunc import compute_graph as cg
from procfunc import types as t
from procfunc.color import hsv_color
from procfunc.random import randint, uniform
from procfunc.util import pytree

logger = logging.getLogger(__name__)


class TODO:
    """
    An undefined type we leave in the graph so the codegen clearly marks that it should be
    filled in by the user.
    """

    def __repr__(self):
        return "TODO()"


def todo():
    return cg.ConstantNode(value=TODO())


_REDUCE_TYPES = (
    float,
    int,
    tuple,
    t.Color,
    np.ndarray,
)


def _minmax_arraylike(
    values: list[float | int | tuple | np.ndarray | t.Color | t.Vector],
) -> tuple[np.ndarray, np.ndarray]:
    low = np.array(values[0])
    high = np.array(values[0])
    for v in values[1:]:
        low = np.minimum(low, np.array(v))
        high = np.maximum(high, np.array(v))
    return low, high


def _infer_hypercube_differing(
    rng_node: cg.Node, key: str, all_kwargs_k: list[Any]
) -> Any | TODO:
    # if logger.isEnabledFor(logging.DEBUG):
    #    logger.debug(f"{_infer_hypercube_differing.__name__} {key=} {all_kwargs_k=}")

    if len(set(type(v) for v in all_kwargs_k)) != 1:
        return todo()

    if all(np.array(v == all_kwargs_k[0]).all() for v in all_kwargs_k):
        return all_kwargs_k[0]

    if not all(isinstance(v, _REDUCE_TYPES) for v in all_kwargs_k):
        return todo()

    low, high = _minmax_arraylike(all_kwargs_k)
    res = cg.FunctionCallNode(uniform, args=(rng_node, low, high), kwargs={})
    # res.metadata["prefer_inline"] = isinstance(low, (float, int)) # TODO currently ignored
    return res


def infer_hypercube_differing_node(
    rng_node: cg.Node,
    nodes: list[cg.Node],
    memo: dict[int, cg.Node] | None = None,
) -> cg.Node | TODO | None:
    if memo is None:
        memo = {}

    if all(node is None for node in nodes):
        return None
    elif any(node is None for node in nodes):
        logger.debug(f"{infer_hypercube_differing_node=} exiting due to None")
        return todo()

    k = id(nodes[0])
    if k in memo:
        return memo[k]

    # logger.debug(f"{infer_hypercube_differing_node=} {nodes[0]=}")

    kinds = {type(node) for node in nodes}
    if len(kinds) > 1:
        logger.debug(
            f"{infer_hypercube_differing_node=} exiting for mismatching {kinds=}"
        )
        return todo()

    targets = {node.target for node in nodes}
    if len(targets) > 1:
        logger.debug(
            f"{infer_hypercube_differing_node=} exiting for mismatching {targets=}"
        )
        return todo()

    nargs = set(len(node.args) for node in nodes)
    nkwargs = set(len(node.kwargs) for node in nodes)
    if len(nargs) > 1 or len(nkwargs) > 1:
        logger.debug(
            f"{infer_hypercube_differing_node=} exiting for differing {nargs=} {nkwargs=}"
        )
        return todo()

    def _infer_differing(key, argvals):
        if all(isinstance(v, cg.Node) for v in argvals):
            return infer_hypercube_differing_node(rng_node, argvals, memo)
        elif all(v is None for v in argvals):
            return None
        else:
            return _infer_hypercube_differing(rng_node, key, argvals)

    args = [
        _infer_differing(str(i), [nodes[j].args[i] for j in range(len(nodes))])
        for i in range(len(nodes[0].args))
    ]
    kwargs = {
        k: _infer_differing(k, [nodes[j].kwargs[k] for j in range(len(nodes))])
        for k in nodes[0].kwargs.keys()
    }

    res = cg.Node(nodes[0].target, nodes[0].kind, tuple(args), kwargs)
    memo[id(nodes[0])] = res

    return res


[docs] def infer_distribution_hypercube( graphs: list[cg.ComputeGraph], memo: dict[int, cg.Node] | None = None, ) -> list[cg.ComputeGraph]: memo = {} rng_node = cg.InputPlaceholderNode( default_value=None, metadata={ "varname": "rng", "known_value_type": "pf.RNG", # TODO use actual type and resolve to string }, ) if len(graphs) <= 1: raise ValueError( f"{infer_distribution_hypercube.__name__} expected at least 2 graphs, got {len(graphs)}" ) all_outputs = [g.outputs.dict() for g in graphs] outputs_mapped = { k: infer_hypercube_differing_node(rng_node, [g[k] for g in all_outputs], memo) for k in all_outputs[0].keys() } all_inputs = [g.inputs.obj() for g in graphs] inputs_mapped = {k: memo.get(id(all_inputs[0][k])) for k in all_inputs[0].keys()} n_inputs_missing = sum(1 for v in inputs_mapped.values() if v is None) if n_inputs_missing > 0: logger.warning( f"{infer_distribution_hypercube=} {n_inputs_missing=} {inputs_mapped} {all_inputs[0]}" ) inputs_mapped = {k: v for k, v in inputs_mapped.items() if v is not None} # find a common prefix of the graph names lens = [len(graph.name) for graph in graphs] for i in range(min(lens), 0, -1): new_names = {graph.name[:i] for graph in graphs} if len(new_names) == 1: break prefix = graphs[0].name[:i].strip("_") res = cg.ComputeGraph( inputs=pytree.PyTree({**inputs_mapped, "rng": rng_node}), outputs=pytree.PyTree(t.Material(**outputs_mapped)), name=f"{prefix}_distribution", metadata={"func": infer_distribution_hypercube}, ) if logger.isEnabledFor(logging.DEBUG): n_inp_graph_nodes = len(list(cg.traverse_depth_first(graphs[0]))) n_out_graph_nodes = len(list(cg.traverse_depth_first(res))) logger.debug( f"{infer_distribution_hypercube=} transformed {n_inp_graph_nodes=} to {n_out_graph_nodes=}" ) return res
def _reduce_const(value: Any) -> cg.Node: if isinstance(value, np.ndarray) and len(value) == 1: value = value[0] return cg.ConstantNode(value=value) def _infer_argument_distribution( values: list[Any], rng_node: cg.Node, colors_to_hsv: bool = True, use_randint: bool = False, ) -> cg.Node: if all(np.allclose(values[0], x) for x in values[1:]): # argument had only one valude (usually the default_value) return _reduce_const(values[0]) elif use_randint and all( isinstance(x, (int, float)) and np.isclose(x, int(x)) for x in values ): # argument is integer range low, high = _minmax_arraylike(values) return cg.FunctionCallNode( randint, args=(rng_node, int(low), int(high)), kwargs={} ) elif colors_to_hsv and any(isinstance(kv, t.Color) for kv in values): # treat color ranges as hsv ranges assert all(isinstance(kv, t.Color) for kv in values), values low, high = _minmax_arraylike([c.hsv for c in values]) hsv = cg.FunctionCallNode(uniform, args=(rng_node, low, high), kwargs={}) return cg.FunctionCallNode(hsv_color, args=(), kwargs={"hsv": hsv}) else: # regular argument - usually float and numpy arrays low, high = _minmax_arraylike(values) return cg.FunctionCallNode(uniform, args=(rng_node, low, high), kwargs={}) def _infer_distribution_from_callnodes( callnodes: list[cg.Node], subgraph: cg.ComputeGraph, supported_types: tuple[type], colors_to_hsv: bool, use_randint: bool, ) -> cg.ComputeGraph: new_inputs = {} rng_node = cg.InputPlaceholderNode( name="rng", default_value=None, metadata={ "known_value_type": "pf.RNG", # TODO use actual type and resolve to string }, ) base_inputs = subgraph.inputs.dict() for k in base_inputs.keys(): default_value = base_inputs[k].kwargs.get("default_value") kwarg_values = [cn.kwargs.get(k, default_value) for cn in callnodes] kwarg_values = [x for x in kwarg_values if x is not None] if any(isinstance(v, cg.Node) for v in kwarg_values): # argument had dynamic values connected up, make it a functionar gument orig_input_type = base_inputs[k].metadata.get("known_value_type", None) new_inputs[k] = cg.InputPlaceholderNode( name=k, default_value=None, metadata={ "known_value_type": orig_input_type, }, ) elif all(isinstance(v, supported_types) for v in kwarg_values): new_inputs[k] = _infer_argument_distribution( kwarg_values, rng_node, colors_to_hsv, use_randint=use_randint ) else: uniq_types = set(type(v) for v in kwarg_values) logger.warning( f"Could not infer distribution for {subgraph.name=} {k=} {uniq_types=}" ) new_inputs[k] = todo() result_call = cg.SubgraphCallNode(subgraph=subgraph, args=(), kwargs=new_inputs) placeholders = {"rng": rng_node} placeholders.update( {k: v for k, v in new_inputs.items() if isinstance(v, cg.InputPlaceholderNode)} ) if len(placeholders) == len(new_inputs): return None graph = cg.ComputeGraph( inputs=pytree.PyTree(placeholders), outputs=pytree.PyTree(result_call), name=f"{subgraph.name}_distribution", metadata={"func": infer_nodegroup_distributions}, ) return graph
[docs] def infer_nodegroup_distributions( graphs: list[cg.ComputeGraph], supported_types: tuple[type] = (float, int, tuple, t.Vector, t.Color, np.ndarray), colors_to_hsv: bool = True, use_randint: bool = False, ) -> list[cg.ComputeGraph]: """ Find all multi-use subgraphs within the graphs, and compute the hypercube distribution for all the numeric parameters ever used with each subgraph """ subgraph_usages = defaultdict(list) # id to list of call nodes for graph in graphs: for callnode, subgraph in cg.traverse_nested_graphs( graph, yield_call_nodes=True ): subgraph_usages[id(subgraph)].append((callnode, subgraph)) result_distrib_fns = [] for subgraph_id, calltuples in subgraph_usages.items(): if len(calltuples) <= 1: continue distrib_fn = _infer_distribution_from_callnodes( callnodes=[ct[0] for ct in calltuples], subgraph=calltuples[0][1], supported_types=supported_types, colors_to_hsv=colors_to_hsv, use_randint=use_randint, ) if distrib_fn is not None: result_distrib_fns.append(distrib_fn) return result_distrib_fns