TIFNJK_E41221588/venv/Lib/site-packages/altair/datasets/_readimpl.py

512 lines
15 KiB
Python

"""Individual read functions and siuations they support."""
from __future__ import annotations
import sys
from enum import Enum
from functools import partial, wraps
from importlib.util import find_spec
from itertools import chain
from operator import itemgetter
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, Literal
from narwhals.stable import v1 as nw
from narwhals.stable.v1.dependencies import get_pandas, get_polars
from narwhals.stable.v1.typing import IntoDataFrameT
from altair.datasets._constraints import (
is_arrow,
is_csv,
is_json,
is_meta,
is_not_tabular,
is_parquet,
is_spatial,
is_topo,
is_tsv,
)
from altair.datasets._exceptions import AltairDatasetsError
if sys.version_info >= (3, 13):
from typing import TypeVar
else:
from typing_extensions import TypeVar
if sys.version_info >= (3, 12):
from typing import TypeAliasType
else:
from typing_extensions import TypeAliasType
if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator, Sequence
from io import IOBase
from types import ModuleType
import pandas as pd
import polars as pl
import pyarrow as pa
from narwhals.stable.v1 import typing as nwt
from altair.datasets._constraints import Items, MetaIs
__all__ = ["is_available", "pa_any", "pd_only", "pd_pyarrow", "pl_only", "read", "scan"]
R = TypeVar("R", bound="nwt.IntoFrame", covariant=True)
IntoFrameT = TypeVar(
"IntoFrameT",
bound="nwt.NativeFrame | nw.DataFrame[Any] | nw.LazyFrame[Any] | nwt.DataFrameLike",
default=nw.LazyFrame[Any],
)
Read = TypeAliasType("Read", "BaseImpl[IntoDataFrameT]", type_params=(IntoDataFrameT,))
"""An *eager* file read function."""
Scan = TypeAliasType("Scan", "BaseImpl[IntoFrameT]", type_params=(IntoFrameT,))
"""A *lazy* file read function."""
class Skip(Enum):
"""Falsy sentinel."""
skip = 0
def __bool__(self) -> Literal[False]:
return False
def __repr__(self) -> Literal["<Skip>"]:
return "<Skip>"
class BaseImpl(Generic[R]):
"""
A function wrapped with dataset support constraints.
The ``include``, ``exclude`` properties form a `NIMPLY gate`_ (`Material nonimplication`_).
Examples
--------
For some dataset ``D``, we can use ``fn`` if::
impl: BaseImpl
impl.include(D) and not impl.exclude(D)
.. _NIMPLY gate:
https://en.m.wikipedia.org/wiki/NIMPLY_gate
.. _Material nonimplication:
https://en.m.wikipedia.org/wiki/Material_nonimplication#Truth_table
"""
fn: Callable[..., R]
"""Wrapped read/scan function."""
include: MetaIs
"""Constraint indicating ``fn`` **supports** reading a dataset."""
exclude: MetaIs
"""Constraint *subsetting* ``include`` to mark **non-support**."""
def __init__(
self,
fn: Callable[..., R],
include: MetaIs,
exclude: MetaIs | None,
kwds: dict[str, Any],
/,
) -> None:
exclude = exclude or self._exclude_none()
if not include.isdisjoint(exclude):
intersection = ", ".join(f"{k}={v!r}" for k, v in include & exclude)
msg = f"Constraints overlap at: `{intersection}`\ninclude={include!r}\nexclude={exclude!r}"
raise TypeError(msg)
object.__setattr__(self, "fn", partial(fn, **kwds) if kwds else fn)
object.__setattr__(self, "include", include)
object.__setattr__(self, "exclude", exclude)
def unwrap_or_skip(
self, meta: Items, /
) -> Callable[..., R] | type[AltairDatasetsError] | Skip:
"""
Indicate an action to take for a dataset.
**Supports** dataset, use this function::
Callable[..., R]
Has explicitly marked as **not supported**::
type[AltairDatasetsError]
No relevant constraints overlap, safe to check others::
Skip
"""
if self.include.issubset(meta):
return self.fn if self.exclude.isdisjoint(meta) else AltairDatasetsError
return Skip.skip
@classmethod
def _exclude_none(cls) -> MetaIs:
"""Represents the empty set."""
return is_meta()
def __setattr__(self, name: str, value: Any):
msg = (
f"{type(self).__name__!r} is immutable.\n"
f"Could not assign self.{name} = {value}"
)
raise TypeError(msg)
@property
def _inferred_package(self) -> str:
return _root_package_name(_unwrap_partial(self.fn), "UNKNOWN")
def __repr__(self) -> str:
tp_name = f"{type(self).__name__}[{self._inferred_package}?]"
return f"{tp_name}({self})"
def __str__(self) -> str:
if isinstance(self.fn, partial):
fn = _unwrap_partial(self.fn)
kwds = self.fn.keywords.items()
fn_repr = f"{fn.__name__}(..., {', '.join(f'{k}={v!r}' for k, v in kwds)})"
else:
fn_repr = f"{self.fn.__name__}(...)"
inc, exc = self.include, self.exclude
return f"{fn_repr}, {f'include={inc!r}, exclude={exc!r}' if exc else repr(inc)}"
@property
def _relevant_columns(self) -> Iterator[str]:
name = itemgetter(0)
yield from (name(obj) for obj in chain(self.include, self.exclude))
@property
def _include_expr(self) -> nw.Expr:
return (
self.include.to_expr() & ~self.exclude.to_expr()
if self.exclude
else self.include.to_expr()
)
@property
def _exclude_expr(self) -> nw.Expr:
if self.exclude:
return self.include.to_expr() & self.exclude.to_expr()
msg = f"Unable to generate an exclude expression without setting exclude\n\n{self!r}"
raise TypeError(msg)
def read(
fn: Callable[..., IntoDataFrameT],
/,
include: MetaIs,
exclude: MetaIs | None = None,
**kwds: Any,
) -> Read[IntoDataFrameT]:
return BaseImpl(fn, include, exclude, kwds)
def scan(
fn: Callable[..., IntoFrameT],
/,
include: MetaIs,
exclude: MetaIs | None = None,
**kwds: Any,
) -> Scan[IntoFrameT]:
return BaseImpl(fn, include, exclude, kwds)
def into_scan(impl: Read[IntoDataFrameT], /) -> Scan[nw.LazyFrame[IntoDataFrameT]]:
def scan_fn(
fn: Callable[..., IntoDataFrameT], /
) -> Callable[..., nw.LazyFrame[IntoDataFrameT]]:
@wraps(_unwrap_partial(fn))
def wrapper(*args: Any, **kwds: Any) -> nw.LazyFrame[IntoDataFrameT]:
return nw.from_native(fn(*args, **kwds)).lazy()
return wrapper
return scan(scan_fn(impl.fn), impl.include, impl.exclude)
def is_available(
pkg_names: str | Iterable[str], *more_pkg_names: str, require_all: bool = True
) -> bool:
"""
Check for importable package(s), without raising on failure.
Parameters
----------
pkg_names, more_pkg_names
One or more packages.
require_all
* ``True`` every package.
* ``False`` at least one package.
"""
if not more_pkg_names and isinstance(pkg_names, str):
return find_spec(pkg_names) is not None
pkgs_names = pkg_names if not isinstance(pkg_names, str) else (pkg_names,)
names = chain(pkgs_names, more_pkg_names)
fn = all if require_all else any
return fn(find_spec(name) is not None for name in names)
def _root_package_name(obj: Any, default: str, /) -> str:
# NOTE: Defers importing `inspect`, if we can get the module name
if hasattr(obj, "__module__"):
return obj.__module__.split(".")[0]
else:
from inspect import getmodule
module = getmodule(obj)
if module and (pkg := module.__package__):
return pkg.split(".")[0]
return default
def _unwrap_partial(fn: Any, /) -> Any:
# NOTE: ``functools._unwrap_partial``
func = fn
while isinstance(func, partial):
func = func.func
return func
def pl_only() -> tuple[Sequence[Read[pl.DataFrame]], Sequence[Scan[pl.LazyFrame]]]:
import polars as pl
pl_read_json = read(_pl_read_json_roundtrip(get_polars()), is_json)
if is_available("polars_st"):
fn_json: Sequence[Read[pl.DataFrame]] = (
_pl_read_json_polars_st_topo_impl(), # TopoJSON files first
_pl_read_json_polars_st_impl(), # Then other spatial JSON
pl_read_json,
)
else:
fn_json = (pl_read_json,)
read_fns = (
read(pl.read_csv, is_csv, try_parse_dates=True),
*fn_json,
read(pl.read_csv, is_tsv, separator="\t", try_parse_dates=True),
read(pl.read_ipc, is_arrow),
read(pl.read_parquet, is_parquet),
)
scan_fns = (scan(pl.scan_parquet, is_parquet),)
return read_fns, scan_fns
def pd_only() -> Sequence[Read[pd.DataFrame]]:
import pandas as pd
opt: Sequence[Read[pd.DataFrame]]
if is_available("pyarrow"):
opt = read(pd.read_feather, is_arrow), read(pd.read_parquet, is_parquet)
elif is_available("fastparquet"):
opt = (read(pd.read_parquet, is_parquet),)
else:
opt = ()
pd_read_json = read(_pd_read_json(get_pandas()), is_json, exclude=is_spatial)
if is_available("geopandas"):
fn_json: Sequence[Read[pd.DataFrame]] = (
_pd_read_json_geopandas_impl(),
pd_read_json,
)
else:
fn_json = (pd_read_json,)
return (
read(pd.read_csv, is_csv),
*fn_json,
read(pd.read_csv, is_tsv, sep="\t"),
*opt,
)
def pd_pyarrow() -> Sequence[Read[pd.DataFrame]]:
import pandas as pd
kwds: dict[str, Any] = {"dtype_backend": "pyarrow"}
pd_read_json = read(
_pd_read_json(get_pandas()), is_json, exclude=is_spatial, **kwds
)
if is_available("geopandas"):
fn_json: Sequence[Read[pd.DataFrame]] = (
_pd_read_json_geopandas_impl(),
pd_read_json,
)
else:
fn_json = (pd_read_json,)
return (
read(pd.read_csv, is_csv, **kwds),
*fn_json,
read(pd.read_csv, is_tsv, sep="\t", **kwds),
read(pd.read_feather, is_arrow, **kwds),
read(pd.read_parquet, is_parquet, **kwds),
)
def pa_any() -> Sequence[Read[pa.Table]]:
from pyarrow import csv, feather, parquet
return (
read(csv.read_csv, is_csv),
_pa_read_json_impl(),
read(csv.read_csv, is_tsv, parse_options=csv.ParseOptions(delimiter="\t")), # pyright: ignore[reportCallIssue]
read(feather.read_table, is_arrow),
read(parquet.read_table, is_parquet),
)
def _pa_read_json_impl() -> Read[pa.Table]:
"""
Mitigating ``pyarrow``'s `line-delimited`_ JSON requirement.
.. _line-delimited:
https://arrow.apache.org/docs/python/json.html#reading-json-files
"""
if is_available("polars"):
polars_ns = get_polars()
if polars_ns is not None:
return read(_pl_read_json_roundtrip_to_arrow(polars_ns), is_json)
if is_available("pandas"):
pandas_ns = get_pandas()
if pandas_ns is not None:
return read(_pd_read_json_to_arrow(pandas_ns), is_json, exclude=is_spatial)
return read(_stdlib_read_json_to_arrow, is_json, exclude=is_not_tabular)
def _pd_read_json(ns: ModuleType, /) -> Callable[..., pd.DataFrame]:
@wraps(ns.read_json)
def fn(source: Path | Any, /, **kwds: Any) -> pd.DataFrame:
return _pd_fix_dtypes_nw(ns.read_json(source, **kwds), **kwds).to_native()
return fn
def _pd_read_json_geopandas_impl() -> Read[pd.DataFrame]:
import geopandas
@wraps(geopandas.read_file)
def fn(source: Path | Any, /, schema: Any = None, **kwds: Any) -> pd.DataFrame:
return geopandas.read_file(source, **kwds)
return read(fn, is_meta(is_spatial=True, suffix=".json"))
def _pd_fix_dtypes_nw(
df: pd.DataFrame, /, *, dtype_backend: Any = None, **kwds: Any
) -> nw.DataFrame[pd.DataFrame]:
kwds = {"dtype_backend": dtype_backend} if dtype_backend else {}
return (
df.convert_dtypes(**kwds)
.pipe(nw.from_native, eager_only=True)
.with_columns(nw.selectors.by_dtype(nw.Object).cast(nw.String))
)
def _pd_read_json_to_arrow(ns: ModuleType, /) -> Callable[..., pa.Table]:
@wraps(ns.read_json)
def fn(source: Path | Any, /, *, schema: Any = None, **kwds: Any) -> pa.Table:
"""``schema`` is only here to swallow the ``SchemaCache`` if used."""
return (
ns.read_json(source, **kwds)
.pipe(_pd_fix_dtypes_nw, dtype_backend="pyarrow")
.to_arrow()
)
return fn
def _pl_read_json_polars_st_impl() -> Read[pl.DataFrame]:
import polars_st as st
@wraps(st.read_file)
def fn(source: Path | Any, /, schema: Any = None, **kwds: Any) -> pl.DataFrame:
return st.read_file(source, **kwds)
return read(fn, is_meta(is_spatial=True, suffix=".json"))
def _pl_read_json_polars_st_topo_impl() -> Read[pl.DataFrame]:
import polars_st as st
@wraps(st.read_file)
def fn(source: Path | Any, /, schema: Any = None, **kwds: Any) -> pl.DataFrame:
# Add TopoJSON driver prefix for URLs
if isinstance(source, str) and source.startswith("http"):
source = f"TopoJSON:{source}"
return st.read_file(source, **kwds)
return read(fn, is_topo)
def _pl_read_json_roundtrip(ns: ModuleType, /) -> Callable[..., pl.DataFrame]:
"""
Try to utilize better date parsing available in `pl.read_csv`_.
`pl.read_json`_ has few options when compared to `pl.read_csv`_.
Chaining the two together - *where possible* - is still usually faster than `pandas.read_json`_.
.. _pl.read_json:
https://docs.pola.rs/api/python/stable/reference/api/polars.read_json.html
.. _pl.read_csv:
https://docs.pola.rs/api/python/stable/reference/api/polars.read_csv.html
.. _pandas.read_json:
https://pandas.pydata.org/docs/reference/api/pandas.read_json.html
"""
from io import BytesIO
@wraps(ns.read_json)
def fn(source: Path | IOBase, /, **kwds: Any) -> pl.DataFrame:
df = ns.read_json(source, **kwds)
if any(tp.is_nested() for tp in df.schema.dtypes()):
return df
buf = BytesIO()
df.write_csv(buf)
if kwds:
SHARED_KWDS = {"schema", "schema_overrides", "infer_schema_length"}
kwds = {k: v for k, v in kwds.items() if k in SHARED_KWDS}
return ns.read_csv(buf, try_parse_dates=True, **kwds)
return fn
def _pl_read_json_roundtrip_to_arrow(ns: ModuleType, /) -> Callable[..., pa.Table]:
eager = _pl_read_json_roundtrip(ns)
@wraps(ns.read_json)
def fn(source: Path | IOBase, /, **kwds: Any) -> pa.Table:
return eager(source).to_arrow()
return fn
def _stdlib_read_json(source: Path | Any, /) -> Any:
import json
if not isinstance(source, Path):
return json.load(source)
else:
with Path(source).open(encoding="utf-8") as f:
return json.load(f)
def _stdlib_read_json_to_arrow(source: Path | Any, /, **kwds: Any) -> pa.Table:
import pyarrow as pa
rows: list[dict[str, Any]] = _stdlib_read_json(source)
try:
return pa.Table.from_pylist(rows, **kwds)
except TypeError:
import csv
import io
from pyarrow import csv as pa_csv
with io.StringIO() as f:
writer = csv.DictWriter(f, rows[0].keys(), dialect=csv.unix_dialect)
writer.writeheader()
writer.writerows(rows)
with io.BytesIO(f.getvalue().encode()) as f2:
return pa_csv.read_csv(f2)