Custom Collections

For many problems, the built-in Dask collections (dask.array, dask.dataframe, dask.bag, and dask.delayed) are sufficient. For cases where they aren’t, it’s possible to create your own Dask collections. Here we describe the required methods to fulfill the Dask collection interface.

Note

This is considered an advanced feature. For most cases the built-in collections are probably sufficient.

Before reading this you should read and understand:

Contents

The Dask Collection Interface

To create your own Dask collection, you need to fulfill the interface defined by the dask.typing.DaskCollection protocol. Note that there is no required base class.

It is recommended to also read Internals of the Core Dask Methods to see how this interface is used inside Dask.

Collection Protocol

class dask.typing.DaskCollection(*args, **kwargs)[source]

Protocol defining the interface of a Dask collection.

abstract __dask_graph__() collections.abc.Mapping[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any][source]

The Dask task graph.

The core Dask collections (Array, DataFrame, Bag, and Delayed) use a HighLevelGraph to represent the collection task graph. It is also possible to represent the task graph as a low level graph using a Python dictionary.

Returns
Mapping

The Dask task graph. If the instance returns a dask.highlevelgraph.HighLevelGraph then the __dask_layers__() method must be implemented, as defined by the HLGDaskCollection protocol.

abstract __dask_keys__() list[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...], ForwardRef('NestedKeys')]][source]

The output keys of the task graph.

Note that there are additional constraints on keys for a Dask collection than those described in the task graph specification documentation. These additional constraints are described below.

All keys must either be non-empty strings or tuples where the first element is a non-empty string, followed by zero or more arbitrary str, bytes, int, float, or tuples thereof. The non-empty string is commonly known as the collection name. All collections embedded in the dask package have exactly one name, but this is not a requirement.

These are all valid outputs:

  • []

  • ["x", "y"]

  • [[("y", "a", 0), ("y", "a", 1)], [("y", "b", 0), ("y", "b", 1)]

Returns
list

A possibly nested list of keys that represent the outputs of the graph. After computation, the results will be returned in the same layout, with the keys replaced with their corresponding outputs.

__dask_optimize__: Any

Given a graph and keys, return a new optimized graph.

This method can be either a staticmethod or a classmethod, but not an instancemethod. For example implementations see the definitions of __dask_optimize__ in the core Dask collections: dask.array.Array, dask.dataframe.DataFrame, etc.

Note that graphs and keys are merged before calling __dask_optimize__; as such, the graph and keys passed to this method may represent more than one collection sharing the same optimize method.

Parameters
dskGraph

The merged graphs from all collections sharing the same __dask_optimize__ method.

keysSequence[Key]

A list of the outputs from __dask_keys__ from all collections sharing the same __dask_optimize__ method.

**kwargsAny

Extra keyword arguments forwarded from the call to compute or persist. Can be used or ignored as needed.

Returns
MutableMapping

The optimized Dask graph.

abstract __dask_postcompute__() tuple[collections.abc.Callable, tuple][source]

Finalizer function and optional arguments to construct final result.

Upon computation each key in the collection will have an in memory result, the postcompute function combines each key’s result into a final in memory representation. For example, dask.array.Array concatenates the arrays at each chunk into a final in-memory array.

Returns
PostComputeCallable

Callable that receives the sequence of the results of each final key along with optional arguments. An example signature would be finalize(results: Sequence[Any], *args).

tuple[Any, …]

Optional arguments passed to the function following the key results (the *args part of the PostComputeCallable. If no additional arguments are to be passed then this must be an empty tuple.

abstract __dask_postpersist__() tuple[dask.typing.PostPersistCallable, tuple][source]

Rebuilder function and optional arguments to construct a persisted collection.

See also the documentation for dask.typing.PostPersistCallable.

Returns
PostPersistCallable

Callable that rebuilds the collection. The signature should be rebuild(dsk: Mapping, *args: Any, rename: Mapping[str, str] | None) (as defined by the PostPersistCallable protocol). The callable should return an equivalent Dask collection with the same keys as self, but with results that are computed through a different graph. In the case of dask.persist(), the new graph will have just the output keys and the values already computed.

tuple[Any, …]

Optional arguments passed to the rebuild callable. If no additional arguments are to be passed then this must be an empty tuple.

__dask_scheduler__: staticmethod[SchedulerGetCallable]

The default scheduler get to use for this object.

Usually attached to the class as a staticmethod, e.g.:

>>> import dask.threaded
>>> class MyCollection:
...     # Use the threaded scheduler by default
...     __dask_scheduler__ = staticmethod(dask.threaded.get)
abstract __dask_tokenize__() collections.abc.Hashable[source]

Value that must fully represent the object.

abstract compute(**kwargs: Any) Any[source]

Compute this dask collection.

This turns a lazy Dask collection into its in-memory equivalent. For example a Dask array turns into a NumPy array and a Dask dataframe turns into a Pandas dataframe. The entire dataset must fit into memory before calling this operation.

Parameters
schedulerstring, optional

Which scheduler to use like “threads”, “synchronous” or “processes”. If not provided, the default is to check the global settings first, and then fall back to the collection defaults.

optimize_graphbool, optional

If True [default], the graph is optimized before computation. Otherwise the graph is run as is. This can be useful for debugging.

kwargs

Extra keywords to forward to the scheduler function.

Returns
The collection’s computed result.

See also

dask.compute
abstract persist(**kwargs: Any) dask.typing.CollType[source]

Persist this dask collection into memory

This turns a lazy Dask collection into a Dask collection with the same metadata, but now with the results fully computed or actively computing in the background.

The action of function differs significantly depending on the active task scheduler. If the task scheduler supports asynchronous computing, such as is the case of the dask.distributed scheduler, then persist will return immediately and the return value’s task graph will contain Dask Future objects. However if the task scheduler only supports blocking computation then the call to persist will block and the return value’s task graph will contain concrete Python results.

This function is particularly useful when using distributed systems, because the results will be kept in distributed memory, rather than returned to the local process as with compute.

Parameters
schedulerstring, optional

Which scheduler to use like “threads”, “synchronous” or “processes”. If not provided, the default is to check the global settings first, and then fall back to the collection defaults.

optimize_graphbool, optional

If True [default], the graph is optimized before computation. Otherwise the graph is run as is. This can be useful for debugging.

**kwargs

Extra keywords to forward to the scheduler function.

Returns
New dask collections backed by in-memory data

See also

dask.persist
abstract visualize(filename: str = 'mydask', format: str | None = None, optimize_graph: bool = False, **kwargs: Any) DisplayObject | None[source]

Render the computation of this object’s task graph using graphviz.

Requires graphviz to be installed.

Parameters
filenamestr or None, optional

The name of the file to write to disk. If the provided filename doesn’t include an extension, ‘.png’ will be used by default. If filename is None, no file will be written, and we communicate with dot using only pipes.

format{‘png’, ‘pdf’, ‘dot’, ‘svg’, ‘jpeg’, ‘jpg’}, optional

Format in which to write output file. Default is ‘png’.

optimize_graphbool, optional

If True, the graph is optimized before rendering. Otherwise, the graph is displayed as is. Default is False.

color: {None, ‘order’}, optional

Options to color nodes. Provide cmap= keyword for additional colormap

**kwargs

Additional keyword arguments to forward to to_graphviz.

Returns
resultIPython.display.Image, IPython.display.SVG, or None

See dask.dot.dot_graph for more information.

See also

dask.visualize
dask.dot.dot_graph

Notes

For more information on optimization see here:

https://docs.dask.org/en/latest/optimize.html

Examples

>>> x.visualize(filename='dask.pdf')  
>>> x.visualize(filename='dask.pdf', color='order')  

HLG Collection Protocol

Collections backed by Dask’s High Level Graphs must implement an additional method, defined by this protocol:

class dask.typing.HLGDaskCollection(*args, **kwargs)[source]

Protocol defining a Dask collection that uses HighLevelGraphs.

This protocol is nearly identical to DaskCollection, with the addition of the __dask_layers__ method (required for collections backed by high level graphs).

abstract __dask_layers__() collections.abc.Sequence[str][source]

Names of the HighLevelGraph layers.

Scheduler get Protocol

The SchedulerGetProtocol defines the signature that a Dask collection’s __dask_scheduler__ definition must adhere to.

class dask.typing.SchedulerGetCallable(*args, **kwargs)[source]

Protocol defining the signature of a __dask_scheduler__ callable.

__call__(dsk: collections.abc.Mapping[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any], keys: Union[collections.abc.Sequence[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]]], str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]], **kwargs: Any) Any[source]

Method called as the default scheduler for a collection.

Parameters
dsk

The task graph.

keys

Key(s) corresponding to the desired data.

**kwargs

Additional arguments.

Returns
Any

Result(s) associated with keys

Post-persist Callable Protocol

Collections must define a __dask_postpersist__ method which returns a callable that adheres to the PostPersistCallable interface.

class dask.typing.PostPersistCallable(*args, **kwargs)[source]

Protocol defining the signature of a __dask_postpersist__ callable.

__call__(dsk: collections.abc.Mapping[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any], *args: Any, rename: collections.abc.Mapping[str, str] | None = None) dask.typing.CollType_co[source]

Method called to rebuild a persisted collection.

Parameters
dsk: Mapping

A mapping which contains at least the output keys returned by __dask_keys__().

*argsAny

Additional optional arguments If no extra arguments are necessary, it must be an empty tuple.

renameMapping[str, str], optional

If defined, it indicates that output keys may be changing too; e.g. if the previous output of __dask_keys__() was [("a", 0), ("a", 1)], after calling rebuild(dsk, *extra_args, rename={"a": "b"}) it must become [("b", 0), ("b", 1)]. The rename mapping may not contain the collection name(s); in such case the associated keys do not change. It may contain replacements for unexpected names, which must be ignored.

Returns
Collection

An equivalent Dask collection with the same keys as computed through a different graph.

Internals of the Core Dask Methods

Dask has a few core functions (and corresponding methods) that implement common operations:

  • compute: Convert one or more Dask collections into their in-memory counterparts

  • persist: Convert one or more Dask collections into equivalent Dask collections with their results already computed and cached in memory

  • optimize: Convert one or more Dask collections into equivalent Dask collections sharing one large optimized graph

  • visualize: Given one or more Dask collections, draw out the graph that would be passed to the scheduler during a call to compute or persist

Here we briefly describe the internals of these functions to illustrate how they relate to the above interface.

Compute

The operation of compute can be broken into three stages:

  1. Graph Merging & Optimization

    First, the individual collections are converted to a single large graph and nested list of keys. How this happens depends on the value of the optimize_graph keyword, which each function takes:

    • If optimize_graph is True (default), then the collections are first grouped by their __dask_optimize__ methods. All collections with the same __dask_optimize__ method have their graphs merged and keys concatenated, and then a single call to each respective __dask_optimize__ is made with the merged graphs and keys. The resulting graphs are then merged.

    • If optimize_graph is False, then all the graphs are merged and all the keys concatenated.

    After this stage there is a single large graph and nested list of keys which represents all the collections.

  2. Computation

    After the graphs are merged and any optimizations performed, the resulting large graph and nested list of keys are passed on to the scheduler. The scheduler to use is chosen as follows:

    • If a get function is specified directly as a keyword, use that

    • Otherwise, if a global scheduler is set, use that

    • Otherwise fall back to the default scheduler for the given collections. Note that if all collections don’t share the same __dask_scheduler__ then an error will be raised.

    Once the appropriate scheduler get function is determined, it is called with the merged graph, keys, and extra keyword arguments. After this stage, results is a nested list of values. The structure of this list mirrors that of keys, with each key substituted with its corresponding result.

  3. Postcompute

    After the results are generated, the output values of compute need to be built. This is what the __dask_postcompute__ method is for. __dask_postcompute__ returns two things:

    • A finalize function, which takes in the results for the corresponding keys

    • A tuple of extra arguments to pass to finalize after the results

    To build the outputs, the list of collections and results is iterated over, and the finalizer for each collection is called on its respective results.

In pseudocode, this process looks like the following:

def compute(*collections, **kwargs):
    # 1. Graph Merging & Optimization
    # -------------------------------
    if kwargs.pop('optimize_graph', True):
        # If optimization is turned on, group the collections by
        # optimization method, and apply each method only once to the merged
        # sub-graphs.
        optimization_groups = groupby_optimization_methods(collections)
        graphs = []
        for optimize_method, cols in optimization_groups:
            # Merge the graphs and keys for the subset of collections that
            # share this optimization method
            sub_graph = merge_graphs([x.__dask_graph__() for x in cols])
            sub_keys = [x.__dask_keys__() for x in cols]
            # kwargs are forwarded to ``__dask_optimize__`` from compute
            optimized_graph = optimize_method(sub_graph, sub_keys, **kwargs)
            graphs.append(optimized_graph)
        graph = merge_graphs(graphs)
    else:
        graph = merge_graphs([x.__dask_graph__() for x in collections])
    # Keys are always the same
    keys = [x.__dask_keys__() for x in collections]

    # 2. Computation
    # --------------
    # Determine appropriate get function based on collections, global
    # settings, and keyword arguments
    get = determine_get_function(collections, **kwargs)
    # Pass the merged graph, keys, and kwargs to ``get``
    results = get(graph, keys, **kwargs)

    # 3. Postcompute
    # --------------
    output = []
    # Iterate over the results and collections
    for res, collection in zip(results, collections):
        finalize, extra_args = collection.__dask_postcompute__()
        out = finalize(res, **extra_args)
        output.append(out)

    # `dask.compute` always returns tuples
    return tuple(output)

Persist

Persist is very similar to compute, except for how the return values are created. It too has three stages:

  1. Graph Merging & Optimization

    Same as in compute.

  2. Computation

    Same as in compute, except in the case of the distributed scheduler, where the values in results are futures instead of values.

  3. Postpersist

    Similar to __dask_postcompute__, __dask_postpersist__ is used to rebuild values in a call to persist. __dask_postpersist__ returns two things:

    • A rebuild function, which takes in a persisted graph. The keys of this graph are the same as __dask_keys__ for the corresponding collection, and the values are computed results (for the single-machine scheduler) or futures (for the distributed scheduler).

    • A tuple of extra arguments to pass to rebuild after the graph

    To build the outputs of persist, the list of collections and results is iterated over, and the rebuilder for each collection is called on the graph for its respective results.

In pseudocode, this looks like the following:

def persist(*collections, **kwargs):
    # 1. Graph Merging & Optimization
    # -------------------------------
    # **Same as in compute**
    graph = ...
    keys = ...

    # 2. Computation
    # --------------
    # **Same as in compute**
    results = ...

    # 3. Postpersist
    # --------------
    output = []
    # Iterate over the results and collections
    for res, collection in zip(results, collections):
        # res has the same structure as keys
        keys = collection.__dask_keys__()
        # Get the computed graph for this collection.
        # Here flatten converts a nested list into a single list
        subgraph = {k: r for (k, r) in zip(flatten(keys), flatten(res))}

        # Rebuild the output dask collection with the computed graph
        rebuild, extra_args = collection.__dask_postpersist__()
        out = rebuild(subgraph, *extra_args)

        output.append(out)

    # dask.persist always returns tuples
    return tuple(output)

Optimize

The operation of optimize can be broken into two stages:

  1. Graph Merging & Optimization

    Same as in compute.

  2. Rebuilding

    Similar to persist, the rebuild function and arguments from __dask_postpersist__ are used to reconstruct equivalent collections from the optimized graph.

In pseudocode, this looks like the following:

def optimize(*collections, **kwargs):
    # 1. Graph Merging & Optimization
    # -------------------------------
    # **Same as in compute**
    graph = ...

    # 2. Rebuilding
    # -------------
    # Rebuild each dask collection using the same large optimized graph
    output = []
    for collection in collections:
        rebuild, extra_args = collection.__dask_postpersist__()
        out = rebuild(graph, *extra_args)
        output.append(out)

    # dask.optimize always returns tuples
    return tuple(output)

Visualize

Visualize is the simplest of the 4 core functions. It only has two stages:

  1. Graph Merging & Optimization

    Same as in compute.

  2. Graph Drawing

    The resulting merged graph is drawn using graphviz and outputs to the specified file.

In pseudocode, this looks like the following:

def visualize(*collections, **kwargs):
    # 1. Graph Merging & Optimization
    # -------------------------------
    # **Same as in compute**
    graph = ...

    # 2. Graph Drawing
    # ----------------
    # Draw the graph with graphviz's `dot` tool and return the result.
    return dot_graph(graph, **kwargs)

Adding the Core Dask Methods to Your Class

Defining the above interface will allow your object to used by the core Dask functions (dask.compute, dask.persist, dask.visualize, etc.). To add corresponding method versions of these, you can subclass from dask.base.DaskMethodsMixin which adds implementations of compute, persist, and visualize based on the interface above.

Example Dask Collection

Here we create a Dask collection representing a tuple. Every element in the tuple is represented as a task in the graph. Note that this is for illustration purposes only - the same user experience could be done using normal tuples with elements of dask.delayed:

# Saved as dask_tuple.py
import dask
from dask.base import DaskMethodsMixin, replace_name_in_key
from dask.optimization import cull

def tuple_optimize(dsk, keys, **kwargs):
    # We cull unnecessary tasks here. See
    # https://docs.dask.org/en/stable/optimize.html for more
    # information on optimizations in Dask.
    dsk2, _ = cull(dsk, keys)
    return dsk2

# We subclass from DaskMethodsMixin to add common dask methods to
# our class (compute, persist, and visualize). This is nice but not
# necessary for creating a Dask collection (you can define them
# yourself).
class Tuple(DaskMethodsMixin):
    def __init__(self, dsk, keys):
        # The init method takes in a dask graph and a set of keys to use
        # as outputs.
        self._dsk = dsk
        self._keys = keys

    def __dask_graph__(self):
        return self._dsk

    def __dask_keys__(self):
        return self._keys

    # use the `tuple_optimize` function defined above
    __dask_optimize__ = staticmethod(tuple_optimize)

    # Use the threaded scheduler by default.
    __dask_scheduler__ = staticmethod(dask.threaded.get)

    def __dask_postcompute__(self):
        # We want to return the results as a tuple, so our finalize
        # function is `tuple`. There are no extra arguments, so we also
        # return an empty tuple.
        return tuple, ()

    def __dask_postpersist__(self):
        # We need to return a callable with the signature
        # rebuild(dsk, *extra_args, rename: Mapping[str, str] = None)
        return Tuple._rebuild, (self._keys,)

    @staticmethod
    def _rebuild(dsk, keys, *, rename=None):
        if rename is not None:
            keys = [replace_name_in_key(key, rename) for key in keys]
        return Tuple(dsk, keys)

    def __dask_tokenize__(self):
        # For tokenize to work we want to return a value that fully
        # represents this object. In this case it's the list of keys
        # to be computed.
        return self._keys

Demonstrating this class:

>>> from dask_tuple import Tuple
>>> from operator import add, mul

# Define a dask graph
>>> dsk = {"k0": 1,
...        ("x", "k1"): 2,
...        ("x", 1): (add, "k0", ("x", "k1")),
...        ("x", 2): (mul, ("x", "k1"), 2),
...        ("x", 3): (add, ("x", "k1"), ("x", 1))}

# The output keys for this graph.
# The first element of each tuple must be the same across the whole collection;
# the remainder are arbitrary, unique str, bytes, int, or floats
>>> keys = [("x", "k1"), ("x", 1), ("x", 2), ("x", 3)]

>>> x = Tuple(dsk, keys)

# Compute turns Tuple into a tuple
>>> x.compute()
(2, 3, 4, 5)

# Persist turns Tuple into a Tuple, with each task already computed
>>> x2 = x.persist()
>>> isinstance(x2, Tuple)
True
>>> x2.__dask_graph__()
{('x', 'k1'): 2, ('x', 1): 3, ('x', 2): 4, ('x', 3): 5}
>>> x2.compute()
(2, 3, 4, 5)

# Run-time typechecking
>>> from dask.typing import DaskCollection
>>> isinstance(x, DaskCollection)
True

Checking if an object is a Dask collection

To check if an object is a Dask collection, use dask.base.is_dask_collection:

>>> from dask.base import is_dask_collection
>>> from dask import delayed

>>> x = delayed(sum)([1, 2, 3])
>>> is_dask_collection(x)
True
>>> is_dask_collection(1)
False

Implementing Deterministic Hashing

Dask implements its own deterministic hash function to generate keys based on the value of arguments. This function is available as dask.base.tokenize. Many common types already have implementations of tokenize, which can be found in dask/base.py.

When creating your own custom classes, you may need to register a tokenize implementation. There are two ways to do this:

  1. The __dask_tokenize__ method

    Where possible, it is recommended to define the __dask_tokenize__ method. This method takes no arguments and should return a value fully representative of the object. It is a good idea to call dask.base.normalize_token from it before returning any non-trivial objects.

  2. Register a function with dask.base.normalize_token

    If defining a method on the class isn’t possible or you need to customize the tokenize function for a class whose super-class is already registered (for example if you need to sub-class built-ins), you can register a tokenize function with the normalize_token dispatch. The function should have the same signature as described above.

In both cases the implementation should be the same, where only the location of the definition is different.

Note

Both Dask collections and normal Python objects can have implementations of tokenize using either of the methods described above.

Example

>>> from dask.base import tokenize, normalize_token

# Define a tokenize implementation using a method.
>>> class Point:
...     def __init__(self, x, y):
...         self.x = x
...         self.y = y
...
...     def __dask_tokenize__(self):
...         # This tuple fully represents self
...         # Wrap non-trivial objects with normalize_token before returning them
...         return normalize_token(Point), self.x, self.y

>>> x = Point(1, 2)
>>> tokenize(x)
'5988362b6e07087db2bc8e7c1c8cc560'
>>> tokenize(x) == tokenize(x)  # token is idempotent
True
>>> tokenize(Point(1, 2)) == tokenize(Point(1, 2))  # token is deterministic
True
>>> tokenize(Point(1, 2)) == tokenize(Point(2, 1))  # tokens are unique
False


# Register an implementation with normalize_token
>>> class Point3D:
...     def __init__(self, x, y, z):
...         self.x = x
...         self.y = y
...         self.z = z

>>> @normalize_token.register(Point3D)
... def normalize_point3d(x):
...     return normalize_token(Point3D), x.x, x.y, x.z

>>> y = Point3D(1, 2, 3)
>>> tokenize(y)
'5a7e9c3645aa44cf13d021c14452152e'

For more examples, see dask/base.py or any of the built-in Dask collections.