from __future__ import annotations
from functools import partial
from itertools import product
import numpy as np
from tlz import curry
from dask.array.backends import array_creation_dispatch
from dask.array.core import Array, normalize_chunks
from dask.array.utils import meta_from_array
from dask.base import tokenize
from dask.blockwise import blockwise as core_blockwise
from dask.layers import ArrayChunkShapeDep
from dask.utils import funcname
def _parse_wrap_args(func, args, kwargs, shape):
if isinstance(shape, np.ndarray):
shape = shape.tolist()
if not isinstance(shape, (tuple, list)):
shape = (shape,)
name = kwargs.pop("name", None)
chunks = kwargs.pop("chunks", "auto")
dtype = kwargs.pop("dtype", None)
if dtype is None:
dtype = func(shape, *args, **kwargs).dtype
dtype = np.dtype(dtype)
chunks = normalize_chunks(chunks, shape, dtype=dtype)
name = name or funcname(func) + "-" + tokenize(
func, shape, chunks, dtype, args, kwargs
)
return {
"shape": shape,
"dtype": dtype,
"kwargs": kwargs,
"chunks": chunks,
"name": name,
}
def wrap_func_shape_as_first_arg(func, *args, **kwargs):
"""
Transform np creation function into blocked version
"""
if "shape" not in kwargs:
shape, args = args[0], args[1:]
else:
shape = kwargs.pop("shape")
if isinstance(shape, Array):
raise TypeError(
"Dask array input not supported. "
"Please use tuple, list, or a 1D numpy array instead."
)
parsed = _parse_wrap_args(func, args, kwargs, shape)
shape = parsed["shape"]
dtype = parsed["dtype"]
chunks = parsed["chunks"]
name = parsed["name"]
kwargs = parsed["kwargs"]
func = partial(func, dtype=dtype, **kwargs)
out_ind = dep_ind = tuple(range(len(shape)))
graph = core_blockwise(
func,
name,
out_ind,
ArrayChunkShapeDep(chunks),
dep_ind,
numblocks={},
)
return Array(graph, name, chunks, dtype=dtype, meta=kwargs.get("meta", None))
def wrap_func_like(func, *args, **kwargs):
"""
Transform np creation function into blocked version
"""
x = args[0]
meta = meta_from_array(x)
shape = kwargs.get("shape", x.shape)
parsed = _parse_wrap_args(func, args, kwargs, shape)
shape = parsed["shape"]
dtype = parsed["dtype"]
chunks = parsed["chunks"]
name = parsed["name"]
kwargs = parsed["kwargs"]
keys = product([name], *[range(len(bd)) for bd in chunks])
shapes = product(*chunks)
shapes = list(shapes)
kw = [kwargs for _ in shapes]
for i, s in enumerate(list(shapes)):
kw[i]["shape"] = s
vals = ((partial(func, dtype=dtype, **k),) + args for (k, s) in zip(kw, shapes))
dsk = dict(zip(keys, vals))
return Array(dsk, name, chunks, meta=meta.astype(dtype))
@curry
def wrap(wrap_func, func, func_like=None, **kwargs):
if func_like is None:
f = partial(wrap_func, func, **kwargs)
else:
f = partial(wrap_func, func_like, **kwargs)
template = """
Blocked variant of %(name)s
Follows the signature of %(name)s exactly except that it also features
optional keyword arguments ``chunks: int, tuple, or dict`` and ``name: str``.
Original signature follows below.
"""
if func.__doc__ is not None:
f.__doc__ = template % {"name": func.__name__} + func.__doc__
f.__name__ = "blocked_" + func.__name__
return f
w = wrap(wrap_func_shape_as_first_arg)
@curry
def _broadcast_trick_inner(func, shape, meta=(), *args, **kwargs):
# cupy-specific hack. numpy is happy with hardcoded shape=().
null_shape = () if shape == () else 1
return np.broadcast_to(func(meta, *args, shape=null_shape, **kwargs), shape)
def broadcast_trick(func):
"""
Provide a decorator to wrap common numpy function with a broadcast trick.
Dask arrays are currently immutable; thus when we know an array is uniform,
we can replace the actual data by a single value and have all elements point
to it, thus reducing the size.
>>> x = np.broadcast_to(1, (100,100,100))
>>> x.base.nbytes
8
Those array are not only more efficient locally, but dask serialisation is
aware of the _real_ size of those array and thus can send them around
efficiently and schedule accordingly.
Note that those array are read-only and numpy will refuse to assign to them,
so should be safe.
"""
inner = _broadcast_trick_inner(func)
inner.__doc__ = func.__doc__
inner.__name__ = func.__name__
return inner
ones = array_creation_dispatch.register_inplace(
backend="numpy",
name="ones",
)(w(broadcast_trick(np.ones_like), dtype="f8"))
zeros = array_creation_dispatch.register_inplace(
backend="numpy",
name="zeros",
)(w(broadcast_trick(np.zeros_like), dtype="f8"))
empty = array_creation_dispatch.register_inplace(
backend="numpy",
name="empty",
)(w(broadcast_trick(np.empty_like), dtype="f8"))
w_like = wrap(wrap_func_like)
empty_like = w_like(np.empty, func_like=np.empty_like)
# full and full_like require special casing due to argument check on fill_value
# Generate wrapped functions only once
_full = array_creation_dispatch.register_inplace(
backend="numpy",
name="full",
)(w(broadcast_trick(np.full_like)))
_full_like = w_like(np.full, func_like=np.full_like)
# workaround for numpy doctest failure: https://github.com/numpy/numpy/pull/17472
if _full.__doc__ is not None:
_full.__doc__ = _full.__doc__.replace(
"array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1])",
"array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1])",
)
_full.__doc__ = _full.__doc__.replace(
">>> np.full_like(y, [0, 0, 255])",
">>> np.full_like(y, [0, 0, 255]) # doctest: +NORMALIZE_WHITESPACE",
)
[docs]def full(shape, fill_value, *args, **kwargs):
# np.isscalar has somewhat strange behavior:
# https://docs.scipy.org/doc/numpy/reference/generated/numpy.isscalar.html
if np.ndim(fill_value) != 0:
raise ValueError(
f"fill_value must be scalar. Received {type(fill_value).__name__} instead."
)
if kwargs.get("dtype", None) is None:
if hasattr(fill_value, "dtype"):
kwargs["dtype"] = fill_value.dtype
else:
kwargs["dtype"] = type(fill_value)
return _full(*args, shape=shape, fill_value=fill_value, **kwargs)
def full_like(a, fill_value, *args, **kwargs):
if np.ndim(fill_value) != 0:
raise ValueError(
f"fill_value must be scalar. Received {type(fill_value).__name__} instead."
)
return _full_like(
*args,
a=a,
fill_value=fill_value,
**kwargs,
)
full.__doc__ = _full.__doc__
full_like.__doc__ = _full_like.__doc__