From 551287ae051ca50554505654f1d6942b7cb592f5 Mon Sep 17 00:00:00 2001 From: Philipp A Date: Mon, 4 Mar 2024 10:49:48 +0100 Subject: [PATCH] remove circularity and implicit module members (#1396) --- anndata/__init__.py | 4 ++ anndata/_core/aligned_df.py | 104 ++++++++++++++++++++++++++++++++++++ anndata/_core/anndata.py | 99 ++-------------------------------- anndata/_core/raw.py | 3 +- anndata/_io/__init__.py | 3 ++ anndata/tests/test_raw.py | 2 +- 6 files changed, 117 insertions(+), 98 deletions(-) create mode 100644 anndata/_core/aligned_df.py diff --git a/anndata/__init__.py b/anndata/__init__.py index 6cae971ac..97c48dacd 100644 --- a/anndata/__init__.py +++ b/anndata/__init__.py @@ -1,4 +1,5 @@ """Annotated multivariate observation data.""" + from __future__ import annotations try: # See https://github.com/maresb/hatch-vcs-footgun-example @@ -45,6 +46,9 @@ # Experimental needs to be imported last from . import experimental # isort: skip +# We use these in tests by attribute access +from . import _io, logging # noqa: F401 isort: skip + def read(*args, **kwargs): import warnings diff --git a/anndata/_core/aligned_df.py b/anndata/_core/aligned_df.py new file mode 100644 index 000000000..7e6d57ca2 --- /dev/null +++ b/anndata/_core/aligned_df.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import warnings +from functools import singledispatch +from typing import TYPE_CHECKING, Any, Literal + +import pandas as pd +from pandas.api.types import is_string_dtype + +from .._warnings import ImplicitModificationWarning + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + + +@singledispatch +def _gen_dataframe( + anno: Mapping[str, Any], + index_names: Iterable[str], + *, + source: Literal["X", "shape"], + attr: Literal["obs", "var"], + length: int | None = None, +) -> pd.DataFrame: + if anno is None or len(anno) == 0: + anno = {} + + def mk_index(l: int) -> pd.Index: + return pd.RangeIndex(0, l, name=None).astype(str) + + for index_name in index_names: + if index_name not in anno: + continue + df = pd.DataFrame( + anno, + index=anno[index_name], + columns=[k for k in anno.keys() if k != index_name], + ) + break + else: + df = pd.DataFrame( + anno, + index=None if length is None else mk_index(length), + columns=None if len(anno) else [], + ) + + if length is None: + df.index = mk_index(len(df)) + elif length != len(df): + raise _mk_df_error(source, attr, length, len(df)) + return df + + +@_gen_dataframe.register(pd.DataFrame) +def _gen_dataframe_df( + anno: pd.DataFrame, + index_names: Iterable[str], + *, + source: Literal["X", "shape"], + attr: Literal["obs", "var"], + length: int | None = None, +): + if length is not None and length != len(anno): + raise _mk_df_error(source, attr, length, len(anno)) + anno = anno.copy(deep=False) + if not is_string_dtype(anno.index): + warnings.warn("Transforming to str index.", ImplicitModificationWarning) + anno.index = anno.index.astype(str) + if not len(anno.columns): + anno.columns = anno.columns.astype(str) + return anno + + +@_gen_dataframe.register(pd.Series) +@_gen_dataframe.register(pd.Index) +def _gen_dataframe_1d( + anno: pd.Series | pd.Index, + index_names: Iterable[str], + *, + source: Literal["X", "shape"], + attr: Literal["obs", "var"], + length: int | None = None, +): + raise ValueError(f"Cannot convert {type(anno)} to {attr} DataFrame") + + +def _mk_df_error( + source: Literal["X", "shape"], + attr: Literal["obs", "var"], + expected: int, + actual: int, +): + if source == "X": + what = "row" if attr == "obs" else "column" + msg = ( + f"Observations annot. `{attr}` must have as many rows as `X` has {what}s " + f"({expected}), but has {actual} rows." + ) + else: + msg = ( + f"`shape` is inconsistent with `{attr}` " + "({actual} {what}s instead of {expected})" + ) + return ValueError(msg) diff --git a/anndata/_core/anndata.py b/anndata/_core/anndata.py index 7d4f2e573..bcb5339d3 100644 --- a/anndata/_core/anndata.py +++ b/anndata/_core/anndata.py @@ -1,6 +1,7 @@ """\ Main class and helper functions. """ + from __future__ import annotations import collections.abc as cabc @@ -9,7 +10,7 @@ from collections.abc import Iterable, Mapping, MutableMapping, Sequence from copy import copy, deepcopy from enum import Enum -from functools import partial, singledispatch +from functools import partial from pathlib import Path from textwrap import dedent from typing import ( # Meta # Generic ABCs # Generic @@ -23,12 +24,10 @@ import pandas as pd from natsort import natsorted from numpy import ma -from pandas.api.types import infer_dtype, is_string_dtype +from pandas.api.types import infer_dtype from scipy import sparse from scipy.sparse import issparse -from anndata._warnings import ImplicitModificationWarning - from .. import utils from .._settings import settings from ..compat import ( @@ -42,6 +41,7 @@ from ..logging import anndata_logger as logger from ..utils import convert_to_dict, deprecated, dim_len, ensure_df_homogeneous from .access import ElementRef +from .aligned_df import _gen_dataframe from .aligned_mapping import ( AxisArrays, AxisArraysView, @@ -110,97 +110,6 @@ def _check_2d_shape(X): ) -def _mk_df_error( - source: Literal["X", "shape"], - attr: Literal["obs", "var"], - expected: int, - actual: int, -): - if source == "X": - what = "row" if attr == "obs" else "column" - msg = ( - f"Observations annot. `{attr}` must have as many rows as `X` has {what}s " - f"({expected}), but has {actual} rows." - ) - else: - msg = ( - f"`shape` is inconsistent with `{attr}` " - "({actual} {what}s instead of {expected})" - ) - return ValueError(msg) - - -@singledispatch -def _gen_dataframe( - anno: Mapping[str, Any], - index_names: Iterable[str], - *, - source: Literal["X", "shape"], - attr: Literal["obs", "var"], - length: int | None = None, -) -> pd.DataFrame: - if anno is None or len(anno) == 0: - anno = {} - - def mk_index(l: int) -> pd.Index: - return pd.RangeIndex(0, l, name=None).astype(str) - - for index_name in index_names: - if index_name not in anno: - continue - df = pd.DataFrame( - anno, - index=anno[index_name], - columns=[k for k in anno.keys() if k != index_name], - ) - break - else: - df = pd.DataFrame( - anno, - index=None if length is None else mk_index(length), - columns=None if len(anno) else [], - ) - - if length is None: - df.index = mk_index(len(df)) - elif length != len(df): - raise _mk_df_error(source, attr, length, len(df)) - return df - - -@_gen_dataframe.register(pd.DataFrame) -def _gen_dataframe_df( - anno: pd.DataFrame, - index_names: Iterable[str], - *, - source: Literal["X", "shape"], - attr: Literal["obs", "var"], - length: int | None = None, -): - if length is not None and length != len(anno): - raise _mk_df_error(source, attr, length, len(anno)) - anno = anno.copy(deep=False) - if not is_string_dtype(anno.index): - warnings.warn("Transforming to str index.", ImplicitModificationWarning) - anno.index = anno.index.astype(str) - if not len(anno.columns): - anno.columns = anno.columns.astype(str) - return anno - - -@_gen_dataframe.register(pd.Series) -@_gen_dataframe.register(pd.Index) -def _gen_dataframe_1d( - anno: pd.Series | pd.Index, - index_names: Iterable[str], - *, - source: Literal["X", "shape"], - attr: Literal["obs", "var"], - length: int | None = None, -): - raise ValueError(f"Cannot convert {type(anno)} to {attr} DataFrame") - - class AnnData(metaclass=utils.DeprecationMixinMeta): """\ An annotated data matrix. diff --git a/anndata/_core/raw.py b/anndata/_core/raw.py index 2b2e27277..8d94d2be6 100644 --- a/anndata/_core/raw.py +++ b/anndata/_core/raw.py @@ -8,6 +8,7 @@ from scipy.sparse import issparse from ..compat import CupyArray, CupySparseMatrix +from .aligned_df import _gen_dataframe from .aligned_mapping import AxisArrays from .index import _normalize_index, _subset, get_vector, unpack_index from .sparse_dataset import BaseCompressedSparseDataset, sparse_dataset @@ -29,8 +30,6 @@ def __init__( var: pd.DataFrame | Mapping[str, Sequence] | None = None, varm: AxisArrays | Mapping[str, np.ndarray] | None = None, ): - from .anndata import _gen_dataframe - self._adata = adata self._n_obs = adata.n_obs # construct manually diff --git a/anndata/_io/__init__.py b/anndata/_io/__init__.py index 7bb6ba506..9315d3369 100644 --- a/anndata/_io/__init__.py +++ b/anndata/_io/__init__.py @@ -20,6 +20,9 @@ def write_zarr(*args, **kw): return write_zarr(*args, **kw) +# We use this in test by attribute access +from . import specs # noqa: F401, E402 + __all__ = [ "read_csv", "read_excel", diff --git a/anndata/tests/test_raw.py b/anndata/tests/test_raw.py index b51376b9a..5ffd14d2e 100644 --- a/anndata/tests/test_raw.py +++ b/anndata/tests/test_raw.py @@ -4,7 +4,7 @@ import pytest import anndata as ad -from anndata._core.anndata import ImplicitModificationWarning +from anndata import ImplicitModificationWarning from anndata.tests.helpers import GEN_ADATA_DASK_ARGS, assert_equal, gen_adata # -------------------------------------------------------------------------------