Source code for distributed.deploy.cluster

from __future__ import annotations

import asyncio
import datetime
import logging
import uuid
import warnings
from contextlib import suppress
from typing import Any

from packaging.version import parse as parse_version
from tornado.ioloop import IOLoop

import dask.config
from dask.utils import _deprecated, format_bytes, parse_timedelta, typename
from dask.widgets import get_template

from distributed.compatibility import PeriodicCallback
from distributed.core import Status
from distributed.deploy.adaptive import Adaptive
from distributed.metrics import time
from distributed.objects import SchedulerInfo
from distributed.utils import (

logger = logging.getLogger(__name__)

[docs]class Cluster(SyncMethodMixin): """Superclass for cluster objects This class contains common functionality for Dask Cluster manager classes. To implement this class, you must provide 1. A ``scheduler_comm`` attribute, which is a connection to the scheduler following the ``distributed.core.rpc`` API. 2. Implement ``scale``, which takes an integer and scales the cluster to that many workers, or else set ``_supports_scaling`` to False For that, you should get the following: 1. A standard ``__repr__`` 2. A live IPython widget 3. Adaptive scaling 4. Integration with dask-labextension 5. A ``scheduler_info`` attribute which contains an up-to-date copy of ``Scheduler.identity()``, which is used for much of the above 6. Methods to gather logs """ _supports_scaling = True __loop: IOLoop | None = None def __init__( self, asynchronous=False, loop=None, quiet=False, name=None, scheduler_sync_interval=1, ): self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.__asynchronous = asynchronous self.scheduler_info = {"workers": {}} self.periodic_callbacks = {} self._watch_worker_status_comm = None self._watch_worker_status_task = None self._cluster_manager_logs = [] self.quiet = quiet self.scheduler_comm = None self._adaptive = None self._sync_interval = parse_timedelta( scheduler_sync_interval, default="seconds" ) self._sync_cluster_info_task = None if name is None: name = str(uuid.uuid4())[:8] self._cluster_info = { "name": name, "type": typename(type(self)), } self.status = Status.created @property def loop(self) -> IOLoop | None: loop = self.__loop if loop is None: # If the loop is not running when this is called, the LoopRunner.loop # property will raise a DeprecationWarning # However subsequent calls might occur - eg atexit, where a stopped # loop is still acceptable - so we cache access to the loop. self.__loop = loop = self._loop_runner.loop return loop @loop.setter def loop(self, value: IOLoop) -> None: warnings.warn( "setting the loop property is deprecated", DeprecationWarning, stacklevel=2 ) if value is None: raise ValueError("expected an IOLoop, got None") self.__loop = value @property def called_from_running_loop(self): try: return ( getattr(self.loop, "asyncio_loop", None) is asyncio.get_running_loop() ) except RuntimeError: return self.__asynchronous @property def name(self): return self._cluster_info["name"] @name.setter def name(self, name): self._cluster_info["name"] = name async def _start(self): comm = await self.scheduler_comm.live_comm() = "Cluster worker status" await comm.write({"op": "subscribe_worker_status"}) self.scheduler_info = SchedulerInfo(await self._watch_worker_status_comm = comm self._watch_worker_status_task = asyncio.ensure_future( self._watch_worker_status(comm) ) info = await self.scheduler_comm.get_metadata( keys=["cluster-manager-info"], default={} ) self._cluster_info.update(info) # Start a background task for syncing cluster info with the scheduler self._sync_cluster_info_task = asyncio.ensure_future(self._sync_cluster_info()) for pc in self.periodic_callbacks.values(): pc.start() self.status = Status.running async def _sync_cluster_info(self): err_count = 0 warn_at = 5 max_interval = 10 * self._sync_interval # Loop until the cluster is shutting down. We shouldn't really need # this check (the `CancelledError` should be enough), but something # deep in the comms code is silencing `CancelledError`s _some_ of the # time, resulting in a cancellation not always bubbling back up to # here. Relying on the status is fine though, not worth changing. while self.status == Status.running: try: await self.scheduler_comm.set_metadata( keys=["cluster-manager-info"], value=self._cluster_info.copy(), ) err_count = 0 except Exception: err_count += 1 # Only warn if multiple subsequent attempts fail, and only once # per set of subsequent failed attempts. This way we're not # excessively noisy during a connection blip, but we also don't # silently fail. if err_count == warn_at: logger.warning( "Failed to sync cluster info multiple times - perhaps " "there's a connection issue? Error:", exc_info=True, ) # Sleep, with error backoff interval = _exponential_backoff( err_count, self._sync_interval, 1.5, max_interval ) await asyncio.sleep(interval) async def _close(self): if self.status == Status.closed: return self.status = Status.closing with suppress(AttributeError): self._adaptive.stop() if self._watch_worker_status_comm: await self._watch_worker_status_comm.close() if self._watch_worker_status_task: await self._watch_worker_status_task if self._sync_cluster_info_task: self._sync_cluster_info_task.cancel() with suppress(asyncio.CancelledError): await self._sync_cluster_info_task if self.scheduler_comm: await self.scheduler_comm.close_rpc() for pc in self.periodic_callbacks.values(): pc.stop() self.status = Status.closed def close(self, timeout: float | None = None) -> Any: # If the cluster is already closed, we're already done if self.status == Status.closed: if self.asynchronous: return NoOpAwaitable() return None try: return self.sync(self._close, callback_timeout=timeout) except RuntimeError: # loop closed during process shutdown return None def __del__(self, _warn=warnings.warn): if getattr(self, "status", Status.closed) != Status.closed: try: self_r = repr(self) except Exception: self_r = f"with a broken __repr__ {object.__repr__(self)}" _warn(f"unclosed cluster {self_r}", ResourceWarning, source=self) async def _watch_worker_status(self, comm): """Listen to scheduler for updates on adding and removing workers""" while True: try: msgs = await except OSError: break with log_errors(): for op, msg in msgs: self._update_worker_status(op, msg) await comm.close() def _update_worker_status(self, op, msg): if op == "add": workers = msg.pop("workers") self.scheduler_info["workers"].update(workers) self.scheduler_info.update(msg) elif op == "remove": del self.scheduler_info["workers"][msg] else: # pragma: no cover raise ValueError("Invalid op", op, msg) def adapt(self, Adaptive: type[Adaptive] = Adaptive, **kwargs: Any) -> Adaptive: """Turn on adaptivity For keyword arguments see dask.distributed.Adaptive Examples -------- >>> cluster.adapt(minimum=0, maximum=10, interval='500ms') """ with suppress(AttributeError): self._adaptive.stop() if not hasattr(self, "_adaptive_options"): self._adaptive_options = {} self._adaptive_options.update(kwargs) self._adaptive = Adaptive(self, **self._adaptive_options) return self._adaptive def scale(self, n: int) -> None: """Scale cluster to n workers Parameters ---------- n : int Target number of workers Examples -------- >>> cluster.scale(10) # scale cluster to ten workers """ raise NotImplementedError() def _log(self, log): """Log a message. Output a message to the user and also store for future retrieval. For use in subclasses where initialisation may take a while and it would be beneficial to feed back to the user. Examples -------- >>> self._log("Submitted job X to batch scheduler") """ self._cluster_manager_logs.append((, log)) if not self.quiet: print(log) async def _get_logs(self, cluster=True, scheduler=True, workers=True): logs = Logs() if cluster: logs["Cluster"] = Log( "\n".join(line[1] for line in self._cluster_manager_logs) ) if scheduler: L = await self.scheduler_comm.get_logs() logs["Scheduler"] = Log("\n".join(line for level, line in L)) if workers: if workers is True: workers = None d = await self.scheduler_comm.worker_logs(workers=workers) for k, v in d.items(): logs[k] = Log("\n".join(line for level, line in v)) return logs def get_logs(self, cluster=True, scheduler=True, workers=True): """Return logs for the cluster, scheduler and workers Parameters ---------- cluster : boolean Whether or not to collect logs for the cluster manager scheduler : boolean Whether or not to collect logs for the scheduler workers : boolean or Iterable[str], optional A list of worker addresses to select. Defaults to all workers if `True` or no workers if `False` Returns ------- logs: Dict[str] A dictionary of logs, with one item for the scheduler and one for each worker """ return self.sync( self._get_logs, cluster=cluster, scheduler=scheduler, workers=workers ) @_deprecated(use_instead="get_logs") def logs(self, *args, **kwargs): return self.get_logs(*args, **kwargs) def get_client(self): """Return client for the cluster If a client has already been initialized for the cluster, return that otherwise initialize a new client object. """ from distributed.client import Client try: current_client = Client.current() if current_client and current_client.cluster == self: return current_client except ValueError: pass return Client(self) @property def dashboard_link(self): try: port = self.scheduler_info["services"]["dashboard"] except KeyError: return "" else: host = self.scheduler_address.split("://")[1].split("/")[0].split(":")[0] return format_dashboard_link(host, port) def _scaling_status(self): if self._adaptive and self._adaptive.periodic_callback: mode = "Adaptive" else: mode = "Manual" workers = len(self.scheduler_info["workers"]) if hasattr(self, "worker_spec"): requested = sum( 1 if "group" not in each else len(each["group"]) for each in self.worker_spec.values() ) elif hasattr(self, "workers"): requested = len(self.workers) else: requested = workers worker_count = workers if workers == requested else f"{workers} / {requested}" return f""" <table> <tr><td style="text-align: left;">Scaling mode: {mode}</td></tr> <tr><td style="text-align: left;">Workers: {worker_count}</td></tr> </table> """ def _widget(self): """Create IPython widget for display within a notebook""" try: return self._cached_widget except AttributeError: pass try: from ipywidgets import ( HTML, Accordion, Button, HBox, IntText, Layout, Tab, VBox, ) except ImportError: self._cached_widget = None return None layout = Layout(width="150px") status = HTML(self._repr_html_()) if self._supports_scaling: request = IntText(0, description="Workers", layout=layout) scale = Button(description="Scale", layout=layout) minimum = IntText(0, description="Minimum", layout=layout) maximum = IntText(0, description="Maximum", layout=layout) adapt = Button(description="Adapt", layout=layout) accordion = Accordion( [HBox([request, scale]), HBox([minimum, maximum, adapt])], layout=Layout(min_width="500px"), ) accordion.selected_index = None accordion.set_title(0, "Manual Scaling") accordion.set_title(1, "Adaptive Scaling") def adapt_cb(b): self.adapt(minimum=minimum.value, maximum=maximum.value) update() adapt.on_click(adapt_cb) @log_errors def scale_cb(b): n = request.value with suppress(AttributeError): self._adaptive.stop() self.scale(n) update() scale.on_click(scale_cb) else: # pragma: no cover accordion = HTML("") scale_status = HTML(self._scaling_status()) tab = Tab() tab.children = [status, VBox([scale_status, accordion])] tab.set_title(0, "Status") tab.set_title(1, "Scaling") self._cached_widget = tab def update(): status.value = self._repr_html_() scale_status.value = self._scaling_status() cluster_repr_interval = parse_timedelta( dask.config.get("distributed.deploy.cluster-repr-interval", default="ms") ) def install(): pc = PeriodicCallback(update, cluster_repr_interval * 1000) self.periodic_callbacks["cluster-repr"] = pc pc.start() self.loop.add_callback(install) return tab def _repr_html_(self, cluster_status=None): try: scheduler_info_repr = self.scheduler_info._repr_html_() except AttributeError: scheduler_info_repr = "Scheduler not started yet." return get_template("cluster.html.j2").render( type=type(self).__name__,, workers=self.scheduler_info["workers"], dashboard_link=self.dashboard_link, scheduler_info_repr=scheduler_info_repr, cluster_status=cluster_status, ) def _ipython_display_(self, **kwargs): """Display the cluster rich IPython repr""" # Note: it would be simpler to just implement _repr_mimebundle_, # but we cannot do that until we drop ipywidgets 7 support, as # it does not provide a public way to get the mimebundle for a # widget. So instead we fall back on the more customizable _ipython_display_ # and display as a side-effect. from IPython.display import display widget = self._widget() if widget: import ipywidgets if parse_version(ipywidgets.__version__) >= parse_version("8.0.0"): mimebundle = widget._repr_mimebundle_(**kwargs) or {} mimebundle["text/plain"] = repr(self) mimebundle["text/html"] = self._repr_html_() display(mimebundle, raw=True) else: display(widget, **kwargs) else: mimebundle = {"text/plain": repr(self), "text/html": self._repr_html_()} display(mimebundle, raw=True) def __enter__(self): if self.asynchronous: raise TypeError( "Used 'with' with asynchronous class; please use 'async with'" ) return self.sync(self.__aenter__) def __exit__(self, exc_type, exc_value, traceback): aw = self.close() assert aw is None, aw def __await__(self): return self yield async def __aenter__(self): await self return self async def __aexit__(self, exc_type, exc_value, traceback): await self.close() @property def scheduler_address(self) -> str: if not self.scheduler_comm: return "<Not Connected>" return self.scheduler_comm.address @property def _cluster_class_name(self): return getattr(self, "_name", type(self).__name__) def __repr__(self): text = "%s(%s, %r, workers=%d, threads=%d" % ( self._cluster_class_name,, self.scheduler_address, len(self.scheduler_info["workers"]), sum(w["nthreads"] for w in self.scheduler_info["workers"].values()), ) memory = [w["memory_limit"] for w in self.scheduler_info["workers"].values()] if all(memory): text += ", memory=" + format_bytes(sum(memory)) text += ")" return text @property def plan(self): return set(self.workers) @property def requested(self): return set(self.workers) @property def observed(self): return {d["name"] for d in self.scheduler_info["workers"].values()} def __eq__(self, other): return type(other) == type(self) and == def __hash__(self): return id(self) async def _wait_for_workers(self, n_workers=0, timeout=None): self.scheduler_info = SchedulerInfo(await self.scheduler_comm.identity()) if timeout: deadline = time() + parse_timedelta(timeout) else: deadline = None def running_workers(info): return len( [ ws for ws in info["workers"].values() if ws["status"] == ] ) while n_workers and running_workers(self.scheduler_info) < n_workers: if deadline and time() > deadline: raise TimeoutError( "Only %d/%d workers arrived after %s" % (running_workers(self.scheduler_info), n_workers, timeout) ) await asyncio.sleep(0.1) self.scheduler_info = SchedulerInfo(await self.scheduler_comm.identity()) def wait_for_workers(self, n_workers: int, timeout: float | None = None) -> None: """Blocking call to wait for n workers before continuing Parameters ---------- n_workers : int The number of workers timeout : number, optional Time in seconds after which to raise a ``dask.distributed.TimeoutError`` """ if not isinstance(n_workers, int) or n_workers < 1: raise ValueError( f"`n_workers` must be a positive integer. Instead got {n_workers}." ) return self.sync(self._wait_for_workers, n_workers, timeout=timeout)
def _exponential_backoff( attempt: int, multiplier: float, exponential_base: float, max_interval: float ) -> float: """Calculate the duration of an exponential backoff""" try: interval = multiplier * exponential_base**attempt except OverflowError: return max_interval return min(max_interval, interval)