Source code for procfunc.codegen.identifiers

import keyword
import logging
import re
from collections import Counter, defaultdict
from typing import Iterable, Literal, TypeVar

import bpy

from procfunc import compute_graph as cg

logger = logging.getLogger(__name__)


def pascal_to_snake(name: str) -> str:
    s = re.sub(r"[./\-\s]+", "_", name or "")
    s = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", s)
    s = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s)
    return "_".join(part.lower() for part in s.split("_") if part)


def snake_to_pascal(name: str) -> str:
    return "".join(part.capitalize() for part in name.split("_"))


def bpy_name_to_pythonid(name: str) -> str:
    # remove .00X
    # name = re.sub(r"\.\d+$", "", name)

    # blender often pascal
    name = pascal_to_snake(name)

    # must come after pascal_to_snake or else can jam noncaps together?
    name = name.replace(".", "_")
    name = name.replace(" ", "_")
    name = name.replace(",", "_")
    name = name.replace("=", "")

    name = name.lower()

    # drop anything that isn't valid in a python identifier (e.g. parentheses
    # in socket names like "Joint ID (do not set)")
    name = re.sub(r"[^0-9a-z_]", "_", name)

    name = re.sub(r"_+", "_", name).strip("_")

    # move number terms to end
    parts = name.split("_")
    for i in range(len(parts)):
        if parts[0][0].isdigit():
            parts = parts[1:] + [parts[0]]

    name = "_".join(parts)

    if keyword.iskeyword(name):
        name = name + "_"

    return name


def is_valid_snake_identifier(name: str) -> bool:
    if name == "":
        return False
    if name is None:
        raise ValueError(f"Name {name=!r} is None")
    if "." in name or " " in name:
        return False
    if name[0].isdigit():
        return False
    if name != name.lower():
        return False  # this is opinionated
    if not name.isidentifier() or keyword.iskeyword(name):
        return False
    return True


def _find_reducible(
    curr_names: dict[int, str],
    mode: Literal["prefix", "postfix"] = "postfix",
    separator: str = ".",
    existing: dict[int, str] | None = None,
    limit_min_fields: int = 2,
    min_str_len_reduce: int = 10,
) -> set[int]:
    new_matched_counts = defaultdict(lambda: [[], 0])

    if existing is not None:
        for node_id, name in existing.items():
            entry = new_matched_counts[name]
            entry[0].append(node_id)
            entry[1] += 1

    def _next_name(name: str) -> str:
        return (
            name[name.find(separator) + 1 :]
            if mode == "postfix"
            else name[name.find(separator) + 1 :]
        )

    def _stop_reduce(newname: str) -> bool:
        return (
            len(newname.split(separator)) <= limit_min_fields
            or len(newname) < min_str_len_reduce
            or not is_valid_snake_identifier(newname)
        )

    for node_id, name in curr_names.items():
        matchval = _next_name(name)
        if _stop_reduce(matchval):
            matchval = name
        if separator in name:
            new_matched_counts[matchval][0].append(node_id)
        new_matched_counts[matchval][1] += 1

    continue_reduce_ids = set()
    for name, (ids, count) in new_matched_counts.items():
        if count != 1 or _stop_reduce(_next_name(name)):
            continue
        filtered = [
            nodeid for nodeid in ids if existing is None or nodeid not in existing
        ]
        continue_reduce_ids.update(filtered)

    return continue_reduce_ids


def duplicate_names(names: dict[int, str]) -> list[str]:
    return [
        (name, count) for name, count in Counter(names.values()).items() if count > 1
    ]


def reduce_name_prefix_suffix(
    names: dict[int, str],
    mode: Literal["prefix", "postfix"] = "postfix",
    separator: str = ".",
    existing: dict[int, str] | None = None,
) -> dict[int, str]:
    """
    remove the maximum number of `.`-separated prefix-fields from every name,
    but make sure the names remain as unique as possible.
    """

    # logger.debug(f"{reduce_name_prefix_suffix.__name__} for {mode=} {separator=} reducing {len(names)=}")

    curr_names = names.copy()

    while continue_reduce := _find_reducible(
        curr_names, mode=mode, separator=separator, existing=existing
    ):
        logger.debug(f"{reduce_name_prefix_suffix.__name__} reducing ")
        for node_id in continue_reduce:
            # TODO: this is not correct, we need to find the longest prefix that is shared by all nodes in the set.
            if mode == "prefix":
                curr_names[node_id] = curr_names[node_id].split(separator, 1)[1]
            else:
                raise NotImplementedError(f"Unsupported mode: {mode}")
                curr_names[node_id] = curr_names[node_id].split(separator, 1)[0]

    output_dups = duplicate_names(curr_names)
    if output_dups:  # and not duplicate_names(names):
        raise ValueError(
            f"{reduce_name_prefix_suffix.__name__} created duplicate names {output_dups} - "
            f"should be impossible, please contact the developers. input names were {names.values()}"
        )

    return curr_names


def apply_panel_names_to_input_names(
    node_tree: bpy.types.NodeTree,
    names: dict[str, str],
    only_dedup: bool = False,
) -> dict[str, str]:
    """
    apply the names of the panel sockets to the input names

    Args:
        node_tree: the node tree to apply the panel names to
        names: the input names to apply the panel names to
        only_dedup: If True, only add the panel name if there are multiple sockets with the same name
    """

    panel_sockets = [
        socket
        for socket in node_tree.interface.items_tree.values()
        if socket.item_type == "PANEL"
    ]

    if len(panel_sockets) == 0:
        return names

    if logger.isEnabledFor(logging.DEBUG):
        socketnames = [socket.name for socket in panel_sockets]
        logger.debug(
            f"apply_panel_names_to_input_names on {node_tree.name=} {only_dedup=} found panel sockets {socketnames=}"
        )

    basename_counts = defaultdict(lambda: 0)
    for socket in node_tree.interface.items_tree.values():
        if socket.item_type == "PANEL":
            continue
        basename_counts[socket.name] += 1

    seen_panel_members = set()
    for panel_socket in panel_sockets:
        for socket in panel_socket.interface_items.values():
            if socket.identifier in seen_panel_members:
                raise ValueError(
                    f"Panel socket {socket.identifier=} {socket.name=} appeared in multiple panels. "
                    "Contact the developers if this is needed."
                )
            seen_panel_members.add(socket.identifier)

            # we will leave .-separation for now so that we can try to dedup unnecessary ones
            panel_name = bpy_name_to_pythonid(panel_socket.name)
            socket_name = bpy_name_to_pythonid(names[socket.identifier])
            if only_dedup and basename_counts[socket.name] == 1:
                names[socket.identifier] = socket_name
            else:
                names[socket.identifier] = panel_name + "." + socket_name

    # any . that werent removed must now become _ to be valid python identifiers
    for k, v in names.items():
        names[k] = v.replace(".", "_")

    return names


TKey = TypeVar("TKey")


[docs] def dedup_names_with_suffix( names: dict[TKey, str], existing: dict[TKey, str] | None = None, separator: str = ".", order: Iterable[TKey] | None = None, first_use_suffix: bool = False, ) -> dict[TKey, str]: seen_counts: dict[str, int] = {} result_names: dict[TKey, str] = {} for nid, name in names.items(): parts = name.split(separator) if len(parts) > 1 and parts[-1].isdigit(): newname = separator.join(parts[:-1]) names[nid] = newname total_counts = Counter(names.values()) if existing is not None: for name in existing.values(): if not isinstance(name, str): continue seen_counts[name] = 1 if "." in name: seen_counts[name.split(".")[0]] = 1 if order is None: order = list(names.keys()) for node_id in order: if node_id not in names: continue if existing and node_id in existing: continue orig_name = names[node_id] count = seen_counts.get(orig_name, 0) if count == 0 and not (first_use_suffix and total_counts[orig_name] > 1): result_names[node_id] = orig_name else: newname = f"{orig_name}{separator}{count}" while newname in seen_counts: count += 1 newname = f"{orig_name}{separator}{count}" seen_counts[newname] = count + 1 result_names[node_id] = newname seen_counts[orig_name] = count + 1 if dups := duplicate_names(result_names): raise ValueError( f"{dedup_names_with_suffix.__name__} created duplicate names {dups} - " "should be impossible, please contact the developers. " f"{result_names.values()=}" ) return result_names
def _propogate_one_step( child: cg.Node, parent: cg.Node | None, argname: str | int, node_names_parts: dict[int, list[str]], limit_n_fields: int | None, usages: dict[int, list[cg.Node]], fold_map: dict[int, bool] | None, skip_propogate_words: list[str] | None = None, ): if parent is None: assert isinstance(argname, str), (argname, parent, child) return [argname] elif ( fold_map is not None and fold_map.get(id(parent), False) and id(parent) not in node_names_parts ): assert isinstance(argname, str), (argname, parent, child) return [argname] elif id(parent) not in node_names_parts: raise ValueError( f"Visited {id(child)} as {argname=} of {id(parent)} {parent=} before parent was named?" ) parent_parts = node_names_parts[id(parent)] while len(parent_parts) > 1 and parent_parts[-1] in skip_propogate_words: parent_parts = parent_parts[:-1] if limit_n_fields is not None and len(parent_parts) > limit_n_fields - 1: prefix = parent_parts[0] if parent_parts[0] != "result" else parent_parts[1] return [prefix, argname] elif fold_map is not None and all( fold_map.get(id(usage_parent), False) for usage_parent in usages[id(child)] ): # no need to have a different name from parent if that parent name is always folded / never used return parent_parts # elif ( # len(usages[id(child)]) == 1 # and child.kind == parent.kind # and child.target == parent.target # ): # # repeatedly applying a function to the same variable can reuse the same name # return parent_parts else: return parent_parts + [argname] def propogate_names_with_parts( graph: cg.ComputeGraph, fixed_names: dict[int, str] | None = None, seen_subgraphs: set[int] | None = None, limit_n_fields: int | None = None, fold_map: dict[int, bool] | None = None, skip_propogate_words: list[str] | None = None, ) -> dict[int, list[str]]: # logger.debug(f"{propogate_names_with_parts.__name__} for {graph.name}") node_names_parts: dict[int, list[str]] = {} fixed_name_ids: set[int] = set() if fixed_names is not None: for nid, v in fixed_names.items(): node_names_parts[nid] = list(v.split("_")) fixed_name_ids.add(nid) for i, (name, outnode) in enumerate(graph.outputs.items()): if name == "": if len(graph.outputs) == 1: name = graph.name + "_result" name = f"result_{i}" node_names_parts[id(outnode)] = [name] if seen_subgraphs is None: seen_subgraphs = set() usages = cg.usages_per_node(graph) for argname, parent, child in cg.traverse_breadth_first( graph, yield_parent=True, yield_name=True, ): if parent is None: continue assert isinstance(parent, cg.Node), (parent, child, argname) assert argname is not None, (parent, child, argname) result = _propogate_one_step( child, parent, argname, node_names_parts, limit_n_fields, usages, fold_map, skip_propogate_words=skip_propogate_words, ) resname = "_".join(result) if not is_valid_snake_identifier(resname): parent_name = "_".join(node_names_parts.get(id(parent), [])) logger.warning( f"Propogate from {parent} {parent_name=} to {child=} produced invalid identifier {resname=}" ) continue if id(child) not in node_names_parts: node_names_parts[id(child)] = result elif id(child) not in fixed_name_ids: newlen = sum(len(x) for x in result) oldlen = sum(len(x) for x in node_names_parts[id(child)]) if newlen < oldlen: node_names_parts[id(child)] = result return node_names_parts def _infill_names_propogate( graph: cg.ComputeGraph, node_names: dict[int, str], fold_map: dict[int, bool], existing: dict[int, str], skip_propogate_words: list[str] | None = None, ): propogated = propogate_names_with_parts( graph, fixed_names=node_names, limit_n_fields=4, fold_map=fold_map, skip_propogate_words=skip_propogate_words, ) propogated_names = { id: "_".join(parts) for id, parts in propogated.items() if not fold_map.get(id, False) } propogated_names = dedup_names_with_suffix( propogated_names, existing={**existing, **node_names}, separator="_", order=reversed([id(n) for n in cg.traverse_depth_first(graph)]), first_use_suffix=True, ) # propogated_names = reduce_name_prefix_suffix( # propogated_names, # mode="prefix", # separator="_", # existing={**existing, **node_names}, # ) if intersection := set(propogated_names.values()).intersection( set(node_names.values()) ): raise ValueError(f"Propogated and node names had overlap: {intersection=}") return propogated_names def _name_from_functionality(node: cg.Node) -> str: match node: case cg.FunctionCallNode(func=func): return func.__name__ case cg.MethodCallNode(method_name=method_name): return method_name case cg.SubgraphCallNode(subgraph=subgraph): return subgraph.name case cg.GetAttributeNode(attribute_name=attribute_name): return attribute_name case cg.MutatedArgumentNode(): return _name_from_functionality(node.args[0]) case cg.ConstantNode(): return "const" case cg.InputPlaceholderNode(input_name=name): return name if name else "input" case _: raise NotImplementedError(f"Unsupported node: {node}") def _infill_names_function( graph: cg.ComputeGraph, node_names: dict[int, str], fold_map: dict[int, bool], existing: dict[int, str], ): result = {} for node in cg.traverse_depth_first(graph): if id(node) in existing: continue if fold_map[id(node)]: continue if node_names.get(id(node), None) is not None: continue result[id(node)] = _name_from_functionality(node) result = dedup_names_with_suffix( result, existing={**existing, **node_names}, separator="_", order=reversed([id(n) for n in cg.traverse_depth_first(graph)]), ) if intersection := set(result.values()).intersection(set(node_names.values())): raise ValueError(f"Function and node names had overlap: {intersection=}") return result def _fixed_name_for_node( node: cg.Node, scope_expressions: dict[int, str], avoid_parts: list[str] = [], ) -> str | None: """ If `node` is significant enough to be given a name, rather than being named after its later usage, then we will return a str name for it here """ match node: case cg.Node(metadata={"varname": varname}): return varname case cg.SubgraphCallNode(subgraph=subgraph): return subgraph.name + "_result" case cg.FunctionCallNode(func=func): func_resolve = scope_expressions.get(id(func), None) module = getattr(func, "__module__", "") or "" if not isinstance(func_resolve, str): return None if not ( ".geo." in func_resolve or ".shader." in func_resolve or module.startswith("infinigen_v2.") ): return None name = "_".join( part for part in func.__name__.split("_") if part not in avoid_parts ) if "." in func_resolve: module_alias = func_resolve.rsplit(".", 1)[0] if name == module_alias: name = name + "_result" return name case _: return None NONDESCRIPTIVE_NODE_NAME_PARTS = [ "mesh", "geometry", "bsdf", "result", "distribution", ] def nodenames_from_fixed_and_infill( graph: cg.ComputeGraph, fold_map: dict[int, bool], scope_expressions: dict[int, str], avoid_parts: list[str] | None = None, ) -> dict[int, str]: if avoid_parts is None: avoid_parts = NONDESCRIPTIVE_NODE_NAME_PARTS node_names = {} for name, output in graph.outputs.items(): if name == "": name = _name_from_functionality(output) if not fold_map.get(id(output), False): node_names[id(output)] = name for name, input in graph.inputs.items(): node_names[id(input)] = name for node in cg.traverse_depth_first(graph): if fold_map.get(id(node), False): continue if name := _fixed_name_for_node( node, scope_expressions, avoid_parts=avoid_parts ): node_names[id(node)] = name node_names = dedup_names_with_suffix( node_names, existing=scope_expressions, separator="_", ) # node_names = reduce_name_prefix_suffix( # node_names, # mode="prefix", # separator="_", # existing=scope_expressions, # ) result = _infill_names_propogate( graph, node_names, fold_map, existing=scope_expressions, skip_propogate_words=avoid_parts, ) result.update(node_names) for name in result.values(): if not is_valid_snake_identifier(name): logger.warning(f"Invalid node name: {name}") return result