Source code for dask.graph_manipulation

"""Tools to modify already existing dask graphs. Unlike in :mod:`dask.optimization`, the
output collections produced by this module are typically not functionally equivalent to
their inputs.
"""

from __future__ import annotations

import uuid
from collections.abc import Callable, Hashable
from typing import Literal, TypeVar

from dask.base import (
    clone_key,
    get_collection_names,
    get_name_from_key,
    replace_name_in_key,
    tokenize,
    unpack_collections,
)
from dask.blockwise import blockwise
from dask.core import flatten
from dask.delayed import Delayed, delayed
from dask.highlevelgraph import HighLevelGraph, Layer, MaterializedLayer
from dask.typing import Graph, Key

__all__ = ("bind", "checkpoint", "clone", "wait_on")

T = TypeVar("T")


[docs]def checkpoint( *collections, split_every: float | Literal[False] | None = None, ) -> Delayed: """Build a :doc:`delayed` which waits until all chunks of the input collection(s) have been computed before returning None. Parameters ---------- collections Zero or more Dask collections or nested data structures containing zero or more collections split_every: int >= 2 or False, optional Determines the depth of the recursive aggregation. If greater than the number of input keys, the aggregation will be performed in multiple steps; the depth of the aggregation graph will be :math:`log_{split_every}(input keys)`. Setting to a low value can reduce cache size and network transfers, at the cost of more CPU and a larger dask graph. Set to False to disable. Defaults to 8. Returns ------- :doc:`delayed` yielding None """ if split_every is None: split_every = 8 elif split_every is not False: split_every = int(split_every) if split_every < 2: raise ValueError("split_every must be False, None, or >= 2") collections, _ = unpack_collections(*collections) if len(collections) == 1: return _checkpoint_one(collections[0], split_every) else: return delayed(chunks.checkpoint)( *(_checkpoint_one(c, split_every) for c in collections) )
def _checkpoint_one(collection, split_every) -> Delayed: tok = tokenize(collection) name = "checkpoint-" + tok keys_iter = flatten(collection.__dask_keys__()) try: next(keys_iter) next(keys_iter) except StopIteration: # Collection has 0 or 1 keys; no need for a map step layer: Graph = {name: (chunks.checkpoint, collection.__dask_keys__())} dsk = HighLevelGraph.from_collections(name, layer, dependencies=(collection,)) return Delayed(name, dsk) # Collection has 2+ keys; apply a two-step map->reduce algorithm so that we # transfer over the network and store in RAM only a handful of None's instead of # the full computed collection's contents dsks = [] map_names = set() map_keys = [] for prev_name in get_collection_names(collection): map_name = "checkpoint_map-" + tokenize(prev_name, tok) map_names.add(map_name) map_layer = _build_map_layer(chunks.checkpoint, prev_name, map_name, collection) map_keys += list(map_layer.get_output_keys()) dsks.append( HighLevelGraph.from_collections( map_name, map_layer, dependencies=(collection,) ) ) # recursive aggregation reduce_layer: dict = {} while split_every and len(map_keys) > split_every: k = (name, len(reduce_layer)) reduce_layer[k] = (chunks.checkpoint, map_keys[:split_every]) map_keys = map_keys[split_every:] + [k] reduce_layer[name] = (chunks.checkpoint, map_keys) dsks.append(HighLevelGraph({name: reduce_layer}, dependencies={name: map_names})) dsk = HighLevelGraph.merge(*dsks) return Delayed(name, dsk) def _can_apply_blockwise(collection) -> bool: """Return True if _map_blocks can be sped up via blockwise operations; False otherwise. FIXME this returns False for collections that wrap around around da.Array, such as pint.Quantity, xarray DataArray, Dataset, and Variable. """ try: from dask.bag import Bag if isinstance(collection, Bag): return True except ImportError: pass try: from dask.array import Array if isinstance(collection, Array): return True except ImportError: pass try: from dask.dataframe import DataFrame, Series return isinstance(collection, (DataFrame, Series)) except ImportError: return False def _build_map_layer( func: Callable, prev_name: str, new_name: str, collection, dependencies: tuple[Delayed, ...] = (), ) -> Layer: """Apply func to all keys of collection. Create a Blockwise layer whenever possible; fall back to MaterializedLayer otherwise. Parameters ---------- func Callable to be invoked on the graph node prev_name : str name of the layer to map from; in case of dask base collections, this is the collection name. Note how third-party collections, e.g. xarray.Dataset, can have multiple names. new_name : str name of the layer to map to collection Arbitrary dask collection dependencies Zero or more Delayed objects, which will be passed as arbitrary variadic args to func after the collection's chunk """ if _can_apply_blockwise(collection): # Use a Blockwise layer try: numblocks = collection.numblocks except AttributeError: numblocks = (collection.npartitions,) indices = tuple(i for i, _ in enumerate(numblocks)) kwargs = {"_deps": [d.key for d in dependencies]} if dependencies else {} return blockwise( func, new_name, indices, prev_name, indices, numblocks={prev_name: numblocks}, dependencies=dependencies, **kwargs, ) else: # Delayed, bag.Item, dataframe.core.Scalar, or third-party collection; # fall back to MaterializedLayer dep_keys = tuple(d.key for d in dependencies) return MaterializedLayer( { replace_name_in_key(k, {prev_name: new_name}): (func, k) + dep_keys for k in flatten(collection.__dask_keys__()) if get_name_from_key(k) == prev_name } )
[docs]def bind( children: T, parents, *, omit=None, seed: Hashable | None = None, assume_layers: bool = True, split_every: float | Literal[False] | None = None, ) -> T: """ Make ``children`` collection(s), optionally omitting sub-collections, dependent on ``parents`` collection(s). Two examples follow. The first example creates an array ``b2`` whose computation first computes an array ``a`` completely and then computes ``b`` completely, recomputing ``a`` in the process: >>> import dask >>> import dask.array as da >>> a = da.ones(4, chunks=2) >>> b = a + 1 >>> b2 = bind(b, a) >>> len(b2.dask) 9 >>> b2.compute() array([2., 2., 2., 2.]) The second example creates arrays ``b3`` and ``c3``, whose computation first computes an array ``a`` and then computes the additions, this time not recomputing ``a`` in the process: >>> c = a + 2 >>> b3, c3 = bind((b, c), a, omit=a) >>> len(b3.dask), len(c3.dask) (7, 7) >>> dask.compute(b3, c3) (array([2., 2., 2., 2.]), array([3., 3., 3., 3.])) Parameters ---------- children Dask collection or nested structure of Dask collections parents Dask collection or nested structure of Dask collections omit Dask collection or nested structure of Dask collections seed Hashable used to seed the key regeneration. Omit to default to a random number that will produce different keys at every call. assume_layers True Use a fast algorithm that works at layer level, which assumes that all collections in ``children`` and ``omit`` #. use :class:`~dask.highlevelgraph.HighLevelGraph`, #. define the ``__dask_layers__()`` method, and #. never had their graphs squashed and rebuilt between the creation of the ``omit`` collections and the ``children`` collections; in other words if the keys of the ``omit`` collections can be found among the keys of the ``children`` collections, then the same must also hold true for the layers. False Use a slower algorithm that works at keys level, which makes none of the above assumptions. split_every See :func:`checkpoint` Returns ------- Same as ``children`` Dask collection or structure of dask collection equivalent to ``children``, which compute to the same values. All nodes of ``children`` will be regenerated, up to and excluding the nodes of ``omit``. Nodes immediately above ``omit``, or the leaf nodes if the collections in ``omit`` are not found, are prevented from computing until all collections in ``parents`` have been fully computed. The keys of the regenerated nodes will be different from the original ones, so that they can be used within the same graph. """ if seed is None: seed = uuid.uuid4().bytes # parents=None is a special case invoked by the one-liner wrapper clone() below blocker = ( checkpoint(parents, split_every=split_every) if parents is not None else None ) omit, _ = unpack_collections(omit) if assume_layers: # Set of all the top-level layers of the collections in omit omit_layers = {layer for coll in omit for layer in coll.__dask_layers__()} omit_keys = set() else: omit_layers = set() # Set of *all* the keys, not just the top-level ones, of the collections in omit omit_keys = {key for coll in omit for key in coll.__dask_graph__()} unpacked_children, repack = unpack_collections(children) return repack( [ _bind_one(child, blocker, omit_layers, omit_keys, seed) for child in unpacked_children ] )[0]
def _bind_one( child: T, blocker: Delayed | None, omit_layers: set[str], omit_keys: set[Key], seed: Hashable, ) -> T: prev_coll_names = get_collection_names(child) if not prev_coll_names: # Collection with no keys; this is a legitimate use case but, at the moment of # writing, can only happen with third-party collections return child dsk = child.__dask_graph__() # type: ignore new_layers: dict[str, Layer] = {} new_deps: dict[str, set[str]] = {} if isinstance(dsk, HighLevelGraph): try: layers_to_clone = set(child.__dask_layers__()) # type: ignore except AttributeError: layers_to_clone = prev_coll_names.copy() else: if len(prev_coll_names) == 1: hlg_name = next(iter(prev_coll_names)) else: hlg_name = tokenize(*prev_coll_names) dsk = HighLevelGraph.from_collections(hlg_name, dsk) layers_to_clone = {hlg_name} clone_keys = dsk.get_all_external_keys() - omit_keys for layer_name in omit_layers: try: layer = dsk.layers[layer_name] except KeyError: continue clone_keys -= layer.get_output_keys() # Note: when assume_layers=True, clone_keys can contain keys of the omit collections # that are not top-level. This is OK, as they will never be encountered inside the # values of their dependent layers. if blocker is not None: blocker_key = blocker.key blocker_dsk = blocker.__dask_graph__() assert isinstance(blocker_dsk, HighLevelGraph) new_layers.update(blocker_dsk.layers) new_deps.update(blocker_dsk.dependencies) else: blocker_key = None layers_to_copy_verbatim = set() while layers_to_clone: prev_layer_name = layers_to_clone.pop() new_layer_name = clone_key(prev_layer_name, seed=seed) if new_layer_name in new_layers: continue layer = dsk.layers[prev_layer_name] layer_deps = dsk.dependencies[prev_layer_name] layer_deps_to_clone = layer_deps - omit_layers layer_deps_to_omit = layer_deps & omit_layers layers_to_clone |= layer_deps_to_clone layers_to_copy_verbatim |= layer_deps_to_omit new_layers[new_layer_name], is_bound = layer.clone( keys=clone_keys, seed=seed, bind_to=blocker_key ) new_dep = { clone_key(dep, seed=seed) for dep in layer_deps_to_clone } | layer_deps_to_omit if is_bound: new_dep.add(blocker_key) new_deps[new_layer_name] = new_dep # Add the layers of the collections from omit from child.dsk. Note that, when # assume_layers=False, it would be unsafe to simply do HighLevelGraph.merge(dsk, # omit[i].dsk). Also, collections in omit may or may not be parents of this specific # child, or of any children at all. while layers_to_copy_verbatim: layer_name = layers_to_copy_verbatim.pop() if layer_name in new_layers: continue layer_deps = dsk.dependencies[layer_name] layers_to_copy_verbatim |= layer_deps new_deps[layer_name] = layer_deps new_layers[layer_name] = dsk.layers[layer_name] rebuild, args = child.__dask_postpersist__() # type: ignore return rebuild( HighLevelGraph(new_layers, new_deps), *args, rename={prev_name: clone_key(prev_name, seed) for prev_name in prev_coll_names}, )
[docs]def clone(*collections, omit=None, seed: Hashable = None, assume_layers: bool = True): """Clone dask collections, returning equivalent collections that are generated from independent calculations. Examples -------- (tokens have been simplified for the sake of brevity) >>> import dask.array as da >>> x_i = da.asarray([1, 1, 1, 1], chunks=2) >>> y_i = x_i + 1 >>> z_i = y_i + 2 >>> dict(z_i.dask) # doctest: +SKIP {('array-1', 0): array([1, 1]), ('array-1', 1): array([1, 1]), ('add-2', 0): (<function operator.add>, ('array-1', 0), 1), ('add-2', 1): (<function operator.add>, ('array-1', 1), 1), ('add-3', 0): (<function operator.add>, ('add-2', 0), 1), ('add-3', 1): (<function operator.add>, ('add-2', 1), 1)} >>> w_i = clone(z_i, omit=x_i) >>> w_i.compute() array([4, 4, 4, 4]) >>> dict(w_i.dask) # doctest: +SKIP {('array-1', 0): array([1, 1]), ('array-1', 1): array([1, 1]), ('add-4', 0): (<function operator.add>, ('array-1', 0), 1), ('add-4', 1): (<function operator.add>, ('array-1', 1), 1), ('add-5', 0): (<function operator.add>, ('add-4', 0), 1), ('add-5', 1): (<function operator.add>, ('add-4', 1), 1)} The typical usage pattern for clone() is the following: >>> x = cheap_computation_with_large_output() # doctest: +SKIP >>> y = expensive_and_long_computation(x) # doctest: +SKIP >>> z = wrap_up(clone(x), y) # doctest: +SKIP In the above code, the chunks of x will be forgotten as soon as they are consumed by the chunks of y, and then they'll be regenerated from scratch at the very end of the computation. Without clone(), x would only be computed once and then kept in memory throughout the whole computation of y, needlessly consuming memory. Parameters ---------- collections Zero or more Dask collections or nested structures of Dask collections omit Dask collection or nested structure of Dask collections which will not be cloned seed See :func:`bind` assume_layers See :func:`bind` Returns ------- Same as ``collections`` Dask collections of the same type as the inputs, which compute to the same value, or nested structures equivalent to the inputs, where the original collections have been replaced. The keys of the regenerated nodes in the new collections will be different from the original ones, so that they can be used within the same graph. """ out = bind( collections, parents=None, omit=omit, seed=seed, assume_layers=assume_layers ) return out[0] if len(collections) == 1 else out
[docs]def wait_on( *collections, split_every: float | Literal[False] | None = None, ): """Ensure that all chunks of all input collections have been computed before computing the dependents of any of the chunks. The following example creates a dask array ``u`` that, when used in a computation, will only proceed when all chunks of the array ``x`` have been computed, but otherwise matches ``x``: >>> import dask.array as da >>> x = da.ones(10, chunks=5) >>> u = wait_on(x) The following example will create two arrays ``u`` and ``v`` that, when used in a computation, will only proceed when all chunks of the arrays ``x`` and ``y`` have been computed but otherwise match ``x`` and ``y``: >>> x = da.ones(10, chunks=5) >>> y = da.zeros(10, chunks=5) >>> u, v = wait_on(x, y) Parameters ---------- collections Zero or more Dask collections or nested structures of Dask collections split_every See :func:`checkpoint` Returns ------- Same as ``collections`` Dask collection of the same type as the input, which computes to the same value, or a nested structure equivalent to the input where the original collections have been replaced. The keys of the regenerated nodes of the new collections will be different from the original ones, so that they can be used within the same graph. """ blocker = checkpoint(*collections, split_every=split_every) def block_one(coll): tok = tokenize(coll, blocker) dsks = [] rename = {} for prev_name in get_collection_names(coll): new_name = "wait_on-" + tokenize(prev_name, tok) rename[prev_name] = new_name layer = _build_map_layer( chunks.bind, prev_name, new_name, coll, dependencies=(blocker,) ) dsks.append( HighLevelGraph.from_collections( new_name, layer, dependencies=(coll, blocker) ) ) dsk = HighLevelGraph.merge(*dsks) rebuild, args = coll.__dask_postpersist__() return rebuild(dsk, *args, rename=rename) unpacked, repack = unpack_collections(*collections) out = repack([block_one(coll) for coll in unpacked]) return out[0] if len(collections) == 1 else out
class chunks: """Callables to be inserted in the Dask graph""" @staticmethod def bind(node: T, *args, **kwargs) -> T: """Dummy graph node of :func:`bind` and :func:`wait_on`. Wait for both node and all variadic args to complete; then return node. """ return node @staticmethod def checkpoint(*args, **kwargs) -> None: """Dummy graph node of :func:`checkpoint`. Wait for all variadic args to complete; then return None. """ pass