Skip to content

Commit

Permalink
remove circularity and implicit module members (#1396)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored Mar 4, 2024
1 parent f0d3a6e commit 551287a
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 98 deletions.
4 changes: 4 additions & 0 deletions anndata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Annotated multivariate observation data."""

from __future__ import annotations

try: # See https://github.com/maresb/hatch-vcs-footgun-example
Expand Down Expand Up @@ -45,6 +46,9 @@
# Experimental needs to be imported last
from . import experimental # isort: skip

# We use these in tests by attribute access
from . import _io, logging # noqa: F401 isort: skip


def read(*args, **kwargs):
import warnings
Expand Down
104 changes: 104 additions & 0 deletions anndata/_core/aligned_df.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import annotations

import warnings
from functools import singledispatch
from typing import TYPE_CHECKING, Any, Literal

import pandas as pd
from pandas.api.types import is_string_dtype

from .._warnings import ImplicitModificationWarning

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping


@singledispatch
def _gen_dataframe(
anno: Mapping[str, Any],
index_names: Iterable[str],
*,
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
length: int | None = None,
) -> pd.DataFrame:
if anno is None or len(anno) == 0:
anno = {}

def mk_index(l: int) -> pd.Index:
return pd.RangeIndex(0, l, name=None).astype(str)

for index_name in index_names:
if index_name not in anno:
continue
df = pd.DataFrame(
anno,
index=anno[index_name],
columns=[k for k in anno.keys() if k != index_name],
)
break
else:
df = pd.DataFrame(
anno,
index=None if length is None else mk_index(length),
columns=None if len(anno) else [],
)

if length is None:
df.index = mk_index(len(df))
elif length != len(df):
raise _mk_df_error(source, attr, length, len(df))
return df


@_gen_dataframe.register(pd.DataFrame)
def _gen_dataframe_df(
anno: pd.DataFrame,
index_names: Iterable[str],
*,
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
length: int | None = None,
):
if length is not None and length != len(anno):
raise _mk_df_error(source, attr, length, len(anno))
anno = anno.copy(deep=False)
if not is_string_dtype(anno.index):
warnings.warn("Transforming to str index.", ImplicitModificationWarning)
anno.index = anno.index.astype(str)
if not len(anno.columns):
anno.columns = anno.columns.astype(str)
return anno


@_gen_dataframe.register(pd.Series)
@_gen_dataframe.register(pd.Index)
def _gen_dataframe_1d(
anno: pd.Series | pd.Index,
index_names: Iterable[str],
*,
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
length: int | None = None,
):
raise ValueError(f"Cannot convert {type(anno)} to {attr} DataFrame")


def _mk_df_error(
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
expected: int,
actual: int,
):
if source == "X":
what = "row" if attr == "obs" else "column"
msg = (
f"Observations annot. `{attr}` must have as many rows as `X` has {what}s "
f"({expected}), but has {actual} rows."
)
else:
msg = (
f"`shape` is inconsistent with `{attr}` "
"({actual} {what}s instead of {expected})"
)
return ValueError(msg)
99 changes: 4 additions & 95 deletions anndata/_core/anndata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""\
Main class and helper functions.
"""

from __future__ import annotations

import collections.abc as cabc
Expand All @@ -9,7 +10,7 @@
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
from copy import copy, deepcopy
from enum import Enum
from functools import partial, singledispatch
from functools import partial
from pathlib import Path
from textwrap import dedent
from typing import ( # Meta # Generic ABCs # Generic
Expand All @@ -23,12 +24,10 @@
import pandas as pd
from natsort import natsorted
from numpy import ma
from pandas.api.types import infer_dtype, is_string_dtype
from pandas.api.types import infer_dtype
from scipy import sparse
from scipy.sparse import issparse

from anndata._warnings import ImplicitModificationWarning

from .. import utils
from .._settings import settings
from ..compat import (
Expand All @@ -42,6 +41,7 @@
from ..logging import anndata_logger as logger
from ..utils import convert_to_dict, deprecated, dim_len, ensure_df_homogeneous
from .access import ElementRef
from .aligned_df import _gen_dataframe
from .aligned_mapping import (
AxisArrays,
AxisArraysView,
Expand Down Expand Up @@ -110,97 +110,6 @@ def _check_2d_shape(X):
)


def _mk_df_error(
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
expected: int,
actual: int,
):
if source == "X":
what = "row" if attr == "obs" else "column"
msg = (
f"Observations annot. `{attr}` must have as many rows as `X` has {what}s "
f"({expected}), but has {actual} rows."
)
else:
msg = (
f"`shape` is inconsistent with `{attr}` "
"({actual} {what}s instead of {expected})"
)
return ValueError(msg)


@singledispatch
def _gen_dataframe(
anno: Mapping[str, Any],
index_names: Iterable[str],
*,
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
length: int | None = None,
) -> pd.DataFrame:
if anno is None or len(anno) == 0:
anno = {}

def mk_index(l: int) -> pd.Index:
return pd.RangeIndex(0, l, name=None).astype(str)

for index_name in index_names:
if index_name not in anno:
continue
df = pd.DataFrame(
anno,
index=anno[index_name],
columns=[k for k in anno.keys() if k != index_name],
)
break
else:
df = pd.DataFrame(
anno,
index=None if length is None else mk_index(length),
columns=None if len(anno) else [],
)

if length is None:
df.index = mk_index(len(df))
elif length != len(df):
raise _mk_df_error(source, attr, length, len(df))
return df


@_gen_dataframe.register(pd.DataFrame)
def _gen_dataframe_df(
anno: pd.DataFrame,
index_names: Iterable[str],
*,
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
length: int | None = None,
):
if length is not None and length != len(anno):
raise _mk_df_error(source, attr, length, len(anno))
anno = anno.copy(deep=False)
if not is_string_dtype(anno.index):
warnings.warn("Transforming to str index.", ImplicitModificationWarning)
anno.index = anno.index.astype(str)
if not len(anno.columns):
anno.columns = anno.columns.astype(str)
return anno


@_gen_dataframe.register(pd.Series)
@_gen_dataframe.register(pd.Index)
def _gen_dataframe_1d(
anno: pd.Series | pd.Index,
index_names: Iterable[str],
*,
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
length: int | None = None,
):
raise ValueError(f"Cannot convert {type(anno)} to {attr} DataFrame")


class AnnData(metaclass=utils.DeprecationMixinMeta):
"""\
An annotated data matrix.
Expand Down
3 changes: 1 addition & 2 deletions anndata/_core/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from scipy.sparse import issparse

from ..compat import CupyArray, CupySparseMatrix
from .aligned_df import _gen_dataframe
from .aligned_mapping import AxisArrays
from .index import _normalize_index, _subset, get_vector, unpack_index
from .sparse_dataset import BaseCompressedSparseDataset, sparse_dataset
Expand All @@ -29,8 +30,6 @@ def __init__(
var: pd.DataFrame | Mapping[str, Sequence] | None = None,
varm: AxisArrays | Mapping[str, np.ndarray] | None = None,
):
from .anndata import _gen_dataframe

self._adata = adata
self._n_obs = adata.n_obs
# construct manually
Expand Down
3 changes: 3 additions & 0 deletions anndata/_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def write_zarr(*args, **kw):
return write_zarr(*args, **kw)


# We use this in test by attribute access
from . import specs # noqa: F401, E402

__all__ = [
"read_csv",
"read_excel",
Expand Down
2 changes: 1 addition & 1 deletion anndata/tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

import anndata as ad
from anndata._core.anndata import ImplicitModificationWarning
from anndata import ImplicitModificationWarning
from anndata.tests.helpers import GEN_ADATA_DASK_ARGS, assert_equal, gen_adata

# -------------------------------------------------------------------------------
Expand Down

0 comments on commit 551287a

Please sign in to comment.