from __future__ import annotations
import functools
import os
import uuid
import warnings
import weakref
from collections import defaultdict
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
import toolz
import dask
from dask._task_spec import Task, convert_legacy_graph
from dask.tokenize import _tokenize_deterministic
from dask.typing import Key
from dask.utils import ensure_dict, funcname, import_required
if TYPE_CHECKING:
from dask.highlevelgraph import HighLevelGraph
OptimizerStage: TypeAlias = Literal[
"logical",
"simplified-logical",
"tuned-logical",
"physical",
"simplified-physical",
"fused",
]
def _unpack_collections(o):
from dask.delayed import Delayed
if isinstance(o, Expr):
return o
if hasattr(o, "expr") and not isinstance(o, Delayed):
return o.expr
else:
return o
[docs]
class Expr:
_parameters: list[str] = []
_defaults: dict[str, Any] = {}
_pickle_functools_cache: bool = True
operands: list
_determ_token: str | None
def __new__(cls, *args, _determ_token=None, **kwargs):
operands = list(args)
for parameter in cls._parameters[len(operands) :]:
try:
operands.append(kwargs.pop(parameter))
except KeyError:
operands.append(cls._defaults[parameter])
assert not kwargs, kwargs
inst = object.__new__(cls)
inst._determ_token = _determ_token
inst.operands = [_unpack_collections(o) for o in operands]
# This is typically cached. Make sure the cache is populated by calling
# it once
inst._name
return inst
def _tune_down(self):
return None
def _tune_up(self, parent):
return None
def finalize_compute(self):
return self
def _operands_for_repr(self):
return [f"{param}={op!r}" for param, op in zip(self._parameters, self.operands)]
def __str__(self):
s = ", ".join(self._operands_for_repr())
return f"{type(self).__name__}({s})"
def __repr__(self):
return str(self)
def _tree_repr_argument_construction(self, i, op, header):
try:
param = self._parameters[i]
default = self._defaults[param]
except (IndexError, KeyError):
param = self._parameters[i] if i < len(self._parameters) else ""
default = "--no-default--"
if repr(op) != repr(default):
if param:
header += f" {param}={op!r}"
else:
header += repr(op)
return header
def _tree_repr_lines(self, indent=0, recursive=True):
return " " * indent + repr(self)
def tree_repr(self):
return os.linesep.join(self._tree_repr_lines())
def analyze(self, filename: str | None = None, format: str | None = None) -> None:
from dask.dataframe.dask_expr._expr import Expr as DFExpr
from dask.dataframe.dask_expr.diagnostics import analyze
if not isinstance(self, DFExpr):
raise TypeError(
"analyze is only supported for dask.dataframe.Expr objects."
)
return analyze(self, filename=filename, format=format)
def explain(
self, stage: OptimizerStage = "fused", format: str | None = None
) -> None:
from dask.dataframe.dask_expr.diagnostics import explain
return explain(self, stage, format)
def pprint(self):
for line in self._tree_repr_lines():
print(line)
def __hash__(self):
return hash(self._name)
def __dask_tokenize__(self):
if not self._determ_token:
# If the subclass does not implement a __dask_tokenize__ we'll want
# to tokenize all operands.
# Note how this differs to the implementation of
# Expr.deterministic_token
self._determ_token = _tokenize_deterministic(type(self), *self.operands)
return self._determ_token
[docs]
def __dask_keys__(self):
"""The keys for this expression
This is used to determine the keys of the output collection
when this expression is computed.
Returns
-------
keys: list
The keys for this expression
"""
return [(self._name, i) for i in range(self.npartitions)]
@staticmethod
def _reconstruct(*args):
typ, *operands, token, cache = args
inst = typ(*operands, _determ_token=token)
for k, v in cache.items():
inst.__dict__[k] = v
return inst
def __reduce__(self):
if dask.config.get("dask-expr-no-serialize", False):
raise RuntimeError(f"Serializing a {type(self)} object")
cache = {}
if type(self)._pickle_functools_cache:
for k, v in type(self).__dict__.items():
if isinstance(v, functools.cached_property) and k in self.__dict__:
cache[k] = getattr(self, k)
return Expr._reconstruct, (
type(self),
*self.operands,
self.deterministic_token,
cache,
)
def _depth(self, cache=None):
"""Depth of the expression tree
Returns
-------
depth: int
"""
if cache is None:
cache = {}
if not self.dependencies():
return 1
else:
result = []
for expr in self.dependencies():
if expr._name in cache:
result.append(cache[expr._name])
else:
result.append(expr._depth(cache) + 1)
cache[expr._name] = result[-1]
return max(result)
def __setattr__(self, name: str, value: Any) -> None:
if name in ["operands", "_determ_token"]:
object.__setattr__(self, name, value)
return
try:
params = type(self)._parameters
operands = object.__getattribute__(self, "operands")
operands[params.index(name)] = value
except ValueError:
raise AttributeError(
f"{type(self).__name__} object has no attribute {name}"
)
def operand(self, key):
# Access an operand unambiguously
# (e.g. if the key is reserved by a method/property)
return self.operands[type(self)._parameters.index(key)]
def dependencies(self):
# Dependencies are `Expr` operands only
return [operand for operand in self.operands if isinstance(operand, Expr)]
[docs]
def _task(self, key: Key, index: int) -> Task:
"""The task for the i'th partition
Parameters
----------
index:
The index of the partition of this dataframe
Examples
--------
>>> class Add(Expr):
... def _task(self, i):
... return Task(
... self.__dask_keys__()[i],
... operator.add,
... TaskRef((self.left._name, i)),
... TaskRef((self.right._name, i))
... )
Returns
-------
task:
The Dask task to compute this partition
See Also
--------
Expr._layer
"""
raise NotImplementedError(
"Expressions should define either _layer (full dictionary) or _task"
f" (single task). This expression {type(self)} defines neither"
)
[docs]
def _layer(self) -> dict:
"""The graph layer added by this expression.
Simple expressions that apply one task per partition can choose to only
implement `Expr._task` instead.
Examples
--------
>>> class Add(Expr):
... def _layer(self):
... return {
... name: Task(
... name,
... operator.add,
... TaskRef((self.left._name, i)),
... TaskRef((self.right._name, i))
... )
... for i, name in enumerate(self.__dask_keys__())
... }
Returns
-------
layer: dict
The Dask task graph added by this expression
See Also
--------
Expr._task
Expr.__dask_graph__
"""
return {
(self._name, i): self._task((self._name, i), i)
for i in range(self.npartitions)
}
def rewrite(self, kind: str, rewritten):
"""Rewrite an expression
This leverages the ``._{kind}_down`` and ``._{kind}_up``
methods defined on each class
Returns
-------
expr:
output expression
changed:
whether or not any change occurred
"""
if self._name in rewritten:
return rewritten[self._name]
expr = self
down_name = f"_{kind}_down"
up_name = f"_{kind}_up"
while True:
_continue = False
# Rewrite this node
out = getattr(expr, down_name)()
if out is None:
out = expr
if not isinstance(out, Expr):
return out
if out._name != expr._name:
expr = out
continue
# Allow children to rewrite their parents
for child in expr.dependencies():
out = getattr(child, up_name)(expr)
if out is None:
out = expr
if not isinstance(out, Expr):
return out
if out is not expr and out._name != expr._name:
expr = out
_continue = True
break
if _continue:
continue
# Rewrite all of the children
new_operands = []
changed = False
for operand in expr.operands:
if isinstance(operand, Expr):
new = operand.rewrite(kind=kind, rewritten=rewritten)
rewritten[operand._name] = new
if new._name != operand._name:
changed = True
else:
new = operand
new_operands.append(new)
if changed:
expr = type(expr)(*new_operands)
continue
else:
break
return expr
def simplify_once(self, dependents: defaultdict, simplified: dict):
"""Simplify an expression
This leverages the ``._simplify_down`` and ``._simplify_up``
methods defined on each class
Parameters
----------
dependents: defaultdict[list]
The dependents for every node.
simplified: dict
Cache of simplified expressions for these dependents.
Returns
-------
expr:
output expression
"""
# Check if we've already simplified for these dependents
if self._name in simplified:
return simplified[self._name]
expr = self
while True:
out = expr._simplify_down()
if out is None:
out = expr
if not isinstance(out, Expr):
return out
if out._name != expr._name:
expr = out
# Allow children to simplify their parents
for child in expr.dependencies():
out = child._simplify_up(expr, dependents)
if out is None:
out = expr
if not isinstance(out, Expr):
return out
if out is not expr and out._name != expr._name:
expr = out
break
# Rewrite all of the children
new_operands = []
changed = False
for operand in expr.operands:
if isinstance(operand, Expr):
# Bandaid for now, waiting for Singleton
dependents[operand._name].append(weakref.ref(expr))
new = operand.simplify_once(
dependents=dependents, simplified=simplified
)
simplified[operand._name] = new
if new._name != operand._name:
changed = True
else:
new = operand
new_operands.append(new)
if changed:
expr = type(expr)(*new_operands)
break
return expr
def optimize(self, fuse: bool = False) -> Expr:
stage: OptimizerStage = "fused" if fuse else "simplified-physical"
return optimize_until(self, stage)
def fuse(self) -> Expr:
return self
def simplify(self) -> Expr:
expr = self
seen = set()
while True:
dependents = collect_dependents(expr)
new = expr.simplify_once(dependents=dependents, simplified={})
if new._name == expr._name:
break
if new._name in seen:
raise RuntimeError(
f"Optimizer does not converge. {expr!r} simplified to {new!r} which was already seen. "
"Please report this issue on the dask issue tracker with a minimal reproducer."
)
seen.add(new._name)
expr = new
return expr
def _simplify_down(self):
return
def _simplify_up(self, parent, dependents):
return
def lower_once(self, lowered: dict):
# Check for a cached result
try:
return lowered[self._name]
except KeyError:
pass
expr = self
# Lower this node
out = expr._lower()
if out is None:
out = expr
if not isinstance(out, Expr):
return out
# Lower all children
new_operands = []
changed = False
for operand in out.operands:
if isinstance(operand, Expr):
new = operand.lower_once(lowered)
if new._name != operand._name:
changed = True
else:
new = operand
new_operands.append(new)
if changed:
out = type(out)(*new_operands)
# Cache the result and return
return lowered.setdefault(self._name, out)
def lower_completely(self) -> Expr:
"""Lower an expression completely
This calls the ``lower_once`` method in a loop
until nothing changes. This function does not
apply any other optimizations (like ``simplify``).
Returns
-------
expr:
output expression
See Also
--------
Expr.lower_once
Expr._lower
"""
# Lower until nothing changes
expr = self
lowered: dict = {}
while True:
new = expr.lower_once(lowered)
if new._name == expr._name:
break
expr = new
return expr
def _lower(self):
return
@functools.cached_property
def _funcname(self) -> str:
return funcname(type(self)).lower()
@property
def deterministic_token(self):
if not self._determ_token:
# Just tokenize self to fall back on __dask_tokenize__
# Note how this differs to the implementation of __dask_tokenize__
self._determ_token = self.__dask_tokenize__()
return self._determ_token
@functools.cached_property
def _name(self) -> str:
return f"{self._funcname}-{self.deterministic_token}"
@property
def _meta(self):
raise NotImplementedError()
@classmethod
def _annotations_tombstone(cls) -> _AnnotationsTombstone:
return _AnnotationsTombstone()
def __dask_annotations__(self):
return {}
[docs]
def __dask_graph__(self):
"""Traverse expression tree, collect layers
Subclasses generally do not want to override this method unless custom
logic is required to treat (e.g. ignore) specific operands during graph
generation.
See also
--------
Expr._layer
Expr._task
"""
stack = [self]
seen = set()
layers = []
while stack:
expr = stack.pop()
if expr._name in seen:
continue
seen.add(expr._name)
layers.append(expr._layer())
for operand in expr.dependencies():
stack.append(operand)
return toolz.merge(layers)
@property
def dask(self):
return self.__dask_graph__()
def substitute(self, old, new) -> Expr:
"""Substitute a specific term within the expression
Note that replacing non-`Expr` terms may produce
unexpected results, and is not recommended.
Substituting boolean values is not allowed.
Parameters
----------
old:
Old term to find and replace.
new:
New term to replace instances of `old` with.
Examples
--------
>>> (df + 10).substitute(10, 20) # doctest: +SKIP
df + 20
"""
return self._substitute(old, new, _seen=set())
def _substitute(self, old, new, _seen):
if self._name in _seen:
return self
# Check if we are replacing a literal
if isinstance(old, Expr):
substitute_literal = False
if self._name == old._name:
return new
else:
substitute_literal = True
if isinstance(old, bool):
raise TypeError("Arguments to `substitute` cannot be bool.")
new_exprs = []
update = False
for operand in self.operands:
if isinstance(operand, Expr):
val = operand._substitute(old, new, _seen)
if operand._name != val._name:
update = True
new_exprs.append(val)
elif (
"Fused" in type(self).__name__
and isinstance(operand, list)
and all(isinstance(op, Expr) for op in operand)
):
# Special handling for `Fused`.
# We make no promise to dive through a
# list operand in general, but NEED to
# do so for the `Fused.exprs` operand.
val = []
for op in operand:
val.append(op._substitute(old, new, _seen))
if val[-1]._name != op._name:
update = True
new_exprs.append(val)
elif (
substitute_literal
and not isinstance(operand, bool)
and isinstance(operand, type(old))
and operand == old
):
new_exprs.append(new)
update = True
else:
new_exprs.append(operand)
if update: # Only recreate if something changed
return type(self)(*new_exprs)
else:
_seen.add(self._name)
return self
def substitute_parameters(self, substitutions: dict) -> Expr:
"""Substitute specific `Expr` parameters
Parameters
----------
substitutions:
Mapping of parameter keys to new values. Keys that
are not found in ``self._parameters`` will be ignored.
"""
if not substitutions:
return self
changed = False
new_operands = []
for i, operand in enumerate(self.operands):
if i < len(self._parameters) and self._parameters[i] in substitutions:
new_operands.append(substitutions[self._parameters[i]])
changed = True
else:
new_operands.append(operand)
if changed:
return type(self)(*new_operands)
return self
def _node_label_args(self):
"""Operands to include in the node label by `visualize`"""
return self.dependencies()
def _to_graphviz(
self,
rankdir="BT",
graph_attr=None,
node_attr=None,
edge_attr=None,
**kwargs,
):
from dask.dot import label, name
graphviz = import_required(
"graphviz",
"Drawing dask graphs with the graphviz visualization engine requires the `graphviz` "
"python library and the `graphviz` system library.\n\n"
"Please either conda or pip install as follows:\n\n"
" conda install python-graphviz # either conda install\n"
" python -m pip install graphviz # or pip install and follow installation instructions",
)
graph_attr = graph_attr or {}
node_attr = node_attr or {}
edge_attr = edge_attr or {}
graph_attr["rankdir"] = rankdir
node_attr["shape"] = "box"
node_attr["fontname"] = "helvetica"
graph_attr.update(kwargs)
g = graphviz.Digraph(
graph_attr=graph_attr,
node_attr=node_attr,
edge_attr=edge_attr,
)
stack = [self]
seen = set()
dependencies = {}
while stack:
expr = stack.pop()
if expr._name in seen:
continue
seen.add(expr._name)
dependencies[expr] = set(expr.dependencies())
for dep in expr.dependencies():
stack.append(dep)
cache = {}
for expr in dependencies:
expr_name = name(expr)
attrs = {}
# Make node label
deps = [
funcname(type(dep)) if isinstance(dep, Expr) else str(dep)
for dep in expr._node_label_args()
]
_label = funcname(type(expr))
if deps:
_label = f"{_label}({', '.join(deps)})" if deps else _label
node_label = label(_label, cache=cache)
attrs.setdefault("label", str(node_label))
attrs.setdefault("fontsize", "20")
g.node(expr_name, **attrs)
for expr, deps in dependencies.items():
expr_name = name(expr)
for dep in deps:
dep_name = name(dep)
g.edge(dep_name, expr_name)
return g
def visualize(self, filename="dask-expr.svg", format=None, **kwargs):
"""
Visualize the expression graph.
Requires ``graphviz`` to be installed.
Parameters
----------
filename : str or None, optional
The name of the file to write to disk. If the provided `filename`
doesn't include an extension, '.png' will be used by default.
If `filename` is None, no file will be written, and the graph is
rendered in the Jupyter notebook only.
format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
Format in which to write output file. Default is 'svg'.
**kwargs
Additional keyword arguments to forward to ``to_graphviz``.
"""
from dask.dot import graphviz_to_file
g = self._to_graphviz(**kwargs)
graphviz_to_file(g, filename, format)
return g
def walk(self) -> Generator[Expr]:
"""Iterate through all expressions in the tree
Returns
-------
nodes
Generator of Expr instances in the graph.
Ordering is a depth-first search of the expression tree
"""
stack = [self]
seen = set()
while stack:
node = stack.pop()
if node._name in seen:
continue
seen.add(node._name)
for dep in node.dependencies():
stack.append(dep)
yield node
def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]:
"""Search the expression graph for a specific operation type
Parameters
----------
operation
The operation type to search for.
Returns
-------
nodes
Generator of `operation` instances. Ordering corresponds
to a depth-first search of the expression graph.
"""
assert (
isinstance(operation, tuple)
and all(issubclass(e, Expr) for e in operation)
or issubclass(operation, Expr) # type: ignore[arg-type]
), "`operation` must be an `Expr` subclass)"
return (expr for expr in self.walk() if isinstance(expr, operation))
def __getattr__(self, key):
try:
return object.__getattribute__(self, key)
except AttributeError as err:
if key.startswith("_meta"):
# Avoid a recursive loop if/when `self._meta*`
# produces an `AttributeError`
raise RuntimeError(
f"Failed to generate metadata for {self}. "
"This operation may not be supported by the current backend."
)
# Allow operands to be accessed as attributes
# as long as the keys are not already reserved
# by existing methods/properties
_parameters = type(self)._parameters
if key in _parameters:
idx = _parameters.index(key)
return self.operands[idx]
raise AttributeError(
f"{err}\n\n"
"This often means that you are attempting to use an unsupported "
f"API function.."
)
class SingletonExpr(Expr):
"""A singleton Expr class
This is used to treat the subclassed expression as a singleton. Singletons
are deduplicated by expr._name which is typically based on the dask.tokenize
output.
This is a crucial performance optimization for expressions that walk through
an optimizer and are recreated repeatedly but isn't safe for objects that
cannot be reliably or quickly tokenized.
"""
_instances: weakref.WeakValueDictionary[str, SingletonExpr]
def __new__(cls, *args, _determ_token=None, **kwargs):
if not hasattr(cls, "_instances"):
cls._instances = weakref.WeakValueDictionary()
inst = super().__new__(cls, *args, _determ_token=_determ_token, **kwargs)
_name = inst._name
if _name in cls._instances and cls.__init__ == object.__init__:
return cls._instances[_name]
cls._instances[_name] = inst
return inst
def collect_dependents(expr) -> defaultdict:
dependents = defaultdict(list)
stack = [expr]
seen = set()
while stack:
node = stack.pop()
if node._name in seen:
continue
seen.add(node._name)
for dep in node.dependencies():
stack.append(dep)
dependents[dep._name].append(weakref.ref(node))
return dependents
def optimize(expr: Expr, fuse: bool = True) -> Expr:
"""High level query optimization
This leverages three optimization passes:
1. Class based simplification using the ``_simplify`` function and methods
2. Blockwise fusion
Parameters
----------
expr:
Input expression to optimize
fuse:
whether or not to turn on blockwise fusion
See Also
--------
simplify
optimize_blockwise_fusion
"""
stage: OptimizerStage = "fused" if fuse else "simplified-physical"
return optimize_until(expr, stage)
def optimize_until(expr: Expr, stage: OptimizerStage) -> Expr:
result = expr
if stage == "logical":
return result
# Simplify
expr = result.simplify()
if stage == "simplified-logical":
return expr
# Manipulate Expression to make it more efficient
if dask.config.get("optimization.tune.active", True):
expr = expr.rewrite(kind="tune", rewritten={})
if stage == "tuned-logical":
return expr
# Lower
expr = expr.lower_completely()
if stage == "physical":
return expr
# Simplify again
expr = expr.simplify()
if stage == "simplified-physical":
return expr
# Final graph-specific optimizations
expr = expr.fuse()
if stage == "fused":
return expr
raise ValueError(f"Stage {stage!r} not supported.")
class LLGExpr(Expr):
"""Low Level Graph Expression"""
_parameters = ["dsk"]
def __dask_keys__(self):
return list(self.operand("dsk"))
def _layer(self) -> dict:
return ensure_dict(self.operand("dsk"))
class HLGExpr(Expr):
_parameters = [
"dsk",
"low_level_optimizer",
"output_keys",
"postcompute",
"_cached_optimized",
]
_defaults = {
"low_level_optimizer": None,
"output_keys": None,
"postcompute": None,
"_cached_optimized": None,
}
@property
def hlg(self):
return self.operand("dsk")
@staticmethod
def from_collection(collection, optimize_graph=True):
from dask.highlevelgraph import HighLevelGraph
if hasattr(collection, "dask"):
dsk = collection.dask.copy()
else:
dsk = collection.__dask_graph__()
# Delayed objects still ship with low level graphs as `dask` when going
# through optimize / persist
if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(
str(id(collection)), dsk, dependencies=()
)
if optimize_graph and not hasattr(collection, "__dask_optimize__"):
warnings.warn(
f"Collection {type(collection)} does not define a "
"`__dask_optimize__` method. In the future this will raise. "
"If no optimization is desired, please set this to `None`.",
PendingDeprecationWarning,
)
low_level_optimizer = None
else:
low_level_optimizer = (
collection.__dask_optimize__ if optimize_graph else None
)
return HLGExpr(
dsk=dsk,
low_level_optimizer=low_level_optimizer,
output_keys=collection.__dask_keys__(),
postcompute=collection.__dask_postcompute__(),
)
def finalize_compute(self):
return HLGFinalizeCompute(
self,
low_level_optimizer=self.low_level_optimizer,
output_keys=self.output_keys,
postcompute=self.postcompute,
)
def __dask_annotations__(self) -> dict[str, dict[Key, object]]:
# optimization has to be called (and cached) since blockwise fusion can
# alter annotations
# see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)`
dsk = self._optimized_dsk
annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict)
for layer in dsk.layers.values():
if layer.annotations:
annot = layer.annotations
for annot_type, value in annot.items():
annotations_by_type[annot_type].update(
{k: (value(k) if callable(value) else value) for k in layer}
)
return dict(annotations_by_type)
def __dask_keys__(self):
if (keys := self.operand("output_keys")) is not None:
return keys
dsk = self.hlg
# Note: This will materialize
dependencies = dsk.get_all_dependencies()
leafs = set(dependencies)
for val in dependencies.values():
leafs -= val
self.output_keys = list(leafs)
return self.output_keys
@functools.cached_property
def _optimized_dsk(self) -> HighLevelGraph:
from dask.highlevelgraph import HighLevelGraph
optimizer = self.low_level_optimizer
keys = self.__dask_keys__()
dsk = self.hlg
if (optimizer := self.low_level_optimizer) is not None:
dsk = optimizer(dsk, keys)
return HighLevelGraph.merge(dsk)
@property
def deterministic_token(self):
if not self._determ_token:
self._determ_token = uuid.uuid4().hex
return self._determ_token
def _layer(self) -> dict:
dsk = self._optimized_dsk
return ensure_dict(dsk)
class _HLGExprGroup(HLGExpr):
# Identical to HLGExpr
# Used internally to determine how output keys are supposed to be returned
pass
class _HLGExprSequence(Expr):
def __getitem__(self, other):
return self.operands[other]
def _operands_for_repr(self):
return [
f"name={self.operand('name')!r}",
f"dsk={self.operand('dsk')!r}",
]
def _tree_repr_lines(self, indent=0, recursive=True):
return self._operands_for_repr()
def finalize_compute(self):
return _HLGExprSequence(*[op.finalize_compute() for op in self.operands])
def _tune_down(self):
if len(self.operands) == 1:
return None
from dask.highlevelgraph import HighLevelGraph
groups = toolz.groupby(
lambda x: x.low_level_optimizer if isinstance(x, HLGExpr) else None,
self.operands,
)
exprs = []
changed = False
for optimizer, group in groups.items():
if len(group) > 1:
graphs = [expr.hlg for expr in group]
changed = True
dsk = HighLevelGraph.merge(*graphs)
hlg_group = _HLGExprGroup(
dsk=dsk,
low_level_optimizer=optimizer,
output_keys=[v.__dask_keys__() for v in group],
postcompute=[g.postcompute for g in group],
)
exprs.append(hlg_group)
else:
exprs.append(group[0])
if not changed:
return None
return _HLGExprSequence(*exprs)
@functools.cached_property
def _optimized_dsk(self) -> HighLevelGraph:
from dask.highlevelgraph import HighLevelGraph
hlgexpr: HLGExpr
graphs = []
# simplify_down ensure there are only one HLGExpr per optimizer/finalizer
for hlgexpr in self.operands:
keys = hlgexpr.__dask_keys__()
dsk = hlgexpr.hlg
if (optimizer := hlgexpr.low_level_optimizer) is not None:
dsk = optimizer(dsk, keys)
graphs.append(dsk)
return HighLevelGraph.merge(*graphs)
def __dask_graph__(self):
# This class has to override this and not just _layer to ensure the HLGs
# are not optimized individually
return ensure_dict(self._optimized_dsk)
_layer = __dask_graph__
def __dask_annotations__(self) -> dict[str, dict[Key, object]]:
# optimization has to be called (and cached) since blockwise fusion can
# alter annotations
# see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)`
dsk = self._optimized_dsk
annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict)
for layer in dsk.layers.values():
if layer.annotations:
annot = layer.annotations
for annot_type, value in annot.items():
annots = list(
(k, (value(k) if callable(value) else value)) for k in layer
)
annotations_by_type[annot_type].update(
{
k: v
for k, v in annots
if not isinstance(v, _AnnotationsTombstone)
}
)
if not annotations_by_type[annot_type]:
del annotations_by_type[annot_type]
return dict(annotations_by_type)
def __dask_keys__(self) -> list:
all_keys = []
for op in self.operands:
if isinstance(op, _HLGExprGroup):
all_keys.extend(op.__dask_keys__())
else:
all_keys.append(op.__dask_keys__())
return all_keys
class _ExprSequence(Expr):
"""A sequence of expressions
This is used to be able to optimize multiple collections combined, e.g. when
being computed simultaneously with ``dask.compute((Expr1, Expr2))``.
"""
def __getitem__(self, other):
return self.operands[other]
def _layer(self) -> dict:
return toolz.merge(op._layer() for op in self.operands)
def __dask_keys__(self) -> list:
all_keys = []
for op in self.operands:
all_keys.append(list(op.__dask_keys__()))
return all_keys
def __repr__(self):
return "ExprSequence(" + ", ".join(map(repr, self.operands)) + ")"
__str__ = __repr__
def finalize_compute(self):
return _ExprSequence(
*(op.finalize_compute() for op in self.operands),
)
def __dask_annotations__(self):
annotations_by_type = {}
for op in self.operands:
for k, v in op.__dask_annotations__().items():
annotations_by_type.setdefault(k, {}).update(v)
return annotations_by_type
def __len__(self):
return len(self.operands)
def __iter__(self):
return iter(self.operands)
def _simplify_down(self):
from dask.highlevelgraph import HighLevelGraph
issue_warning = False
hlgs = []
if any(
isinstance(op, (HLGExpr, HLGFinalizeCompute, dict)) for op in self.operands
):
for op in self.operands:
if isinstance(op, (HLGExpr, HLGFinalizeCompute)):
hlgs.append(op)
elif isinstance(op, dict):
hlgs.append(
HLGExpr(
dsk=HighLevelGraph.from_collections(
str(id(op)), op, dependencies=()
)
)
)
else:
issue_warning = True
opt = op.optimize()
hlgs.append(
HLGExpr(
dsk=HighLevelGraph.from_collections(
opt._name, opt.__dask_graph__(), dependencies=()
)
)
)
if issue_warning:
warnings.warn(
"Computing mixed collections that are backed by "
"HighlevelGraphs/dicts and Expressions. "
"This forces Expressions to be materialized. "
"It is recommended to use only one type and separate the dask."
"compute calls if necessary.",
UserWarning,
)
if not hlgs:
return None
return _HLGExprSequence(*hlgs)
class _AnnotationsTombstone: ...
class FinalizeCompute(Expr):
_parameters = ["expr"]
def _simplify_down(self):
return self.expr.finalize_compute()
def _convert_dask_keys(keys):
from dask._task_spec import List, TaskRef
assert isinstance(keys, list)
new_keys = []
for key in keys:
if isinstance(key, list):
new_keys.append(_convert_dask_keys(key))
else:
new_keys.append(TaskRef(key))
return List(*new_keys)
class HLGFinalizeCompute(HLGExpr):
def _simplify_down(self):
if not self.postcompute:
return self.dsk
from dask.delayed import Delayed
# Skip finalization for Delayed
if self.dsk.postcompute == Delayed.__dask_postcompute__(self.dsk):
return self.dsk
return self
@property
def _name(self):
return f"finalize-{super()._name}"
def __dask_graph__(self):
# The base class __dask_graph__ will not just materialize this layer but
# also that of its dependencies, i.e. it will render the finalized and
# the non-finalized graph and combine them. We only want the finalized
# so we're overriding this.
# This is an artifact generated because the wrapped expression is
# identified automatically as a dependency but HLG expressions are not
# working in this layered way.
return self._layer()
@property
def hlg(self):
expr = self.operand("dsk")
layers = expr.dsk.layers.copy()
deps = expr.dsk.dependencies.copy()
keys = expr.__dask_keys__()
if isinstance(expr.postcompute, list):
postcomputes = expr.postcompute
else:
postcomputes = [expr.postcompute]
tasks = [
Task(self._name, func, _convert_dask_keys(keys), *extra_args)
for func, extra_args in postcomputes
]
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
leafs = set(deps)
for val in deps.values():
leafs -= val
for t in tasks:
layers[t.key] = MaterializedLayer({t.key: t})
deps[t.key] = leafs
return HighLevelGraph(layers, dependencies=deps)
def __dask_keys__(self):
return [self._name]
class ProhibitReuse(Expr):
"""
An expression that guarantees that all keys are suffixes with a unique id.
This can be used to break a common subexpression apart.
"""
_parameters = ["expr"]
_ALLOWED_TYPES = [HLGExpr, LLGExpr, HLGFinalizeCompute, _HLGExprSequence]
def __dask_keys__(self):
return self._modify_keys(self.expr.__dask_keys__())
@staticmethod
def _identity(obj):
return obj
@functools.cached_property
def _suffix(self):
return uuid.uuid4().hex
def _modify_keys(self, k):
if isinstance(k, list):
return [self._modify_keys(kk) for kk in k]
elif isinstance(k, tuple):
return (self._modify_keys(k[0]),) + k[1:]
elif isinstance(k, (int, float)):
k = str(k)
return f"{k}-{self._suffix}"
def _simplify_down(self):
# FIXME: Shuffling cannot be rewritten since the barrier key is
# hardcoded. Skipping this here should do the trick most of the time
if not isinstance(
self.expr,
tuple(self._ALLOWED_TYPES),
):
return self.expr
def __dask_graph__(self):
try:
from distributed.shuffle._core import P2PBarrierTask
except ModuleNotFoundError:
P2PBarrierTask = type(None)
dsk = convert_legacy_graph(self.expr.__dask_graph__())
subs = {old_key: self._modify_keys(old_key) for old_key in dsk}
dsk2 = {}
for old_key, new_key in subs.items():
t = dsk[old_key]
if isinstance(t, P2PBarrierTask):
warnings.warn(
"Cannot block reusing for graphs including a "
"P2PBarrierTask. This may cause unexpected results. "
"This typically happens when converting a dask "
"DataFrame to delayed objects.",
UserWarning,
)
return dsk
dsk2[new_key] = Task(
new_key,
ProhibitReuse._identity,
t.substitute(subs),
)
dsk2.update(dsk)
return dsk2
_layer = __dask_graph__