from __future__ import annotations
import contextlib
import functools
import itertools
import math
import numbers
import warnings
import numpy as np
from tlz import concat, frequencies
from dask.array.core import Array
from dask.array.numpy_compat import AxisError
from dask.base import is_dask_collection, tokenize
from dask.highlevelgraph import HighLevelGraph
from dask.utils import has_keyword, is_arraylike, is_cupy_type, typename
def normalize_to_array(x):
if is_cupy_type(x):
return x.get()
else:
return x
def compute_meta(func, _dtype, *args, **kwargs):
with np.errstate(all="ignore"), warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
args_meta = [meta_from_array(x) if is_arraylike(x) else x for x in args]
kwargs_meta = {
k: meta_from_array(v) if is_arraylike(v) else v for k, v in kwargs.items()
}
# todo: look for alternative to this, causes issues when using map_blocks()
# with np.vectorize, such as dask.array.routines._isnonzero_vec().
if isinstance(func, np.vectorize):
meta = func(*args_meta)
else:
try:
# some reduction functions need to know they are computing meta
if has_keyword(func, "computing_meta"):
kwargs_meta["computing_meta"] = True
meta = func(*args_meta, **kwargs_meta)
except TypeError as e:
if any(
s in str(e)
for s in [
"unexpected keyword argument",
"is an invalid keyword for",
"Did not understand the following kwargs",
]
):
raise
else:
return None
except ValueError as e:
# min/max functions have no identity, just use the same input type when there's only one
if len(
args_meta
) == 1 and "zero-size array to reduction operation" in str(e):
meta = args_meta[0]
else:
return None
except Exception:
return None
if _dtype and getattr(meta, "dtype", None) != _dtype:
with contextlib.suppress(AttributeError):
meta = meta.astype(_dtype)
if np.isscalar(meta):
meta = np.array(meta)
return meta
def allclose(a, b, equal_nan=False, **kwargs):
a = normalize_to_array(a)
b = normalize_to_array(b)
if getattr(a, "dtype", None) != "O":
if hasattr(a, "mask") or hasattr(b, "mask"):
return np.ma.allclose(a, b, masked_equal=True, **kwargs)
else:
return np.allclose(a, b, equal_nan=equal_nan, **kwargs)
if equal_nan:
return a.shape == b.shape and all(
np.isnan(b) if np.isnan(a) else a == b for (a, b) in zip(a.flat, b.flat)
)
return (a == b).all()
def same_keys(a, b):
def key(k):
if isinstance(k, str):
return (k, -1, -1, -1)
else:
return k
return sorted(a.dask, key=key) == sorted(b.dask, key=key)
def _not_empty(x):
return x.shape and 0 not in x.shape
def _check_dsk(dsk):
"""Check that graph is well named and non-overlapping"""
if not isinstance(dsk, HighLevelGraph):
return
dsk.validate()
assert all(isinstance(k, (tuple, str)) for k in dsk.layers)
freqs = frequencies(concat(dsk.layers.values()))
non_one = {k: v for k, v in freqs.items() if v != 1}
key_collisions = set()
# Allow redundant keys if the values are equivalent
for k in non_one.keys():
for layer in dsk.layers.values():
try:
key_collisions.add(tokenize(layer[k]))
except KeyError:
pass
assert len(key_collisions) < 2, non_one
def assert_eq_shape(a, b, check_ndim=True, check_nan=True):
if check_ndim:
assert len(a) == len(b)
for aa, bb in zip(a, b):
if math.isnan(aa) or math.isnan(bb):
if check_nan:
assert math.isnan(aa) == math.isnan(bb)
else:
assert aa == bb
def _check_chunks(x, check_ndim=True, scheduler=None):
x = x.persist(scheduler=scheduler)
for idx in itertools.product(*(range(len(c)) for c in x.chunks)):
chunk = x.dask[(x.name,) + idx]
if hasattr(chunk, "result"): # it's a future
chunk = chunk.result()
if not hasattr(chunk, "dtype"):
chunk = np.array(chunk, dtype="O")
expected_shape = tuple(c[i] for c, i in zip(x.chunks, idx))
assert_eq_shape(
expected_shape, chunk.shape, check_ndim=check_ndim, check_nan=False
)
assert (
chunk.dtype == x.dtype
), "maybe you forgot to pass the scheduler to `assert_eq`?"
return x
def _get_dt_meta_computed(
x,
check_shape=True,
check_graph=True,
check_chunks=True,
check_ndim=True,
scheduler=None,
):
x_original = x
x_meta = None
x_computed = None
if is_dask_collection(x) and is_arraylike(x):
assert x.dtype is not None
adt = x.dtype
if check_graph:
_check_dsk(x.dask)
x_meta = getattr(x, "_meta", None)
if check_chunks:
# Replace x with persisted version to avoid computing it twice.
x = _check_chunks(x, check_ndim=check_ndim, scheduler=scheduler)
x = x.compute(scheduler=scheduler)
x_computed = x
if hasattr(x, "todense"):
x = x.todense()
if not hasattr(x, "dtype"):
x = np.array(x, dtype="O")
if _not_empty(x):
assert x.dtype == x_original.dtype
if check_shape:
assert_eq_shape(x_original.shape, x.shape, check_nan=False)
else:
if not hasattr(x, "dtype"):
x = np.array(x, dtype="O")
adt = getattr(x, "dtype", None)
return x, adt, x_meta, x_computed
def assert_eq(
a,
b,
check_shape=True,
check_graph=True,
check_meta=True,
check_chunks=True,
check_ndim=True,
check_type=True,
check_dtype=True,
equal_nan=True,
scheduler="sync",
**kwargs,
):
a_original = a
b_original = b
if isinstance(a, (list, int, float)):
a = np.array(a)
if isinstance(b, (list, int, float)):
b = np.array(b)
a, adt, a_meta, a_computed = _get_dt_meta_computed(
a,
check_shape=check_shape,
check_graph=check_graph,
check_chunks=check_chunks,
check_ndim=check_ndim,
scheduler=scheduler,
)
b, bdt, b_meta, b_computed = _get_dt_meta_computed(
b,
check_shape=check_shape,
check_graph=check_graph,
check_chunks=check_chunks,
check_ndim=check_ndim,
scheduler=scheduler,
)
if check_dtype and str(adt) != str(bdt):
raise AssertionError(f"a and b have different dtypes: (a: {adt}, b: {bdt})")
try:
assert (
a.shape == b.shape
), f"a and b have different shapes (a: {a.shape}, b: {b.shape})"
if check_type:
_a = a if a.shape else a.item()
_b = b if b.shape else b.item()
assert type(_a) == type(
_b
), f"a and b have different types (a: {type(_a)}, b: {type(_b)})"
if check_meta:
if hasattr(a, "_meta") and hasattr(b, "_meta"):
assert_eq(a._meta, b._meta)
if hasattr(a_original, "_meta"):
msg = (
f"compute()-ing 'a' changes its number of dimensions "
f"(before: {a_original._meta.ndim}, after: {a.ndim})"
)
assert a_original._meta.ndim == a.ndim, msg
if a_meta is not None:
msg = (
f"compute()-ing 'a' changes its type "
f"(before: {type(a_original._meta)}, after: {type(a_meta)})"
)
assert type(a_original._meta) == type(a_meta), msg
if not (np.isscalar(a_meta) or np.isscalar(a_computed)):
msg = (
f"compute()-ing 'a' results in a different type than implied by its metadata "
f"(meta: {type(a_meta)}, computed: {type(a_computed)})"
)
assert type(a_meta) == type(a_computed), msg
if hasattr(b_original, "_meta"):
msg = (
f"compute()-ing 'b' changes its number of dimensions "
f"(before: {b_original._meta.ndim}, after: {b.ndim})"
)
assert b_original._meta.ndim == b.ndim, msg
if b_meta is not None:
msg = (
f"compute()-ing 'b' changes its type "
f"(before: {type(b_original._meta)}, after: {type(b_meta)})"
)
assert type(b_original._meta) == type(b_meta), msg
if not (np.isscalar(b_meta) or np.isscalar(b_computed)):
msg = (
f"compute()-ing 'b' results in a different type than implied by its metadata "
f"(meta: {type(b_meta)}, computed: {type(b_computed)})"
)
assert type(b_meta) == type(b_computed), msg
msg = "found values in 'a' and 'b' which differ by more than the allowed amount"
assert allclose(a, b, equal_nan=equal_nan, **kwargs), msg
return True
except TypeError:
pass
c = a == b
if isinstance(c, np.ndarray):
assert c.all()
else:
assert c
return True
def safe_wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS):
"""Like functools.wraps, but safe to use even if wrapped is not a function.
Only needed on Python 2.
"""
if all(hasattr(wrapped, attr) for attr in assigned):
return functools.wraps(wrapped, assigned=assigned)
else:
return lambda x: x
def _dtype_of(a):
"""Determine dtype of an array-like."""
try:
# Check for the attribute before using asanyarray, because some types
# (notably sparse arrays) don't work with it.
return a.dtype
except AttributeError:
return np.asanyarray(a).dtype
def arange_safe(*args, like, **kwargs):
"""
Use the `like=` from `np.arange` to create a new array dispatching
to the downstream library. If that fails, falls back to the
default NumPy behavior, resulting in a `numpy.ndarray`.
"""
if like is None:
return np.arange(*args, **kwargs)
else:
try:
return np.arange(*args, like=meta_from_array(like), **kwargs)
except TypeError:
return np.arange(*args, **kwargs)
def _array_like_safe(np_func, da_func, a, like, **kwargs):
if like is a and hasattr(a, "__array_function__"):
return a
if isinstance(like, Array):
return da_func(a, **kwargs)
elif isinstance(a, Array):
if is_cupy_type(a._meta):
a = a.compute(scheduler="sync")
try:
return np_func(a, like=meta_from_array(like), **kwargs)
except TypeError:
return np_func(a, **kwargs)
def array_safe(a, like, **kwargs):
"""
If `a` is `dask.array`, return `dask.array.asarray(a, **kwargs)`,
otherwise return `np.asarray(a, like=like, **kwargs)`, dispatching
the call to the library that implements the like array. Note that
when `a` is a `dask.Array` backed by `cupy.ndarray` but `like`
isn't, this function will call `a.compute(scheduler="sync")`
before `np.array`, as downstream libraries are unlikely to know how
to convert a `dask.Array` and CuPy doesn't implement `__array__` to
prevent implicit copies to host.
"""
from dask.array.routines import array
return _array_like_safe(np.array, array, a, like, **kwargs)
def asarray_safe(a, like, **kwargs):
"""
If a is dask.array, return dask.array.asarray(a, **kwargs),
otherwise return np.asarray(a, like=like, **kwargs), dispatching
the call to the library that implements the like array. Note that
when a is a dask.Array but like isn't, this function will call
a.compute(scheduler="sync") before np.asarray, as downstream
libraries are unlikely to know how to convert a dask.Array.
"""
from dask.array.core import asarray
return _array_like_safe(np.asarray, asarray, a, like, **kwargs)
def asanyarray_safe(a, like, **kwargs):
"""
If a is dask.array, return dask.array.asanyarray(a, **kwargs),
otherwise return np.asanyarray(a, like=like, **kwargs), dispatching
the call to the library that implements the like array. Note that
when a is a dask.Array but like isn't, this function will call
a.compute(scheduler="sync") before np.asanyarray, as downstream
libraries are unlikely to know how to convert a dask.Array.
"""
from dask.array.core import asanyarray
return _array_like_safe(np.asanyarray, asanyarray, a, like, **kwargs)
def validate_axis(axis, ndim):
"""Validate an input to axis= keywords"""
if isinstance(axis, (tuple, list)):
return tuple(validate_axis(ax, ndim) for ax in axis)
if not isinstance(axis, numbers.Integral):
raise TypeError("Axis value must be an integer, got %s" % axis)
if axis < -ndim or axis >= ndim:
raise AxisError(
"Axis %d is out of bounds for array of dimension %d" % (axis, ndim)
)
if axis < 0:
axis += ndim
return axis
def svd_flip(u, v, u_based_decision=False):
"""Sign correction to ensure deterministic output from SVD.
This function is useful for orienting eigenvectors such that
they all lie in a shared but arbitrary half-space. This makes
it possible to ensure that results are equivalent across SVD
implementations and random number generator states.
Parameters
----------
u : (M, K) array_like
Left singular vectors (in columns)
v : (K, N) array_like
Right singular vectors (in rows)
u_based_decision: bool
Whether or not to choose signs based
on `u` rather than `v`, by default False
Returns
-------
u : (M, K) array_like
Left singular vectors with corrected sign
v: (K, N) array_like
Right singular vectors with corrected sign
"""
# Determine half-space in which all singular vectors
# lie relative to an arbitrary vector; summation
# equivalent to dot product with row vector of ones
if u_based_decision:
dtype = u.dtype
signs = np.sum(u, axis=0, keepdims=True)
else:
dtype = v.dtype
signs = np.sum(v, axis=1, keepdims=True).T
signs = 2.0 * ((signs >= 0) - 0.5).astype(dtype)
# Force all singular vectors into same half-space
u, v = u * signs, v * signs.T
return u, v
def scipy_linalg_safe(func_name, *args, **kwargs):
# need to evaluate at least the first input array
# for gpu/cpu checking
a = args[0]
if is_cupy_type(a):
import cupyx.scipy.linalg
func = getattr(cupyx.scipy.linalg, func_name)
else:
import scipy.linalg
func = getattr(scipy.linalg, func_name)
return func(*args, **kwargs)
def solve_triangular_safe(a, b, lower=False):
return scipy_linalg_safe("solve_triangular", a, b, lower=lower)
def __getattr__(name):
# Can't use the @_deprecated decorator as it would not work on `except AxisError`
if name == "AxisError":
warnings.warn(
"AxisError was deprecated after version 2021.10.0 and will be removed in a "
f"future release. Please use {typename(AxisError)} instead.",
category=FutureWarning,
stacklevel=2,
)
return AxisError
else:
raise AttributeError(f"module {__name__} has no attribute {name}")