Advanced graph manipulation
Contents
Advanced graph manipulation¶
There are some situations where computations with Dask collections will result in suboptimal memory usage (e.g. an entire Dask DataFrame is loaded into memory). This may happen when Dask’s scheduler doesn’t automatically delay the computation of nodes in a task graph to avoid occupying memory with their output for prolonged periods of time, or in scenarios where recalculating nodes is much cheaper than holding their output in memory.
This page highlights a set of graph manipulation utilities which can be used to help avoid these scenarios. In particular, the utilities described below rewrite the underlying Dask graph for Dask collections, producing equivalent collections with different sets of keys.
Consider the following example:
>>> import dask.array as da
>>> x = da.random.normal(size=500_000_000, chunks=100_000)
>>> x_mean = x.mean()
>>> y = (x  x_mean).max().compute()
The above example computes the largest value of a distribution after removing its bias.
This involves loading the chunks of x
into memory in order to compute x_mean
.
However, since the x
array is needed later in the computation to compute y
, the
entire x
array is kept in memory. For large Dask Arrays this can be very
problematic.
To alleviate the need for the entire x
array to be kept in memory, one could rewrite
the last line as follows:
>>> from dask.graph_manipulation import bind
>>> xb = bind(x, x_mean)
>>> y = (xb  x_mean).max().compute()
Here we use bind()
to create a new Dask Array, xb
,
which produces exactly the same output as x
, but whose underlying Dask graph has
different keys than x
, and will only be computed after x_mean
has been
calculated.
This results in the chunks of x
being computed and immediately individually reduced
by mean
; then recomputed and again immediately pipelined into the subtraction
followed by reduction with max
. This results in a much smaller peak memory usage as
the full x
array is no longer loaded into memory. However, the tradeoff is that the
compute time increases as x
is computed twice.
API¶

Build a Dask Delayed which waits until all chunks of the input collection(s) have been computed before returning None. 

Ensure that all chunks of all input collections have been computed before computing the dependents of any of the chunks. 

Make 

Clone dask collections, returning equivalent collections that are generated from independent calculations. 
Definitions¶
 dask.graph_manipulation.checkpoint(*collections, split_every: float  Literal[False]  None = None) Delayed [source]¶
Build a Dask Delayed which waits until all chunks of the input collection(s) have been computed before returning None.
 Parameters
 collections
Zero or more Dask collections or nested data structures containing zero or more collections
 split_every: int >= 2 or False, optional
Determines the depth of the recursive aggregation. If greater than the number of input keys, the aggregation will be performed in multiple steps; the depth of the aggregation graph will be \(log_{split_every}(input keys)\). Setting to a low value can reduce cache size and network transfers, at the cost of more CPU and a larger dask graph.
Set to False to disable. Defaults to 8.
 Returns
 Dask Delayed yielding None
 dask.graph_manipulation.wait_on(*collections, split_every: float  Literal[False]  None = None)[source]¶
Ensure that all chunks of all input collections have been computed before computing the dependents of any of the chunks.
The following example creates a dask array
u
that, when used in a computation, will only proceed when all chunks of the arrayx
have been computed, but otherwise matchesx
:>>> import dask.array as da >>> x = da.ones(10, chunks=5) >>> u = wait_on(x)
The following example will create two arrays
u
andv
that, when used in a computation, will only proceed when all chunks of the arraysx
andy
have been computed but otherwise matchx
andy
:>>> x = da.ones(10, chunks=5) >>> y = da.zeros(10, chunks=5) >>> u, v = wait_on(x, y)
 Parameters
 collections
Zero or more Dask collections or nested structures of Dask collections
 split_every
See
checkpoint()
 Returns
 Same as
collections
Dask collection of the same type as the input, which computes to the same value, or a nested structure equivalent to the input where the original collections have been replaced. The keys of the regenerated nodes of the new collections will be different from the original ones, so that they can be used within the same graph.
 Same as
 dask.graph_manipulation.bind(children: T, parents, *, omit=None, seed: Hashable  None = None, assume_layers: bool = True, split_every: float  Literal[False]  None = None) T [source]¶
Make
children
collection(s), optionally omitting subcollections, dependent onparents
collection(s). Two examples follow.The first example creates an array
b2
whose computation first computes an arraya
completely and then computesb
completely, recomputinga
in the process:>>> import dask >>> import dask.array as da >>> a = da.ones(4, chunks=2) >>> b = a + 1 >>> b2 = bind(b, a) >>> len(b2.dask) 9 >>> b2.compute() array([2., 2., 2., 2.])
The second example creates arrays
b3
andc3
, whose computation first computes an arraya
and then computes the additions, this time not recomputinga
in the process:>>> c = a + 2 >>> b3, c3 = bind((b, c), a, omit=a) >>> len(b3.dask), len(c3.dask) (7, 7) >>> dask.compute(b3, c3) (array([2., 2., 2., 2.]), array([3., 3., 3., 3.]))
 Parameters
 children
Dask collection or nested structure of Dask collections
 parents
Dask collection or nested structure of Dask collections
 omit
Dask collection or nested structure of Dask collections
 seed
Hashable used to seed the key regeneration. Omit to default to a random number that will produce different keys at every call.
 assume_layers
 True
Use a fast algorithm that works at layer level, which assumes that all collections in
children
andomit
use
HighLevelGraph
,define the
__dask_layers__()
method, andnever had their graphs squashed and rebuilt between the creation of the
omit
collections and thechildren
collections; in other words if the keys of theomit
collections can be found among the keys of thechildren
collections, then the same must also hold true for the layers.
 False
Use a slower algorithm that works at keys level, which makes none of the above assumptions.
 split_every
See
checkpoint()
 Returns
 Same as
children
Dask collection or structure of dask collection equivalent to
children
, which compute to the same values. All nodes ofchildren
will be regenerated, up to and excluding the nodes ofomit
. Nodes immediately aboveomit
, or the leaf nodes if the collections inomit
are not found, are prevented from computing until all collections inparents
have been fully computed. The keys of the regenerated nodes will be different from the original ones, so that they can be used within the same graph.
 Same as
 dask.graph_manipulation.clone(*collections, omit=None, seed: Optional[collections.abc.Hashable] = None, assume_layers: bool = True)[source]¶
Clone dask collections, returning equivalent collections that are generated from independent calculations.
 Parameters
 Returns
 Same as
collections
Dask collections of the same type as the inputs, which compute to the same value, or nested structures equivalent to the inputs, where the original collections have been replaced. The keys of the regenerated nodes in the new collections will be different from the original ones, so that they can be used within the same graph.
 Same as
Examples
(tokens have been simplified for the sake of brevity)
>>> import dask.array as da >>> x_i = da.asarray([1, 1, 1, 1], chunks=2) >>> y_i = x_i + 1 >>> z_i = y_i + 2 >>> dict(z_i.dask) {('array1', 0): array([1, 1]), ('array1', 1): array([1, 1]), ('add2', 0): (<function operator.add>, ('array1', 0), 1), ('add2', 1): (<function operator.add>, ('array1', 1), 1), ('add3', 0): (<function operator.add>, ('add2', 0), 1), ('add3', 1): (<function operator.add>, ('add2', 1), 1)} >>> w_i = clone(z_i, omit=x_i) >>> w_i.compute() array([4, 4, 4, 4]) >>> dict(w_i.dask) {('array1', 0): array([1, 1]), ('array1', 1): array([1, 1]), ('add4', 0): (<function operator.add>, ('array1', 0), 1), ('add4', 1): (<function operator.add>, ('array1', 1), 1), ('add5', 0): (<function operator.add>, ('add4', 0), 1), ('add5', 1): (<function operator.add>, ('add4', 1), 1)}
The typical usage pattern for clone() is the following:
>>> x = cheap_computation_with_large_output() >>> y = expensive_and_long_computation(x) >>> z = wrap_up(clone(x), y)
In the above code, the chunks of x will be forgotten as soon as they are consumed by the chunks of y, and then they’ll be regenerated from scratch at the very end of the computation. Without clone(), x would only be computed once and then kept in memory throughout the whole computation of y, needlessly consuming memory.