from __future__ import annotations
import logging
from collections.abc import Hashable
from datetime import timedelta
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Callable, Literal, cast
from tornado.ioloop import IOLoop
import dask.config
from dask.utils import parse_timedelta
from distributed.compatibility import PeriodicCallback
from distributed.core import Status
from distributed.deploy.adaptive_core import AdaptiveCore
from distributed.protocol import pickle
from distributed.utils import log_errors
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from distributed.deploy.cluster import Cluster
from distributed.scheduler import WorkerState
logger = logging.getLogger(__name__)
AdaptiveStateState: TypeAlias = Literal[
"starting",
"running",
"stopped",
"inactive",
]
[docs]class Adaptive(AdaptiveCore):
'''
Adaptively allocate workers based on scheduler load. A superclass.
Contains logic to dynamically resize a Dask cluster based on current use.
This class needs to be paired with a system that can create and destroy
Dask workers using a cluster resource manager. Typically it is built into
already existing solutions, rather than used directly by users.
It is most commonly used from the ``.adapt(...)`` method of various Dask
cluster classes.
Parameters
----------
cluster: object
Must have scale and scale_down methods/coroutines
interval : timedelta or str, default "1000 ms"
Milliseconds between checks
wait_count: int, default 3
Number of consecutive times that a worker should be suggested for
removal before we remove it.
target_duration: timedelta or str, default "5s"
Amount of time we want a computation to take.
This affects how aggressively we scale up.
worker_key: Callable[WorkerState]
Function to group workers together when scaling down
See Scheduler.workers_to_close for more information
minimum: int
Minimum number of workers to keep around
maximum: int
Maximum number of workers to keep around
**kwargs:
Extra parameters to pass to Scheduler.workers_to_close
Examples
--------
This is commonly used from existing Dask classes, like KubeCluster
>>> from dask_kubernetes import KubeCluster
>>> cluster = KubeCluster()
>>> cluster.adapt(minimum=10, maximum=100)
Alternatively you can use it from your own Cluster class by subclassing
from Dask's Cluster superclass
>>> from distributed.deploy import Cluster
>>> class MyCluster(Cluster):
... def scale_up(self, n):
... """ Bring worker count up to n """
... def scale_down(self, workers):
... """ Remove worker addresses from cluster """
>>> cluster = MyCluster()
>>> cluster.adapt(minimum=10, maximum=100)
Notes
-----
Subclasses can override :meth:`Adaptive.target` and
:meth:`Adaptive.workers_to_close` to control when the cluster should be
resized. The default implementation checks if there are too many tasks
per worker or too little memory available (see
:meth:`distributed.Scheduler.adaptive_target`).
The values for interval, min, max, wait_count and target_duration can be
specified in the dask config under the distributed.adaptive key.
'''
interval: float | None
periodic_callback: PeriodicCallback | None
#: Whether this adaptive strategy is periodically adapting
state: AdaptiveStateState
def __init__(
self,
cluster: Cluster,
interval: str | float | timedelta | None = None,
minimum: int | None = None,
maximum: int | float | None = None,
wait_count: int | None = None,
target_duration: str | float | timedelta | None = None,
worker_key: Callable[[WorkerState], Hashable] | None = None,
**kwargs: Any,
):
self.cluster = cluster
self.worker_key = worker_key
self._workers_to_close_kwargs = kwargs
if interval is None:
interval = dask.config.get("distributed.adaptive.interval")
if minimum is None:
minimum = cast(int, dask.config.get("distributed.adaptive.minimum"))
if maximum is None:
maximum = cast(float, dask.config.get("distributed.adaptive.maximum"))
if wait_count is None:
wait_count = cast(int, dask.config.get("distributed.adaptive.wait-count"))
if target_duration is None:
target_duration = cast(
str, dask.config.get("distributed.adaptive.target-duration")
)
self.interval = parse_timedelta(interval, "seconds")
self.periodic_callback = None
if self.interval and self.cluster:
import weakref
self_ref = weakref.ref(self)
async def _adapt():
adaptive = self_ref()
if not adaptive or adaptive.state != "running":
return
if adaptive.cluster.status != Status.running:
adaptive.stop(reason="cluster-not-running")
return
try:
await adaptive.adapt()
except Exception:
logger.warning(
"Adaptive encountered an error while adapting", exc_info=True
)
self.periodic_callback = PeriodicCallback(_adapt, self.interval * 1000)
self.state = "starting"
self.loop.add_callback(self._start)
else:
self.state = "inactive"
self.target_duration = parse_timedelta(target_duration)
super().__init__(minimum=minimum, maximum=maximum, wait_count=wait_count)
def _start(self) -> None:
if self.state != "starting":
return
assert self.periodic_callback is not None
self.periodic_callback.start()
self.state = "running"
logger.info(
"Adaptive scaling started: minimum=%s maximum=%s",
self.minimum,
self.maximum,
)
def stop(self, reason: str = "unknown") -> None:
if self.state in ("inactive", "stopped"):
return
if self.state == "running":
assert self.periodic_callback is not None
self.periodic_callback.stop()
logger.info(
"Adaptive scaling stopped: minimum=%s maximum=%s. Reason: %s",
self.minimum,
self.maximum,
reason,
)
self.periodic_callback = None
self.state = "stopped"
@property
def scheduler(self):
return self.cluster.scheduler_comm
@property
def plan(self):
return self.cluster.plan
@property
def requested(self):
return self.cluster.requested
@property
def observed(self):
return self.cluster.observed
async def target(self):
"""
Determine target number of workers that should exist.
Notes
-----
``Adaptive.target`` dispatches to Scheduler.adaptive_target(),
but may be overridden in subclasses.
Returns
-------
Target number of workers
See Also
--------
Scheduler.adaptive_target
"""
return await self.scheduler.adaptive_target(
target_duration=self.target_duration
)
async def recommendations(self, target: int) -> dict:
if len(self.plan) != len(self.requested):
# Ensure that the number of planned and requested workers
# are in sync before making recommendations.
await self.cluster
return await super().recommendations(target)
async def workers_to_close(self, target: int) -> list[str]:
"""
Determine which, if any, workers should potentially be removed from
the cluster.
Notes
-----
``Adaptive.workers_to_close`` dispatches to Scheduler.workers_to_close(),
but may be overridden in subclasses.
Returns
-------
List of worker names to close, if any
See Also
--------
Scheduler.workers_to_close
"""
return await self.scheduler.workers_to_close(
target=target,
key=pickle.dumps(self.worker_key) if self.worker_key else None,
attribute="name",
**self._workers_to_close_kwargs,
)
@log_errors
async def scale_down(self, workers):
if not workers:
return
logger.info("Retiring workers %s", workers)
# Ask scheduler to cleanly retire workers
await self.scheduler.retire_workers(
names=workers,
remove=True,
close_workers=True,
)
# close workers more forcefully
f = self.cluster.scale_down(workers)
if isawaitable(f):
await f
async def scale_up(self, n):
f = self.cluster.scale(n)
if isawaitable(f):
await f
@property
def loop(self) -> IOLoop:
"""Override Adaptive.loop"""
if self.cluster:
return self.cluster.loop # type: ignore[return-value]
else:
return IOLoop.current()
def __del__(self):
self.stop(reason="adaptive-deleted")