Source code for dask_expr._rolling

import functools
from collections import namedtuple
from numbers import Integral

import pandas as pd
from dask.utils import derived_from
from pandas.core.window import Rolling as pd_Rolling

from dask_expr._collection import new_collection
from dask_expr._expr import (
    Blockwise,
    Expr,
    MapOverlap,
    Projection,
    determine_column_projection,
    make_meta,
)

BlockwiseDep = namedtuple(typename="BlockwiseDep", field_names=["iterable"])


def _rolling_agg(
    frame,
    window,
    kwargs,
    how,
    how_args,
    how_kwargs,
    groupby_kwargs=None,
    groupby_slice=None,
):
    if groupby_kwargs is not None:
        frame = frame.groupby(**groupby_kwargs)
        if groupby_slice:
            frame = frame[groupby_slice]
    rolling = frame.rolling(window, **kwargs)
    result = getattr(rolling, how)(*how_args, **(how_kwargs or {}))
    if groupby_kwargs is not None:
        return result.sort_index(level=-1)
    return result


class RollingReduction(Expr):
    _parameters = [
        "frame",
        "window",
        "kwargs",
        "how_args",
        "how_kwargs",
        "groupby_kwargs",
        "groupby_slice",
    ]
    _defaults = {
        "kwargs": None,
        "how_args": (),
        "how_kwargs": None,
        "groupby_kwargs": None,
        "groupby_slice": None,
    }
    how = None

    @functools.cached_property
    def npartitions(self):
        return self.frame.npartitions

    def _divisions(self):
        return self.frame.divisions

    @functools.cached_property
    def _meta(self):
        meta = _rolling_agg(
            self.frame._meta,
            window=self.window,
            kwargs=self.kwargs,
            how=self.how,
            how_args=self.how_args,
            how_kwargs=self.how_kwargs,
            groupby_kwargs=self.groupby_kwargs,
            groupby_slice=self.groupby_slice,
        )
        return make_meta(meta)

    @functools.cached_property
    def kwargs(self):
        return {} if self.operand("kwargs") is None else self.operand("kwargs")

    def _simplify_up(self, parent, dependents):
        if isinstance(parent, Projection):
            by = self.groupby_kwargs.get("by", []) if self.groupby_kwargs else []
            by_columns = by if not isinstance(by, Expr) else []
            columns = determine_column_projection(self, parent, dependents, by_columns)
            columns = [col for col in self.frame.columns if col in columns]
            if columns == self.frame.columns:
                return
            if self.groupby_kwargs is not None:
                return type(parent)(
                    type(self)(self.frame[columns], *self.operands[1:]),
                    *parent.operands[1:],
                )
            if len(columns) == 1:
                columns = columns[0]
            return type(self)(self.frame[columns], *self.operands[1:])

    @property
    def _is_blockwise_op(self):
        return (
            self.kwargs.get("axis") in (1, "columns")
            or (isinstance(self.window, Integral) and self.window <= 1)
            or self.frame.npartitions == 1
        )

    def _lower(self):
        if self._is_blockwise_op:
            return RollingAggregation(
                self.frame,
                self.window,
                self.kwargs,
                self.how,
                list(self.how_args),
                self.how_kwargs,
                groupby_kwargs=self.groupby_kwargs,
                groupby_slice=self.groupby_slice,
            )

        if self.kwargs.get("center"):
            before = self.window // 2
            after = self.window - before - 1
        elif not isinstance(self.window, int):
            before = pd.Timedelta(self.window)
            after = 0
        else:
            before = self.window - 1
            after = 0

        return MapOverlap(
            frame=self.frame,
            func=_rolling_agg,
            before=before,
            after=after,
            meta=self._meta,
            enforce_metadata=True,
            kwargs=dict(
                window=self.window,
                kwargs=self.kwargs,
                how=self.how,
                how_args=self.how_args,
                how_kwargs=self.how_kwargs,
                groupby_kwargs=self.groupby_kwargs,
                groupby_slice=self.groupby_slice,
            ),
        )


class RollingAggregation(Blockwise):
    _parameters = [
        "frame",
        "window",
        "kwargs",
        "how",
        "how_args",
        "how_kwargs",
        "groupby_kwargs",
        "groupby_slice",
    ]

    operation = staticmethod(_rolling_agg)

    @functools.cached_property
    def _meta(self):
        return self.frame._meta


class RollingCount(RollingReduction):
    how = "count"


class RollingSum(RollingReduction):
    how = "sum"


class RollingMean(RollingReduction):
    how = "mean"


class RollingMin(RollingReduction):
    how = "min"


class RollingMax(RollingReduction):
    how = "max"


class RollingVar(RollingReduction):
    how = "var"


class RollingStd(RollingReduction):
    how = "std"


class RollingMedian(RollingReduction):
    how = "median"


class RollingQuantile(RollingReduction):
    how = "quantile"


class RollingSkew(RollingReduction):
    how = "skew"


class RollingKurt(RollingReduction):
    how = "kurt"


class RollingAgg(RollingReduction):
    how = "agg"

    def _simplify_up(self, parent, dependents):
        # Disable optimization in `agg`; function may access other columns
        return


class RollingApply(RollingReduction):
    how = "apply"


class RollingCov(RollingReduction):
    how = "cov"


class Rolling:
    """Aggregate using one or more operations

    The purpose of this class is to expose an API similar
    to Pandas' `Rolling` for dask-expr
    """

    def __init__(
        self,
        obj,
        window,
        groupby_kwargs=None,
        groupby_slice=None,
        min_periods=None,
        center=False,
        win_type=None,
    ):
        if obj.divisions[0] is None and len(obj.divisions) > 2:
            msg = (
                "Can only rolling dataframes with known divisions\n"
                "See https://docs.dask.org/en/latest/dataframe-design.html#partitions\n"
                "for more information."
            )
            raise ValueError(msg)
        self.obj = obj
        self.window = window
        self.groupby_kwargs = groupby_kwargs
        self.groupby_slice = groupby_slice
        self.min_periods = min_periods
        self.center = center
        self.win_type = win_type

        # Allow pandas to raise if appropriate
        obj._meta.rolling(window, **self.kwargs)

    @functools.cached_property
    def kwargs(self):
        return dict(
            min_periods=self.min_periods, center=self.center, win_type=self.win_type
        )

    def _single_agg(self, expr_cls, how_args=(), how_kwargs=None):
        return new_collection(
            expr_cls(
                self.obj,
                self.window,
                kwargs=self.kwargs,
                how_args=how_args,
                how_kwargs=how_kwargs,
                groupby_kwargs=self.groupby_kwargs,
                groupby_slice=self.groupby_slice,
            )
        )

    @derived_from(pd_Rolling)
    def cov(self):
        return self._single_agg(RollingCov)

[docs] @derived_from(pd_Rolling) def apply(self, func, *args, **kwargs): return self._single_agg(RollingApply, how_args=(func, *args), how_kwargs=kwargs)
[docs] @derived_from(pd_Rolling) def count(self): return self._single_agg(RollingCount)
[docs] @derived_from(pd_Rolling) def sum(self): return self._single_agg(RollingSum)
[docs] @derived_from(pd_Rolling) def mean(self): return self._single_agg(RollingMean)
[docs] @derived_from(pd_Rolling) def min(self): return self._single_agg(RollingMin)
[docs] @derived_from(pd_Rolling) def max(self): return self._single_agg(RollingMax)
[docs] @derived_from(pd_Rolling) def var(self): return self._single_agg(RollingVar)
[docs] @derived_from(pd_Rolling) def std(self): return self._single_agg(RollingStd)
[docs] @derived_from(pd_Rolling) def median(self): return self._single_agg(RollingMedian)
[docs] @derived_from(pd_Rolling) def quantile(self, q): return self._single_agg(RollingQuantile, how_args=(q,))
[docs] @derived_from(pd_Rolling) def skew(self): return self._single_agg(RollingSkew)
[docs] @derived_from(pd_Rolling) def kurt(self): return self._single_agg(RollingKurt)
@derived_from(pd_Rolling) def agg(self, func, *args, **kwargs): return self._single_agg(RollingAgg, how_args=(func, *args), how_kwargs=kwargs)