Advanced graph manipulation

Advanced graph manipulation

There are some situations where computations with Dask collections will result in suboptimal memory usage (e.g. an entire Dask DataFrame is loaded into memory). This may happen when Dask’s scheduler doesn’t automatically delay the computation of nodes in a task graph to avoid occupying memory with their output for prolonged periods of time, or in scenarios where recalculating nodes is much cheaper than holding their output in memory.

This page highlights a set of graph manipulation utilities which can be used to help avoid these scenarios. In particular, the utilities described below rewrite the underlying Dask graph for Dask collections, producing equivalent collections with different sets of keys.

Consider the following example:

>>> import dask.array as da
>>> x = da.random.default_rng().normal(size=500_000_000, chunks=100_000)
>>> x_mean = x.mean()
>>> y = (x - x_mean).max().compute()

The above example computes the largest value of a distribution after removing its bias. This involves loading the chunks of x into memory in order to compute x_mean. However, since the x array is needed later in the computation to compute y, the entire x array is kept in memory. For large Dask Arrays this can be very problematic.

To alleviate the need for the entire x array to be kept in memory, one could rewrite the last line as follows:

>>> from dask.graph_manipulation import bind
>>> xb = bind(x, x_mean)
>>> y = (xb - x_mean).max().compute()

Here we use bind() to create a new Dask Array, xb, which produces exactly the same output as x, but whose underlying Dask graph has different keys than x, and will only be computed after x_mean has been calculated.

This results in the chunks of x being computed and immediately individually reduced by mean; then recomputed and again immediately pipelined into the subtraction followed by reduction with max. This results in a much smaller peak memory usage as the full x array is no longer loaded into memory. However, the tradeoff is that the compute time increases as x is computed twice.

API

checkpoint(*collections[, split_every])

Build a Dask Delayed which waits until all chunks of the input collection(s) have been computed before returning None.

wait_on(*collections[, split_every])

Ensure that all chunks of all input collections have been computed before computing the dependents of any of the chunks.

bind(children, parents, *[, omit, seed, ...])

Make children collection(s), optionally omitting sub-collections, dependent on parents collection(s).

clone(*collections[, omit, seed, assume_layers])

Clone dask collections, returning equivalent collections that are generated from independent calculations.

Definitions

dask.graph_manipulation.checkpoint(*collections, split_every: Optional[Union[float, Literal[False]]] = None) dask.delayed.Delayed[source]

Build a Dask Delayed which waits until all chunks of the input collection(s) have been computed before returning None.

Parameters
collections

Zero or more Dask collections or nested data structures containing zero or more collections

split_every: int >= 2 or False, optional

Determines the depth of the recursive aggregation. If greater than the number of input keys, the aggregation will be performed in multiple steps; the depth of the aggregation graph will be \(log_{split_every}(input keys)\). Setting to a low value can reduce cache size and network transfers, at the cost of more CPU and a larger dask graph.

Set to False to disable. Defaults to 8.

Returns
Dask Delayed yielding None
dask.graph_manipulation.wait_on(*collections, split_every: Optional[Union[float, Literal[False]]] = None)[source]

Ensure that all chunks of all input collections have been computed before computing the dependents of any of the chunks.

The following example creates a dask array u that, when used in a computation, will only proceed when all chunks of the array x have been computed, but otherwise matches x:

>>> import dask.array as da
>>> x = da.ones(10, chunks=5)
>>> u = wait_on(x)

The following example will create two arrays u and v that, when used in a computation, will only proceed when all chunks of the arrays x and y have been computed but otherwise match x and y:

>>> x = da.ones(10, chunks=5)
>>> y = da.zeros(10, chunks=5)
>>> u, v = wait_on(x, y)
Parameters
collections

Zero or more Dask collections or nested structures of Dask collections

split_every

See checkpoint()

Returns
Same as collections

Dask collection of the same type as the input, which computes to the same value, or a nested structure equivalent to the input where the original collections have been replaced. The keys of the regenerated nodes of the new collections will be different from the original ones, so that they can be used within the same graph.

dask.graph_manipulation.bind(children: dask.graph_manipulation.T, parents, *, omit=None, seed: collections.abc.Hashable | None = None, assume_layers: bool = True, split_every: Optional[Union[float, Literal[False]]] = None) dask.graph_manipulation.T[source]

Make children collection(s), optionally omitting sub-collections, dependent on parents collection(s). Two examples follow.

The first example creates an array b2 whose computation first computes an array a completely and then computes b completely, recomputing a in the process:

>>> import dask
>>> import dask.array as da
>>> a = da.ones(4, chunks=2)
>>> b = a + 1
>>> b2 = bind(b, a)
>>> len(b2.dask)
9
>>> b2.compute()
array([2., 2., 2., 2.])

The second example creates arrays b3 and c3, whose computation first computes an array a and then computes the additions, this time not recomputing a in the process:

>>> c = a + 2
>>> b3, c3 = bind((b, c), a, omit=a)
>>> len(b3.dask), len(c3.dask)
(7, 7)
>>> dask.compute(b3, c3)
(array([2., 2., 2., 2.]), array([3., 3., 3., 3.]))
Parameters
children

Dask collection or nested structure of Dask collections

parents

Dask collection or nested structure of Dask collections

omit

Dask collection or nested structure of Dask collections

seed

Hashable used to seed the key regeneration. Omit to default to a random number that will produce different keys at every call.

assume_layers
True

Use a fast algorithm that works at layer level, which assumes that all collections in children and omit

  1. use HighLevelGraph,

  2. define the __dask_layers__() method, and

  3. never had their graphs squashed and rebuilt between the creation of the omit collections and the children collections; in other words if the keys of the omit collections can be found among the keys of the children collections, then the same must also hold true for the layers.

False

Use a slower algorithm that works at keys level, which makes none of the above assumptions.

split_every

See checkpoint()

Returns
Same as children

Dask collection or structure of dask collection equivalent to children, which compute to the same values. All nodes of children will be regenerated, up to and excluding the nodes of omit. Nodes immediately above omit, or the leaf nodes if the collections in omit are not found, are prevented from computing until all collections in parents have been fully computed. The keys of the regenerated nodes will be different from the original ones, so that they can be used within the same graph.

dask.graph_manipulation.clone(*collections, omit=None, seed: collections.abc.Hashable = None, assume_layers: bool = True)[source]

Clone dask collections, returning equivalent collections that are generated from independent calculations.

Parameters
collections

Zero or more Dask collections or nested structures of Dask collections

omit

Dask collection or nested structure of Dask collections which will not be cloned

seed

See bind()

assume_layers

See bind()

Returns
Same as collections

Dask collections of the same type as the inputs, which compute to the same value, or nested structures equivalent to the inputs, where the original collections have been replaced. The keys of the regenerated nodes in the new collections will be different from the original ones, so that they can be used within the same graph.

Examples

(tokens have been simplified for the sake of brevity)

>>> import dask.array as da
>>> x_i = da.asarray([1, 1, 1, 1], chunks=2)
>>> y_i = x_i + 1
>>> z_i = y_i + 2
>>> dict(z_i.dask)  
{('array-1', 0): array([1, 1]),
 ('array-1', 1): array([1, 1]),
 ('add-2', 0): (<function operator.add>, ('array-1', 0), 1),
 ('add-2', 1): (<function operator.add>, ('array-1', 1), 1),
 ('add-3', 0): (<function operator.add>, ('add-2', 0), 1),
 ('add-3', 1): (<function operator.add>, ('add-2', 1), 1)}
>>> w_i = clone(z_i, omit=x_i)
>>> w_i.compute()
array([4, 4, 4, 4])
>>> dict(w_i.dask)  
{('array-1', 0): array([1, 1]),
 ('array-1', 1): array([1, 1]),
 ('add-4', 0): (<function operator.add>, ('array-1', 0), 1),
 ('add-4', 1): (<function operator.add>, ('array-1', 1), 1),
 ('add-5', 0): (<function operator.add>, ('add-4', 0), 1),
 ('add-5', 1): (<function operator.add>, ('add-4', 1), 1)}

The typical usage pattern for clone() is the following:

>>> x = cheap_computation_with_large_output()  
>>> y = expensive_and_long_computation(x)  
>>> z = wrap_up(clone(x), y)  

In the above code, the chunks of x will be forgotten as soon as they are consumed by the chunks of y, and then they’ll be regenerated from scratch at the very end of the computation. Without clone(), x would only be computed once and then kept in memory throughout the whole computation of y, needlessly consuming memory.