Source code for dask.rewrite

from __future__ import annotations

from collections import deque

from dask.core import istask, subs

def head(task):
    """Return the top level node of a task"""

    if istask(task):
        return task[0]
    elif isinstance(task, list):
        return list
        return task

def args(task):
    """Get the arguments for the current task"""

    if istask(task):
        return task[1:]
    elif isinstance(task, list):
        return task
        return ()

class Traverser:
    """Traverser interface for tasks.

    Class for storing the state while performing a preorder-traversal of a

    term : task
        The task to be traversed

        The current element in the traversal
        The head of the current element in the traversal. This is simply `head`
        applied to the attribute `term`.

    def __init__(self, term, stack=None):
        self.term = term
        if not stack:
            self._stack = deque([END])
            self._stack = stack

    def __iter__(self):
        while self.current is not END:
            yield self.current

    def copy(self):
        """Copy the traverser in its current state.

        This allows the traversal to be pushed onto a stack, for easy

        return Traverser(self.term, deque(self._stack))

    def next(self):
        """Proceed to the next term in the preorder traversal."""

        subterms = args(self.term)
        if not subterms:
            # No subterms, pop off stack
            self.term = self._stack.pop()
            self.term = subterms[0]

    def current(self):
        return head(self.term)

    def skip(self):
        """Skip over all subterms of the current level in the traversal"""
        self.term = self._stack.pop()

class Token:
    """A token object.

    Used to express certain objects in the traversal of a task or pattern."""

    def __init__(self, name): = name

    def __repr__(self):

# A variable to represent *all* variables in a discrimination net
VAR = Token("?")
# Represents the end of the traversal of an expression. We can't use `None`,
# 'False', etc... here, as anything may be an argument to a function.
END = Token("end")

class Node(tuple):
    """A Discrimination Net node."""

    __slots__ = ()

    def __new__(cls, edges=None, patterns=None):
        edges = edges if edges else {}
        patterns = patterns if patterns else []
        return tuple.__new__(cls, (edges, patterns))

    def edges(self):
        """A dictionary, where the keys are edges, and the values are nodes"""
        return self[0]

    def patterns(self):
        """A list of all patterns that currently match at this node"""
        return self[1]

[docs]class RewriteRule: """A rewrite rule. Expresses `lhs` -> `rhs`, for variables `vars`. Parameters ---------- lhs : task The left-hand-side of the rewrite rule. rhs : task or function The right-hand-side of the rewrite rule. If it's a task, variables in `rhs` will be replaced by terms in the subject that match the variables in `lhs`. If it's a function, the function will be called with a dict of such matches. vars: tuple, optional Tuple of variables found in the lhs. Variables can be represented as any hashable object; a good convention is to use strings. If there are no variables, this can be omitted. Examples -------- Here's a `RewriteRule` to replace all nested calls to `list`, so that `(list, (list, 'x'))` is replaced with `(list, 'x')`, where `'x'` is a variable. >>> import dask.rewrite as dr >>> lhs = (list, (list, 'x')) >>> rhs = (list, 'x') >>> variables = ('x',) >>> rule = dr.RewriteRule(lhs, rhs, variables) Here's a more complicated rule that uses a callable right-hand-side. A callable `rhs` takes in a dictionary mapping variables to their matching values. This rule replaces all occurrences of `(list, 'x')` with `'x'` if `'x'` is a list itself. >>> lhs = (list, 'x') >>> def repl_list(sd): ... x = sd['x'] ... if isinstance(x, list): ... return x ... else: ... return (list, x) >>> rule = dr.RewriteRule(lhs, repl_list, variables) """ def __init__(self, lhs, rhs, vars=()): if not isinstance(vars, tuple): raise TypeError("vars must be a tuple of variables") self.lhs = lhs if callable(rhs): self.subs = rhs else: self.subs = self._apply self.rhs = rhs self._varlist = [t for t in Traverser(lhs) if t in vars] # Reduce vars down to just variables found in lhs self.vars = tuple(sorted(set(self._varlist))) def _apply(self, sub_dict): term = self.rhs for key, val in sub_dict.items(): term = subs(term, key, val) return term def __str__(self): return f"RewriteRule({self.lhs}, {self.rhs}, {self.vars})" def __repr__(self): return str(self)
[docs]class RuleSet: """A set of rewrite rules. Forms a structure for fast rewriting over a set of rewrite rules. This allows for syntactic matching of terms to patterns for many patterns at the same time. Examples -------- >>> import dask.rewrite as dr >>> def f(*args): pass >>> def g(*args): pass >>> def h(*args): pass >>> from operator import add >>> rs = dr.RuleSet( ... dr.RewriteRule((add, 'x', 0), 'x', ('x',)), ... dr.RewriteRule((f, (g, 'x'), 'y'), ... (h, 'x', 'y'), ... ('x', 'y'))) >>> rs.rewrite((add, 2, 0)) 2 >>> rs.rewrite((f, (g, 'a', 3))) # doctest: +ELLIPSIS (<function h at ...>, 'a', 3) >>> dsk = {'a': (add, 2, 0), ... 'b': (f, (g, 'a', 3))} >>> from toolz import valmap >>> valmap(rs.rewrite, dsk) # doctest: +ELLIPSIS {'a': 2, 'b': (<function h at ...>, 'a', 3)} Attributes ---------- rules : list A list of `RewriteRule`s included in the `RuleSet`. """ def __init__(self, *rules): """Create a `RuleSet` for a number of rules Parameters ---------- rules One or more instances of RewriteRule """ self._net = Node() self.rules = [] for p in rules: self.add(p) def add(self, rule): """Add a rule to the RuleSet. Parameters ---------- rule : RewriteRule """ if not isinstance(rule, RewriteRule): raise TypeError("rule must be instance of RewriteRule") vars = rule.vars curr_node = self._net ind = len(self.rules) # List of variables, in order they appear in the POT of the term for t in Traverser(rule.lhs): prev_node = curr_node if t in vars: t = VAR if t in curr_node.edges: curr_node = curr_node.edges[t] else: curr_node.edges[t] = Node() curr_node = curr_node.edges[t] # We've reached a leaf node. Add the term index to this leaf. prev_node.edges[t].patterns.append(ind) self.rules.append(rule) def iter_matches(self, term): """A generator that lazily finds matchings for term from the RuleSet. Parameters ---------- term : task Yields ------ Tuples of `(rule, subs)`, where `rule` is the rewrite rule being matched, and `subs` is a dictionary mapping the variables in the lhs of the rule to their matching values in the term.""" S = Traverser(term) for m, syms in _match(S, self._net): for i in m: rule = self.rules[i] subs = _process_match(rule, syms) if subs is not None: yield rule, subs def _rewrite(self, term): """Apply the rewrite rules in RuleSet to top level of term""" for rule, sd in self.iter_matches(term): # We use for (...) because it's fast in all cases for getting the # first element from the match iterator. As we only want that # element, we break here term = rule.subs(sd) break return term def rewrite(self, task, strategy="bottom_up"): """Apply the `RuleSet` to `task`. This applies the most specific matching rule in the RuleSet to the task, using the provided strategy. Parameters ---------- term: a task The task to be rewritten strategy: str, optional The rewriting strategy to use. Options are "bottom_up" (default), or "top_level". Examples -------- Suppose there was a function `add` that returned the sum of 2 numbers, and another function `double` that returned twice its input: >>> add = lambda x, y: x + y >>> double = lambda x: 2*x Now suppose `double` was *significantly* faster than `add`, so you'd like to replace all expressions `(add, x, x)` with `(double, x)`, where `x` is a variable. This can be expressed as a rewrite rule: >>> rule = RewriteRule((add, 'x', 'x'), (double, 'x'), ('x',)) >>> rs = RuleSet(rule) This can then be applied to terms to perform the rewriting: >>> term = (add, (add, 2, 2), (add, 2, 2)) >>> rs.rewrite(term) # doctest: +SKIP (double, (double, 2)) If we only wanted to apply this to the top level of the term, the `strategy` kwarg can be set to "top_level". >>> rs.rewrite(term) # doctest: +SKIP (double, (add, 2, 2)) """ return strategies[strategy](self, task)
def _top_level(net, term): return net._rewrite(term) def _bottom_up(net, term): if istask(term): term = (head(term),) + tuple(_bottom_up(net, t) for t in args(term)) elif isinstance(term, list): term = [_bottom_up(net, t) for t in args(term)] return net._rewrite(term) strategies = {"top_level": _top_level, "bottom_up": _bottom_up} def _match(S, N): """Structural matching of term S to discrimination net node N.""" stack = deque() restore_state_flag = False # matches are stored in a tuple, because all mutations result in a copy, # preventing operations from changing matches stored on the stack. matches = () while True: if S.current is END: yield N.patterns, matches try: # This try-except block is to catch hashing errors from un-hashable # types. This allows for variables to be matched with un-hashable # objects. n = N.edges.get(S.current, None) if n and not restore_state_flag: stack.append((S.copy(), N, matches)) N = n continue except TypeError: pass n = N.edges.get(VAR, None) if n: restore_state_flag = False matches = matches + (S.term,) S.skip() N = n continue try: # Backtrack here (S, N, matches) = stack.pop() restore_state_flag = True except Exception: return def _process_match(rule, syms): """Process a match to determine if it is correct, and to find the correct substitution that will convert the term into the pattern. Parameters ---------- rule : RewriteRule syms : iterable Iterable of subterms that match a corresponding variable. Returns ------- A dictionary of {vars : subterms} describing the substitution to make the pattern equivalent with the term. Returns `None` if the match is invalid.""" subs = {} varlist = rule._varlist if not len(varlist) == len(syms): raise RuntimeError("length of varlist doesn't match length of syms.") for v, s in zip(varlist, syms): if v in subs and subs[v] != s: return None else: subs[v] = s return subs