Source code for dask_expr.io.parquet

from __future__ import annotations

import contextlib
import itertools
import operator
import os
import pickle
import statistics
import warnings
import weakref
from abc import abstractmethod
from collections import defaultdict
from functools import cached_property, partial

import dask
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as pa_ds
import pyarrow.fs as pa_fs
import pyarrow.parquet as pq
import tlz as toolz
from dask.base import normalize_token, tokenize
from dask.core import flatten
from dask.dataframe.io.parquet.core import (
    ParquetFunctionWrapper,
    ToParquetFunctionWrapper,
    aggregate_row_groups,
    apply_filters,
    get_engine,
    set_index_columns,
    sorted_columns,
)
from dask.dataframe.io.parquet.utils import _split_user_options
from dask.dataframe.io.utils import _is_local_fs
from dask.delayed import delayed
from dask.utils import apply, funcname, natural_sort_key, parse_bytes, typename
from fsspec.utils import stringify_path
from toolz import identity

from dask_expr._expr import (
    EQ,
    GE,
    GT,
    LE,
    LT,
    NE,
    And,
    Blockwise,
    Expr,
    Filter,
    Index,
    Lengths,
    Literal,
    Or,
    Projection,
    determine_column_projection,
)
from dask_expr._reductions import Len
from dask_expr._util import _convert_to_list, _tokenize_deterministic
from dask_expr.io import BlockwiseIO, PartitionsFiltered
from dask_expr.io.io import FusedParquetIO


@normalize_token.register(pa.fs.FileInfo)
def _tokenize_fileinfo(fileinfo):
    return type(fileinfo).__name__, (
        fileinfo.path,
        fileinfo.size,
        fileinfo.mtime_ns,
        fileinfo.size,
    )


_CPU_COUNT_SET = False


def _maybe_adjust_cpu_count():
    global _CPU_COUNT_SET
    if not _CPU_COUNT_SET:
        # Set the number of threads to the number of cores
        # This is a default for pyarrow, but it's not set by default in
        # dask/distributed
        pa.set_cpu_count(os.cpu_count())
        _CPU_COUNT_SET = True


_STATS_CACHE = {}


PYARROW_NULLABLE_DTYPE_MAPPING = {
    pa.int8(): pd.Int8Dtype(),
    pa.int16(): pd.Int16Dtype(),
    pa.int32(): pd.Int32Dtype(),
    pa.int64(): pd.Int64Dtype(),
    pa.uint8(): pd.UInt8Dtype(),
    pa.uint16(): pd.UInt16Dtype(),
    pa.uint32(): pd.UInt32Dtype(),
    pa.uint64(): pd.UInt64Dtype(),
    pa.bool_(): pd.BooleanDtype(),
    pa.string(): pd.StringDtype(),
    pa.float32(): pd.Float32Dtype(),
    pa.float64(): pd.Float64Dtype(),
}

NONE_LABEL = "__null_dask_index__"

_CACHED_PLAN_SIZE = 10
_cached_plan = {}


class FragmentWrapper:
    _filesystems = weakref.WeakValueDictionary()

    def __init__(self, fragment=None, file_size=None, fragment_packed=None) -> None:
        """Wrap a pyarrow Fragment to only deserialize when needed."""
        # https://github.com/apache/arrow/issues/40279
        self._fragment = fragment
        self._fragment_packed = fragment_packed
        self._file_size = file_size
        self._fs = None

    def pack(self):
        if self._fragment_packed is None:
            self._fragment_packed = (
                self._fragment.format,
                (
                    self._fragment.path
                    if self._fragment.buffer is None
                    else self._fragment.buffer
                ),
                pickle.dumps(self._fragment.filesystem),
                self._fragment.partition_expression,
                self._file_size,
            )
        self._fs = self._fragment = None

    def unpack(self):
        if self._fragment is None:
            (
                pqformat,
                path_or_buffer,
                fs_raw,
                partition_expression,
                file_size,
            ) = self._fragment_packed
            fs = FragmentWrapper._filesystems.get(fs_raw)
            if fs is None:
                fs = pickle.loads(fs_raw)
                FragmentWrapper._filesystems[fs_raw] = fs
            # arrow doesn't keep the python object alive so if we want to reuse
            # we need to keep a reference
            self._fs = fs
            self._fragment = pqformat.make_fragment(
                path_or_buffer,
                filesystem=fs,
                partition_expression=partition_expression,
                file_size=file_size,
            )
        self._fragment_packed = None

    @property
    def fragment(self):
        self.unpack()
        return self._fragment

    def __dask_tokenize__(self):
        return type(self).__name__, normalize_token(
            (
                self.fragment,
                self._fragment_packed,
                self._file_size,
            )
        )

    def __reduce__(self):
        self.pack()
        return FragmentWrapper, (None, None, self._fragment_packed)


def _control_cached_plan(key):
    if len(_cached_plan) > _CACHED_PLAN_SIZE and key not in _cached_plan:
        key_to_pop = list(_cached_plan.keys())[0]
        _cached_plan.pop(key_to_pop)


@normalize_token.register(pa_ds.Dataset)
def normalize_pa_ds(ds):
    return (ds.files, ds.schema)


@normalize_token.register(pa_ds.FileFormat)
def normalize_pa_file_format(file_format):
    return str(file_format)


@normalize_token.register(pa.Schema)
def normalize_pa_schema(schema):
    return schema.to_string()


@normalize_token.register(pq.ParquetSchema)
def normalize_pq_schema(schema):
    try:
        return hash(schema)
    except TypeError:  # pyarrow version not supporting ParquetSchema hash
        return hash(repr(schema))


@normalize_token.register(pq.FileMetaData)
def normalize_pq_filemetadata(meta):
    try:
        return hash(meta)
    except TypeError:
        # pyarrow version not implementing hash for FileMetaData
        # use same logic as implemented in version that does support hashing
        # https://github.com/apache/arrow/blob/bbe59b35de33a0534fc76c9617aa4746031ce16c/python/pyarrow/_parquet.pyx#L853
        return hash(
            (
                repr(meta.schema),
                meta.num_rows,
                meta.num_row_groups,
                meta.format_version,
                meta.serialized_size,
            )
        )


class ToParquet(Expr):
    _parameters = [
        "frame",
        "path",
        "fs",
        "fmd",
        "engine",
        "offset",
        "partition_on",
        "write_metadata_file",
        "name_function",
        "write_kwargs",
    ]

    @property
    def _meta(self):
        return None

    def _divisions(self):
        return (None, None)

    def _lower(self):
        return ToParquetBarrier(
            ToParquetData(
                *self.operands,
            ),
            *self.operands[1:],
        )


class ToParquetData(Blockwise):
    _parameters = ToParquet._parameters

    @property
    def io_func(self):
        return ToParquetFunctionWrapper(
            self.engine,
            self.path,
            self.fs,
            self.partition_on,
            self.write_metadata_file,
            self.offset,
            self.name_function,
            self.write_kwargs,
        )

    def _divisions(self):
        return (None,) * (self.frame.npartitions + 1)

    def _task(self, index: int):
        return (self.io_func, (self.frame._name, index), (index,))


class ToParquetBarrier(Expr):
    _parameters = ToParquet._parameters

    @property
    def _meta(self):
        return None

    def _divisions(self):
        return (None, None)

    def _layer(self):
        if self.write_metadata_file:
            append = self.write_kwargs.get("append")
            compression = self.write_kwargs.get("compression")
            return {
                (self._name, 0): (
                    apply,
                    self.engine.write_metadata,
                    [
                        self.frame.__dask_keys__(),
                        self.fmd,
                        self.fs,
                        self.path,
                    ],
                    {"append": append, "compression": compression},
                )
            }
        else:
            return {(self._name, 0): (lambda x: None, self.frame.__dask_keys__())}


[docs]def to_parquet( df, path, compression="snappy", write_index=True, append=False, overwrite=False, ignore_divisions=False, partition_on=None, storage_options=None, custom_metadata=None, write_metadata_file=None, compute=True, compute_kwargs=None, schema="infer", name_function=None, filesystem=None, engine=None, **kwargs, ): from dask_expr._collection import new_collection engine = _set_parquet_engine(engine=engine, meta=df._meta) compute_kwargs = compute_kwargs or {} partition_on = partition_on or [] if isinstance(partition_on, str): partition_on = [partition_on] if set(partition_on) - set(df.columns): raise ValueError( "Partitioning on non-existent column. " "partition_on=%s ." "columns=%s" % (str(partition_on), str(list(df.columns))) ) if df.columns.inferred_type not in {"string", "empty"}: raise ValueError("parquet doesn't support non-string column names") if isinstance(engine, str): engine = get_engine(engine) if hasattr(path, "name"): path = stringify_path(path) fs, _paths, _, _ = engine.extract_filesystem( path, filesystem=filesystem, dataset_options={}, open_file_options={}, storage_options=storage_options, ) assert len(_paths) == 1, "only one path" path = _paths[0] if overwrite: if append: raise ValueError("Cannot use both `overwrite=True` and `append=True`!") if fs.exists(path) and fs.isdir(path): # Check for any previous parquet ops reading from a file in the # output directory, since deleting those files now would result in # errors or incorrect results. for read_op in df.expr.find_operations(ReadParquet): read_path_with_slash = str(read_op.path).rstrip("/") + "/" write_path_with_slash = path.rstrip("/") + "/" if read_path_with_slash.startswith(write_path_with_slash): raise ValueError( "Cannot overwrite a path that you are reading " "from in the same task graph." ) # Don't remove the directory if it's the current working directory if _is_local_fs(fs): working_dir = fs.expand_path(".")[0] if path.rstrip("/") == working_dir.rstrip("/"): raise ValueError( "Cannot clear the contents of the current working directory!" ) # It's safe to clear the output directory fs.rm(path, recursive=True) # Clear read_parquet caches in case we are # also reading from the overwritten path _cached_plan.clear() # Always skip divisions checks if divisions are unknown if not df.known_divisions: ignore_divisions = True # Save divisions and corresponding index name. This is necessary, # because we may be resetting the index to write the file division_info = {"divisions": df.divisions, "name": df.index.name} if division_info["name"] is None: # As of 0.24.2, pandas will rename an index with name=None # when df.reset_index() is called. The default name is "index", # but dask will always change the name to the NONE_LABEL constant if NONE_LABEL not in df.columns: division_info["name"] = NONE_LABEL elif write_index: raise ValueError( "Index must have a name if __null_dask_index__ is a column." ) else: warnings.warn( "If read back by Dask, column named __null_dask_index__ " "will be set to the index (and renamed to None)." ) # There are some "reserved" names that may be used as the default column # name after resetting the index. However, we don't want to treat it as # a "special" name if the string is already used as a "real" column name. reserved_names = [] for name in ["index", "level_0"]: if name not in df.columns: reserved_names.append(name) # If write_index==True (default), reset the index and record the # name of the original index in `index_cols` (we will set the name # to the NONE_LABEL constant if it is originally `None`). # `fastparquet` will use `index_cols` to specify the index column(s) # in the metadata. `pyarrow` will revert the `reset_index` call # below if `index_cols` is populated (because pyarrow will want to handle # index preservation itself). For both engines, the column index # will be written to "pandas metadata" if write_index=True index_cols = [] if write_index: real_cols = set(df.columns) none_index = list(df._meta.index.names) == [None] df = df.reset_index() if none_index: rename_columns = {c: NONE_LABEL for c in df.columns if c in reserved_names} df = df.rename(columns=rename_columns) index_cols = [c for c in set(df.columns) - real_cols] else: # Not writing index - might as well drop it df = df.reset_index(drop=True) if custom_metadata and b"pandas" in custom_metadata.keys(): raise ValueError( "User-defined key/value metadata (custom_metadata) can not " "contain a b'pandas' key. This key is reserved by Pandas, " "and overwriting the corresponding value can render the " "entire dataset unreadable." ) # Engine-specific initialization steps to write the dataset. # Possibly create parquet metadata, and load existing stuff if appending i_offset, fmd, metadata_file_exists, extra_write_kwargs = engine.initialize_write( df.to_legacy_dataframe(), fs, path, append=append, ignore_divisions=ignore_divisions, partition_on=partition_on, division_info=division_info, index_cols=index_cols, schema=schema, custom_metadata=custom_metadata, **kwargs, ) # By default we only write a metadata file when appending if one already # exists if append and write_metadata_file is None: write_metadata_file = metadata_file_exists # Check that custom name_function is valid, # and that it will produce unique names if name_function is not None: if not callable(name_function): raise ValueError("``name_function`` must be a callable with one argument.") filenames = [name_function(i + i_offset) for i in range(df.npartitions)] if len(set(filenames)) < len(filenames): raise ValueError("``name_function`` must produce unique filenames.") # If we are using a remote filesystem and retries is not set, bump it # to be more fault tolerant, as transient transport errors can occur. # The specific number 5 isn't hugely motivated: it's less than ten and more # than two. annotations = dask.config.get("annotations", {}) if "retries" not in annotations and not _is_local_fs(fs): ctx = dask.annotate(retries=5) else: ctx = contextlib.nullcontext() with ctx: out = new_collection( ToParquet( df, path, fs, fmd, engine, i_offset, partition_on, write_metadata_file, name_function, toolz.merge( kwargs, {"compression": compression, "custom_metadata": custom_metadata}, extra_write_kwargs, ), ) ) if compute: out = out.compute(**compute_kwargs) # Invalidate the filesystem listing cache for the output path after write. # We do this before returning, even if `compute=False`. This helps ensure # that reading files that were just written succeeds. fs.invalidate_cache(path) return out
def _determine_type_mapper( *, user_types_mapper, dtype_backend, pyarrow_strings_enabled ): type_mappers = [] def pyarrow_type_mapper(pyarrow_dtype): # Special case pyarrow strings to use more feature complete dtype # See https://github.com/pandas-dev/pandas/issues/50074 if pyarrow_dtype == pa.string(): return pd.StringDtype("pyarrow") else: return pd.ArrowDtype(pyarrow_dtype) # always use the user-defined mapper first, if available if user_types_mapper is not None: type_mappers.append(user_types_mapper) # next in priority is converting strings if pyarrow_strings_enabled: type_mappers.append({pa.string(): pd.StringDtype("pyarrow")}.get) type_mappers.append({pa.date32(): pd.ArrowDtype(pa.date32())}.get) type_mappers.append({pa.date64(): pd.ArrowDtype(pa.date64())}.get) def _convert_decimal_type(type): if pa.types.is_decimal(type): return pd.ArrowDtype(type) return None type_mappers.append(_convert_decimal_type) # and then nullable types if dtype_backend == "numpy_nullable": type_mappers.append(PYARROW_NULLABLE_DTYPE_MAPPING.get) elif dtype_backend == "pyarrow": type_mappers.append(pyarrow_type_mapper) def default_types_mapper(pyarrow_dtype): """Try all type mappers in order, starting from the user type mapper.""" for type_converter in type_mappers: converted_type = type_converter(pyarrow_dtype) if converted_type is not None: return converted_type if len(type_mappers) > 0: return default_types_mapper class ReadParquet(PartitionsFiltered, BlockwiseIO): _pq_length_stats = None _absorb_projections = True _filter_passthrough = False def _filter_passthrough_available(self, parent, dependents): return ( super()._filter_passthrough_available(parent, dependents) and (isinstance(parent.predicate, (LE, GE, LT, GT, EQ, NE, And, Or))) and _DNF.extract_pq_filters(self, parent.predicate)._filters is not None ) def _simplify_up(self, parent, dependents): if isinstance(parent, Index): # Column projection columns = determine_column_projection(self, parent, dependents) if set(columns) == set(self.columns): return columns = [col for col in self.columns if col in columns] return Index( self.substitute_parameters({"columns": columns, "_series": False}) ) if isinstance(parent, Projection): return super()._simplify_up(parent, dependents) if isinstance(parent, Filter) and self._filter_passthrough_available( parent, dependents ): # Predicate pushdown filters = _DNF.extract_pq_filters(self, parent.predicate) if filters._filters is not None: return self.substitute_parameters( { "filters": filters.combine( self.operand("filters") ).to_list_tuple() } ) if isinstance(parent, Lengths): _lengths = self._get_lengths() if _lengths: return Literal(_lengths) if isinstance(parent, Len): _lengths = self._get_lengths() if _lengths: return Literal(sum(_lengths)) @property def columns(self): columns_operand = self.operand("columns") if columns_operand is None: return list(self._meta.columns) else: return _convert_to_list(columns_operand) @cached_property def _funcname(self): return "read_parquet" @cached_property def _name(self): return ( self._funcname + "-" + _tokenize_deterministic( funcname(type(self)), self.checksum, *self.operands[:-1] ) ) @property def checksum(self): return self._dataset_info["checksum"] def _tree_repr_argument_construction(self, i, op, header): if self._parameters[i] == "_dataset_info_cache": # Don't print this, very ugly return header return super()._tree_repr_argument_construction(i, op, header) @cached_property def _meta(self): meta = self._dataset_info["base_meta"] columns = _convert_to_list(self.operand("columns")) if self._series: assert len(columns) > 0 return meta[columns[0]] elif columns is not None: return meta[columns] return meta @abstractmethod def _divisions(self): raise NotImplementedError @property def _fusion_compression_factor(self): if self.operand("columns") is None: return 1 nr_original_columns = max(len(self._dataset_info["schema"].names) - 1, 1) return max( len(_convert_to_list(self.operand("columns"))) / nr_original_columns, 0.001 ) class ReadParquetPyarrowFS(ReadParquet): _parameters = [ "path", "columns", "filters", "categories", "index", "storage_options", "filesystem", "ignore_metadata_file", "calculate_divisions", "arrow_to_pandas", "pyarrow_strings_enabled", "kwargs", "_partitions", "_series", "_dataset_info_cache", ] _defaults = { "columns": None, "filters": None, "categories": None, "index": None, "storage_options": None, "filesystem": None, "ignore_metadata_file": True, "calculate_divisions": False, "arrow_to_pandas": None, "pyarrow_strings_enabled": True, "kwargs": None, "_partitions": None, "_series": False, "_dataset_info_cache": None, } _absorb_projections = True _filter_passthrough = True @cached_property def normalized_path(self): return _normalize_and_strip_protocol(self.path) @cached_property def fs(self): fs_input = self.operand("filesystem") if isinstance(fs_input, pa.fs.FileSystem): return fs_input else: fs = pa_fs.FileSystem.from_uri(self.path)[0] if storage_options := self.storage_options: # Use inferred region as the default region = {} if "region" in storage_options else {"region": fs.region} fs = type(fs)(**region, **storage_options) return fs def approx_statistics(self) -> dict: """Return an approximation of a single files statistics. This is determined by sampling a few files and averaging their statistics. Fields ------ num_rows: avg num_row_groups: avg serialized_size: avg columns: list A list of all colum statistics where individual fields are also averaged. Example ------- { 'num_rows': 1991129, 'num_row_groups': 2.3333333333333335, 'serialized_size': 6256.666666666667, 'total_byte_size': 118030095, 'columns': [ {'total_compressed_size': 6284162.333333333, 'total_uncompressed_size': 6347380.333333333, 'path_in_schema': 'l_orderkey'}, {'total_compressed_size': 9423516.333333334, 'total_uncompressed_size': 9423063.333333334, 'path_in_schema': 'l_partkey'}, {'total_compressed_size': 9405796.666666666, 'total_uncompressed_size': 9405346.666666666, 'path_in_schema': 'l_suppkey'}, ... ] } Returns ------- dict """ idxs = self.sample_statistics() files_to_consider = np.array(self._dataset_info["all_files"])[idxs] stats = [_STATS_CACHE[tokenize(finfo)] for finfo in files_to_consider] return _combine_stats(stats) def load_statistics(self, files=None, fragments=None): if files is None: files = self._dataset_info["all_files"] if fragments is None: fragments = self.fragments_unsorted # Collecting code samples is actually a little expensive (~100ms) and # we'd like this thing to be as low overhead as possible with dask.config.set({"distributed.diagnostics.computations.nframes": 0}): token_stats = flatten( dask.compute(_collect_statistics_plan(files, fragments)) ) for token, stats in token_stats: _STATS_CACHE[token] = stats def sample_statistics(self, n=3): """Sample statistics from the dataset. Sample N file statistics from the dataset. The files are chosen by sorting all files based on their binary file size and picking equidistant sampling points. In the special case of n=3 this corresponds to min/median/max. Returns ------- ixs: list[int] The indices of files that were sampled """ frags = self.fragments_unsorted finfos = np.array(self._dataset_info["all_files"]) getsize = np.frompyfunc(lambda x: x.size, nin=1, nout=1) finfo_size_arr = getsize(finfos) finfo_argsort = finfo_size_arr.argsort() nfrags = len(frags) stepsize = max(nfrags // n, 1) finfos_sampled = [] frags_samples = [] ixs = [] for i in range(0, nfrags, stepsize): sort_ix = finfo_argsort[i] # TODO: This is crude but the most conservative estimate sort_ix = sort_ix if sort_ix < nfrags else 0 ixs.append(sort_ix) finfos_sampled.append(finfos[sort_ix]) frags_samples.append(frags[sort_ix]) self.load_statistics(finfos_sampled, frags_samples) return ixs @cached_property def raw_statistics(self): """Parquet statstics for every file in the dataset. The statistics do not include all the metadata that is stored in the file but only a subset. See also `_extract_stats`. """ self.load_statistics() return [ _STATS_CACHE[tokenize(finfo)] for finfo in self._dataset_info["all_files"] ] @cached_property def aggregated_statistics(self): """Aggregate statistics for every partition in the dataset. These statistics aggregated the row group statistics to partition level such that min/max/total_compressed_size/etc. corresponds to the entire partition instead of individual row groups. """ return _aggregate_statistics_to_file(self.raw_statistics) def _get_lengths(self): # TODO: Filters that only filter partition_expr can be used as well if not self.filters: return tuple(stats["num_rows"] for stats in self.aggregated_statistics) @cached_property def _dataset_info(self): if rv := self.operand("_dataset_info_cache"): return rv dataset_info = {} path_normalized = self.normalized_path # We'll first treat the path as if it was a directory since this is the # most common case. Only if this fails, we'll treat it as a file. This # way, the happy path performs one remote request instead of two if we # were to check the type of the path first. try: # At this point we will post a listbucket request which includes the # same data as a HEAD request. The information included here (see # pyarrow FileInfo) are size, type, path and modified since # timestamps This isn't free but realtively cheap (200-300ms or less # for ~1k files) all_files = [] for path in path_normalized: dataset_selector = pa_fs.FileSelector(path, recursive=True) all_files.extend( [ finfo for finfo in self.fs.get_file_info(dataset_selector) if finfo.type == pa.fs.FileType.File ] ) except (NotADirectoryError, FileNotFoundError): all_files = [self.fs.get_file_info(path) for path in path_normalized] # TODO: At this point we could verify if we're dealing with a very # inhomogeneous datasets already without reading any further data metadata_file = False checksum = None dataset = None if not self.ignore_metadata_file: all_files = sorted( all_files, key=lambda x: x.base_name.endswith("_metadata") ) if all_files[-1].base_name.endswith("_metadata"): metadata_file = all_files.pop() checksum = tokenize(metadata_file) # TODO: dataset kwargs? dataset = pa_ds.parquet_dataset( metadata_file.path, filesystem=self.fs, ) dataset_info["using_metadata_file"] = True dataset_info["fragments"] = _frags = list(dataset.get_fragments()) dataset_info["file_sizes"] = [None for fi in _frags] if checksum is None: checksum = tokenize(all_files) dataset_info["file_sizes"] = [fi.size for fi in all_files] dataset_info["checksum"] = checksum if dataset is None: import pyarrow.parquet as pq dataset = pq.ParquetDataset( # TODO Just pass all_files once # https://github.com/apache/arrow/pull/40143 is available to # reduce latency [fi.path for fi in all_files], filesystem=self.fs, filters=self.filters, ) dataset_info["using_metadata_file"] = False dataset_info["fragments"] = dataset.fragments dataset_info["all_files"] = all_files dataset_info["dataset"] = dataset dataset_info["schema"] = dataset.schema dataset_info["base_meta"] = dataset.schema.empty_table().to_pandas() self.operands[ type(self)._parameters.index("_dataset_info_cache") ] = dataset_info return dataset_info @cached_property def _division_from_stats(self): """If enabled, compute the divisions from the collected statistics. If divisions are possible to set, the second argument will be the argsort of the fragments such that the divisions are correct. Returns ------- divisions argsort """ if self.calculate_divisions and self.index is not None: index_name = self.index.name return _divisions_from_statistics(self.aggregated_statistics, index_name) return tuple([None] * (len(self.fragments_unsorted) + 1)), None def all_statistics_known(self) -> bool: """Whether all statistics have been fetched from remote store""" return all( tokenize(finfo) in _STATS_CACHE for finfo in self._dataset_info["all_files"] ) def _fragment_sort_index(self): return self._division_from_stats[1] def _divisions(self): return self._division_from_stats[0] def _tune_up(self, parent): if self._fusion_compression_factor >= 1: return if isinstance(parent, FusedParquetIO): return return parent.substitute(self, FusedParquetIO(self)) @cached_property def fragments(self): """Return all fragments in the dataset after filtering in the order as expected by the divisions. See also -------- ReadParquetPyarrowFS.fragments_unsorted """ if self._fragment_sort_index() is not None: return self.fragments_unsorted[self._fragment_sort_index()] return self.fragments_unsorted @property def fragments_unsorted(self): """All fragments in the dataset after filtering. No guarantees on ordering. This is ordered as the files are listed. See also -------- ReadParquetPyarrowFS.fragments """ if self.filters is not None: if self._dataset_info["using_metadata_file"]: ds = self._dataset_info["dataset"] else: ds = self._dataset_info["dataset"]._dataset return np.array( list(ds.get_fragments(filter=pq.filters_to_expression(self.filters))) ) return np.array(self._dataset_info["fragments"]) @property def _fusion_compression_factor(self): approx_stats = self.approx_statistics() total_uncompressed = 0 after_projection = 0 col_op = self.operand("columns") or self.columns for col in approx_stats["columns"]: total_uncompressed += col["total_uncompressed_size"] if col["path_in_schema"] in col_op: after_projection += col["total_uncompressed_size"] min_size = dask.config.get("dataframe.parquet.minimum-partition-size") total_uncompressed = max(total_uncompressed, min_size) return max(after_projection / total_uncompressed, 0.001) def _filtered_task(self, index: int): columns = self.columns.copy() index_name = self.index.name if self.index is not None: index_name = self.index.name schema = self._dataset_info["schema"].remove_metadata() if index_name: if columns is None: columns = list(schema.names) columns.append(index_name) return ( ReadParquetPyarrowFS._table_to_pandas, ( ReadParquetPyarrowFS._fragment_to_table, FragmentWrapper(self.fragments[index]), self.filters, columns, schema, ), index_name, self.arrow_to_pandas, self.kwargs.get("dtype_backend"), self.pyarrow_strings_enabled, ) @staticmethod def _fragment_to_table(fragment_wrapper, filters, columns, schema): _maybe_adjust_cpu_count() if isinstance(fragment_wrapper, FragmentWrapper): fragment = fragment_wrapper.fragment else: fragment = fragment_wrapper if isinstance(filters, list): filters = pq.filters_to_expression(filters) return fragment.to_table( schema=schema, columns=columns, filter=filters, # Batch size determines how many rows are read at once and will # cause the underlying array to be split into chunks of this size # (max). We'd like to avoid fragmentation as much as possible and # and to set this to something like inf but we have to set a finite, # positive number. # In the presence of row groups, the underlying array will still be # chunked per rowgroup batch_size=10_000_000, fragment_scan_options=pa.dataset.ParquetFragmentScanOptions( pre_buffer=True, cache_options=pa.CacheOptions( hole_size_limit=parse_bytes("4 MiB"), range_size_limit=parse_bytes("32.00 MiB"), ), ), # TODO: Reconsider this. The OMP_NUM_THREAD variable makes it harmful to enable this use_threads=True, ) @staticmethod def _table_to_pandas( table, index_name, arrow_to_pandas, dtype_backend, pyarrow_strings_enabled ): if arrow_to_pandas is None: arrow_to_pandas = {} else: arrow_to_pandas = arrow_to_pandas.copy() # This can mess up index setting, etc. arrow_to_pandas.pop("ignore_metadata", None) df = table.to_pandas( types_mapper=_determine_type_mapper( user_types_mapper=arrow_to_pandas.pop("types_mapper", None), dtype_backend=dtype_backend, pyarrow_strings_enabled=pyarrow_strings_enabled, ), use_threads=arrow_to_pandas.get("use_threads", False), self_destruct=arrow_to_pandas.get("self_destruct", True), **arrow_to_pandas, ignore_metadata=True, ) if index_name is not None: df = df.set_index(index_name) return df class ReadParquetFSSpec(ReadParquet): """Read a parquet dataset""" _parameters = [ "path", "columns", "filters", "categories", "index", "storage_options", "calculate_divisions", "ignore_metadata_file", "metadata_task_size", "split_row_groups", "blocksize", "aggregate_files", "parquet_file_extension", "filesystem", "engine", "kwargs", "_partitions", "_series", "_dataset_info_cache", ] _defaults = { "columns": None, "filters": None, "categories": None, "index": None, "storage_options": None, "calculate_divisions": False, "ignore_metadata_file": False, "metadata_task_size": None, "split_row_groups": "infer", "blocksize": "default", "aggregate_files": None, "parquet_file_extension": (".parq", ".parquet", ".pq"), "filesystem": "fsspec", "engine": "pyarrow", "kwargs": None, "_partitions": None, "_series": False, "_dataset_info_cache": None, } @property def engine(self): _engine = self.operand("engine") if isinstance(_engine, str): return get_engine(_engine) return _engine def _divisions(self): return self._plan["divisions"] @property def _dataset_info(self): if rv := self.operand("_dataset_info_cache"): return rv # Process and split user options ( dataset_options, read_options, open_file_options, other_options, ) = _split_user_options(**(self.kwargs or {})) # Extract global filesystem and paths fs, paths, dataset_options, open_file_options = self.engine.extract_filesystem( self.path, self.filesystem, dataset_options, open_file_options, self.storage_options, ) read_options["open_file_options"] = open_file_options paths = sorted(paths, key=natural_sort_key) # numeric rather than glob ordering auto_index_allowed = False index_operand = self.operand("index") if index_operand is None: # User is allowing auto-detected index auto_index_allowed = True if index_operand and isinstance(index_operand, str): index = [index_operand] else: index = index_operand blocksize = self.blocksize if self.split_row_groups in ("infer", "adaptive"): # Using blocksize to plan partitioning if self.blocksize == "default": if hasattr(self.engine, "default_blocksize"): blocksize = self.engine.default_blocksize() else: blocksize = "128MiB" else: # Not using blocksize - Set to `None` blocksize = None # Collect general dataset info args = ( paths, fs, self.categories, index, self.calculate_divisions, self.filters, self.split_row_groups, blocksize, self.aggregate_files, self.ignore_metadata_file, self.metadata_task_size, self.parquet_file_extension, { "read": read_options, "dataset": dataset_options, **other_options, }, ) dataset_info = self.engine._collect_dataset_info(*args) checksum = [] files_for_checksum = [] if dataset_info["has_metadata_file"]: if isinstance(self.path, list): files_for_checksum = [ next(path for path in self.path if path.endswith("_metadata")) ] else: files_for_checksum = [self.path + fs.sep + "_metadata"] else: files_for_checksum = dataset_info["ds"].files for file in files_for_checksum: # The checksum / file info is usually already cached by the fsspec # FileSystem dir_cache since this info was already asked for in # _collect_dataset_info checksum.append(fs.checksum(file)) dataset_info["checksum"] = tokenize(checksum) # Infer meta, accounting for index and columns arguments. meta = self.engine._create_dd_meta(dataset_info) index = dataset_info["index"] index = [index] if isinstance(index, str) else index meta, index, all_columns = set_index_columns( meta, index, None, auto_index_allowed ) if meta.index.name == NONE_LABEL: meta.index.name = None dataset_info["base_meta"] = meta dataset_info["index"] = index dataset_info["all_columns"] = all_columns dataset_info["calculate_divisions"] = self.calculate_divisions self.operands[ type(self)._parameters.index("_dataset_info_cache") ] = dataset_info return dataset_info def _filtered_task(self, index: int): tsk = (self._io_func, self._plan["parts"][index]) if self._series: return (operator.getitem, tsk, self.columns[0]) return tsk @property def _io_func(self): if self._plan["empty"]: return identity dataset_info = self._dataset_info return ParquetFunctionWrapper( self.engine, dataset_info["fs"], dataset_info["base_meta"], self.columns, dataset_info["index"], dataset_info["kwargs"]["dtype_backend"], {}, # All kwargs should now be in `common_kwargs` self._plan["common_kwargs"], ) @cached_property def _plan(self): dataset_info = self._dataset_info dataset_token = tokenize(dataset_info) if dataset_token not in _cached_plan: parts, stats, common_kwargs = self.engine._construct_collection_plan( dataset_info ) # Make sure parts and stats are aligned parts, stats = _align_statistics(parts, stats) # Use statistics to aggregate partitions parts, stats = _aggregate_row_groups(parts, stats, dataset_info) # Drop filtered partitions (aligns with `dask.dataframe` behavior) if self.filters and stats: parts, stats = apply_filters(parts, stats, self.filters) # Use statistics to calculate divisions divisions = _calculate_divisions(stats, dataset_info, len(parts)) empty = False if len(divisions) < 2: # empty dataframe - just use meta divisions = (None, None) parts = [self._meta] empty = True _control_cached_plan(dataset_token) _cached_plan[dataset_token] = { "empty": empty, "parts": parts, "statistics": stats, "divisions": divisions, "common_kwargs": common_kwargs, } return _cached_plan[dataset_token] def _get_lengths(self) -> tuple | None: """Return known partition lengths using parquet statistics""" if not self.filters: self._update_length_statistics() return tuple( length for i, length in enumerate(self._pq_length_stats) if not self._filtered or i in self._partitions ) return None def _update_length_statistics(self): """Ensure that partition-length statistics are up to date""" if not self._pq_length_stats: if self._plan["statistics"]: # Already have statistics from original API call self._pq_length_stats = tuple( stat["num-rows"] for i, stat in enumerate(self._plan["statistics"]) if not self._filtered or i in self._partitions ) else: # Need to go back and collect statistics self._pq_length_stats = tuple( stat["num-rows"] for stat in _collect_pq_statistics(self) ) # # Helper functions # def _set_parquet_engine(engine=None, meta=None): # Use `engine` or `meta` input to set the parquet engine if engine == "fastparquet": raise NotImplementedError("Fastparquet engine is not supported") if engine is None: if ( meta is not None and typename(meta).split(".")[0] == "cudf" ) or dask.config.get("dataframe.backend", "pandas") == "cudf": from dask_cudf.io.parquet import CudfEngine engine = CudfEngine else: engine = "pyarrow" return engine def _align_statistics(parts, statistics): # Make sure parts and statistics are aligned # (if statistics is not empty) if statistics and len(parts) != len(statistics): statistics = [] if statistics: result = list( zip( *[ (part, stats) for part, stats in zip(parts, statistics) if stats["num-rows"] > 0 ] ) ) parts, statistics = result or [[], []] return parts, statistics def _aggregate_row_groups(parts, statistics, dataset_info): # Aggregate parts/statistics if we are splitting by row-group blocksize = ( dataset_info["blocksize"] if dataset_info["split_row_groups"] is True else None ) split_row_groups = dataset_info["split_row_groups"] fs = dataset_info["fs"] aggregation_depth = dataset_info["aggregation_depth"] if statistics: if blocksize or (split_row_groups and int(split_row_groups) > 1): parts, statistics = aggregate_row_groups( parts, statistics, blocksize, split_row_groups, fs, aggregation_depth ) return parts, statistics def _calculate_divisions(statistics, dataset_info, npartitions): # Use statistics to define divisions divisions = None if statistics and dataset_info.get("gather_statistics", False): calculate_divisions = dataset_info.get("calculate_divisions", None) index = dataset_info["index"] process_columns = index if index and len(index) == 1 else None if (calculate_divisions is not False) and process_columns: for sorted_column_info in sorted_columns( statistics, columns=process_columns ): if sorted_column_info["name"] in index: divisions = sorted_column_info["divisions"] break return divisions or (None,) * (npartitions + 1) # # Filtering logic # class _DNF: """Manage filters in Disjunctive Normal Form (DNF)""" class _Or(frozenset): """Fozen set of disjunctions""" def to_list_tuple(self) -> list: # DNF "or" is List[List[Tuple]] def _maybe_list(val): if isinstance(val, tuple) and val and isinstance(val[0], (tuple, list)): return list(val) return [val] return [ ( _maybe_list(val.to_list_tuple()) if hasattr(val, "to_list_tuple") else _maybe_list(val) ) for val in self ] class _And(frozenset): """Frozen set of conjunctions""" def to_list_tuple(self) -> list: # DNF "and" is List[Tuple] return tuple( val.to_list_tuple() if hasattr(val, "to_list_tuple") else val for val in self ) _filters: _And | _Or | None # Underlying filter expression def __init__(self, filters: _And | _Or | list | tuple | None) -> _DNF: self._filters = self.normalize(filters) def to_list_tuple(self) -> list: return self._filters.to_list_tuple() def __bool__(self) -> bool: return bool(self._filters) @classmethod def normalize(cls, filters: _And | _Or | list | tuple | None): """Convert raw filters to the `_Or(_And)` DNF representation""" if not filters: result = None elif isinstance(filters, list): conjunctions = filters if isinstance(filters[0], list) else [filters] result = cls._Or([cls._And(conjunction) for conjunction in conjunctions]) elif isinstance(filters, tuple): if isinstance(filters[0], tuple): raise TypeError("filters must be List[Tuple] or List[List[Tuple]]") result = cls._Or((cls._And((filters,)),)) elif isinstance(filters, cls._Or): result = cls._Or(se for e in filters for se in cls.normalize(e)) elif isinstance(filters, cls._And): total = [] for c in itertools.product(*[cls.normalize(e) for e in filters]): total.append(cls._And(se for e in c for se in e)) result = cls._Or(total) else: raise TypeError(f"{type(filters)} not a supported type for _DNF") return result def combine(self, other: _DNF | _And | _Or | list | tuple | None) -> _DNF: """Combine with another _DNF object""" if not isinstance(other, _DNF): other = _DNF(other) assert isinstance(other, _DNF) if self._filters is None: result = other._filters elif other._filters is None: result = self._filters else: result = self._And([self._filters, other._filters]) return _DNF(result) @classmethod def extract_pq_filters(cls, pq_expr: ReadParquet, predicate_expr: Expr) -> _DNF: _filters = None if isinstance(predicate_expr, (LE, GE, LT, GT, EQ, NE)): if ( not isinstance(predicate_expr.right, Expr) and isinstance(predicate_expr.left, Projection) and predicate_expr.left.frame._name == pq_expr._name ): op = predicate_expr._operator_repr column = predicate_expr.left.columns[0] value = predicate_expr.right _filters = (column, op, value) elif ( not isinstance(predicate_expr.left, Expr) and isinstance(predicate_expr.left, Projection) and predicate_expr.left.frame._name == pq_expr._name ): # Simple dict to make sure field comes first in filter flip = {LE: GE, LT: GT, GE: LE, GT: LT} op = predicate_expr op = flip.get(op, op)._operator_repr column = predicate_expr.right.columns[0] value = predicate_expr.left _filters = (column, op, value) elif isinstance(predicate_expr, (And, Or)): left = cls.extract_pq_filters(pq_expr, predicate_expr.left)._filters right = cls.extract_pq_filters(pq_expr, predicate_expr.right)._filters if left and right: if isinstance(predicate_expr, And): _filters = cls._And([left, right]) else: _filters = cls._Or([left, right]) return _DNF(_filters) # # Parquet-statistics handling # def _collect_pq_statistics( expr: ReadParquet, columns: list | None = None ) -> list[dict] | None: """Collect Parquet statistic for dataset paths""" # Be strict about columns argument if columns: if not isinstance(columns, list): raise ValueError(f"Expected columns to be a list, got {type(columns)}.") allowed = {expr._meta.index.name} | set(expr.columns) if not set(columns).issubset(allowed): raise ValueError(f"columns={columns} must be a subset of {allowed}") if expr._plan["empty"]: return [] # Collect statistics using layer information fs = expr._io_func.fs parts = [ part for i, part in enumerate(expr._plan["parts"]) if not expr._filtered or i in expr._partitions ] # Execute with delayed for large and remote datasets parallel = int(False if _is_local_fs(fs) else 16) if parallel: # Group parts corresponding to the same file. # A single task should always parse statistics # for all these parts at once (since they will # all be in the same footer) groups = defaultdict(list) for part in parts: for p in [part] if isinstance(part, dict) else part: path = p.get("piece")[0] groups[path].append(p) group_keys = list(groups.keys()) # Compute and return flattened result func = delayed(_read_partition_stats_group) result = dask.compute( [ func( list( itertools.chain( *[groups[k] for k in group_keys[i : i + parallel]] ) ), fs, columns=columns, ) for i in range(0, len(group_keys), parallel) ] )[0] return list(itertools.chain(*result)) else: # Serial computation on client return _read_partition_stats_group(parts, fs, columns=columns) def _read_partition_stats_group(parts, fs, columns=None): """Parse the statistics for a group of files""" def _read_partition_stats(part, fs, columns=None): # Helper function to read Parquet-metadata # statistics for a single partition if not isinstance(part, list): part = [part] column_stats = {} num_rows = 0 columns = columns or [] for p in part: piece = p["piece"] path = piece[0] row_groups = None if piece[1] == [None] else piece[1] with fs.open(path, default_cache="none") as f: md = pq.ParquetFile(f).metadata if row_groups is None: row_groups = list(range(md.num_row_groups)) for rg in row_groups: row_group = md.row_group(rg) num_rows += row_group.num_rows for i in range(row_group.num_columns): col = row_group.column(i) name = col.path_in_schema if name in columns: if col.statistics and col.statistics.has_min_max: if name in column_stats: column_stats[name]["min"] = min( column_stats[name]["min"], col.statistics.min ) column_stats[name]["max"] = max( column_stats[name]["max"], col.statistics.max ) else: column_stats[name] = { "min": col.statistics.min, "max": col.statistics.max, } # Convert dict-of-dict to list-of-dict to be consistent # with current `dd.read_parquet` convention (for now) column_stats_list = [ { "name": name, "min": column_stats[name]["min"], "max": column_stats[name]["max"], } for name in column_stats.keys() ] return {"num-rows": num_rows, "columns": column_stats_list} # Helper function used by _extract_statistics return [_read_partition_stats(part, fs, columns=columns) for part in parts] def _normalize_and_strip_protocol(path): if not isinstance(path, (list, tuple)): path = [path] result = [] for p in path: protocol_separators = ["://", "::"] for sep in protocol_separators: split = p.split(sep, 1) if len(split) > 1: p = split[1] break result.append(p.rstrip("/")) return result def _divisions_from_statistics(aggregated_stats, index_name): col_ix = -1 peak_rg = aggregated_stats[0] for ix, col in enumerate(peak_rg["columns"]): if col["path_in_schema"] == index_name: col_ix = ix break else: raise ValueError( f"Index column {index_name} not found in statistics" # noqa: E713 ) last_max = None minmax = [] for file_stats in aggregated_stats: file_min = file_stats["columns"][col_ix]["statistics"]["min"] file_max = file_stats["columns"][col_ix]["statistics"]["max"] minmax.append((file_min, file_max)) divisions = [] minmax = pd.Series(minmax) argsort = minmax.argsort() sorted_minmax = minmax[argsort] if not sorted_minmax.is_monotonic_increasing: return tuple([None] * (len(aggregated_stats) + 1)), None for file_min, file_max in sorted_minmax: divisions.append(file_min) last_max = file_max divisions.append(last_max) return tuple(divisions), argsort def _extract_stats(original): """Take the raw file statistics as returned by pyarrow (as a dict) and filter it to what we care about. The full stats are a bit too verbose and we don't need all of it.""" # TODO: dicts are pretty memory inefficient. Move to dataclass? file_level_stats = ["num_rows", "num_row_groups", "serialized_size"] rg_stats = [ "num_rows", "total_byte_size", "sorting_columns", ] col_meta = [ "num_values", "total_compressed_size", "total_uncompressed_size", "path_in_schema", ] col_stats = [ "min", "max", "null_count", "num_values", "distinct_count", ] out = {} for name in file_level_stats: out[name] = original[name] out["row_groups"] = rgs = [] for rg in original["row_groups"]: rg_out = {} rgs.append(rg_out) for name in rg_stats: rg_out[name] = rg[name] rg_out["columns"] = [] for col in rg["columns"]: col_out = {} rg_out["columns"].append(col_out) for name in col_meta: col_out[name] = col[name] col_out["statistics"] = {} for name in col_stats: col_out["statistics"][name] = col["statistics"][name] return out def _agg_dicts(dicts, agg_funcs): result = {} for d in dicts: for k, v in d.items(): if k not in result: result[k] = [v] else: result[k].append(v) result2 = {} for k, v in result.items(): agg = agg_funcs.get(k) if agg: result2[k] = agg(v) return result2 def _aggregate_columns(cols, agg_cols): combine = [] i = 0 while True: inner = [] combine.append(inner) try: for col in cols: inner.append(col[i]) except IndexError: combine.pop() break i += 1 return [_agg_dicts(c, agg_cols) for c in combine] def _aggregate_statistics_to_file(stats): """Aggregate RG information to file level.""" agg_stats = { "min": min, "max": max, } agg_cols = { "total_compressed_size": sum, "total_uncompressed_size": sum, "statistics": partial(_agg_dicts, agg_funcs=agg_stats), "path_in_schema": lambda x: set(x).pop(), } agg_func = { "num_rows": sum, "total_byte_size": sum, "columns": partial(_aggregate_columns, agg_cols=agg_cols), } aggregated_stats = [] for file_stat in stats: file_stat = file_stat.copy() aggregated_stats.append(file_stat) file_stat.update(_agg_dicts(file_stat.pop("row_groups"), agg_func)) return aggregated_stats @dask.delayed def _gather_statistics(frags): @dask.delayed def _collect_statistics(token_fragment): return token_fragment[0], _extract_stats(token_fragment[1].metadata.to_dict()) return dask.compute( list(_collect_statistics(frag) for frag in frags), scheduler="threading" )[0] def _collect_statistics_plan(file_infos, fragments): """Collect statistics for a list of files and their corresponding fragments""" to_collect = [] for finfo, frag in zip(file_infos, fragments): if (token := tokenize(finfo)) not in _STATS_CACHE: to_collect.append((token, frag)) return [ _gather_statistics(batch) for batch in toolz.itertoolz.partition_all(20, to_collect) ] def _combine_stats(stats): """Combine multiple file-level statistics into a single dict of metrics that represent the average values of the parquet statistics""" agg_cols = { "total_compressed_size": statistics.mean, "total_uncompressed_size": statistics.mean, "path_in_schema": lambda x: set(x).pop(), } return _agg_dicts( _aggregate_statistics_to_file(stats), { "num_rows": statistics.mean, "num_row_groups": statistics.mean, "serialized_size": statistics.mean, "total_byte_size": statistics.mean, "columns": partial(_aggregate_columns, agg_cols=agg_cols), }, )