Source code for icomo.tree_tools

"""Tools for working with nested structures of lists, tuples or dictionaries."""

from collections.abc import Generator, Sequence
from types import EllipsisType
from typing import Any, Optional

from jaxtyping import ArrayLike, PyTree


[docs] def walk_tree(tree: PyTree) -> Generator[tuple[list, Any]]: """Walk through a tree and yield the indices and values of the leaves. Parameters ---------- tree : The tree to walk through. Can only be a nested tree of lists, tuples and/or dictionaries. Yields ------ indices, value The indices as a list and values of the leaves. """ cursor = [] generators_list = [] value = tree key = None while True: if isinstance(value, dict | list | tuple): cursor.append(key) if isinstance(value, dict): generator = iter(value.items()) else: generator = enumerate(value) generators_list.append(generator) key, value = next(generators_list[-1], (None, None)) cursor[-1] = key if key is None: generators_list.pop() cursor.pop() if len(generators_list) == 0: break continue elif not isinstance(value, dict | list | tuple): yield cursor, value
[docs] def nested_indexing( tree: PyTree, indices: str | int | Sequence[str | int], add: Optional[ArrayLike] = None, at: int | EllipsisType | slice | None | Sequence[EllipsisType | slice | int] = None, ) -> Any: """Return the element of a nested structure of lists or tuples. Parameters ---------- tree : The nested structure of lists or tuples. indices : The indices of the element to return. add : optional The element to add to the nested structure of lists or tuples. at : optional Specifies the position where to add the element. Returns ------- element The element of the nested structure of lists or tuples. """ element = tree if not isinstance(indices, tuple | list): indices = [indices] for depth, index in enumerate(indices): if depth == len(indices) - 1: if add is not None: if at is not None: element[index] = element[index].at[at].add(add) else: element[index] = element[index] + add element = element[index] return element