Skip to content

Commit efa2986

Browse files
authored
Refactor test utils to use a class (#30)
1 parent 042b108 commit efa2986

File tree

13 files changed

+522
-296
lines changed

13 files changed

+522
-296
lines changed

docs/conf.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,20 @@
3232
"sphinx.ext.autodoc",
3333
"sphinx.ext.autosummary",
3434
"scanpydoc.elegant_typehints",
35+
"sphinx_autofixture",
3536
]
3637

3738
# API documentation when building
3839
nitpicky = True
3940
autosummary_generate = True
4041
autodoc_member_order = "bysource"
42+
autodoc_default_options = {
43+
"special-members": True,
44+
# everything except __call__ really, to avoid having to write autosummary templates
45+
"exclude-members": (
46+
"__setattr__,__delattr__,__repr__,__eq__,__or__,__ror__,__hash__,__weakref__,__init__,__new__"
47+
),
48+
}
4149
napoleon_google_docstring = False
4250
napoleon_numpy_docstring = True
4351
todo_include_todos = False
@@ -55,9 +63,11 @@
5563
"np.dtype": "numpy.dtype",
5664
"np.number": "numpy.number",
5765
"np.integer": "numpy.integer",
66+
"np.random.Generator": "numpy.random.Generator",
5867
"ArrayLike": "numpy.typing.ArrayLike",
5968
"DTypeLike": "numpy.typing.DTypeLike",
6069
"NDArray": "numpy.typing.NDArray",
70+
"_pytest.fixtures.FixtureRequest": "pytest.FixtureRequest",
6171
**{
6272
k: v
6373
for k_plain, v in {
@@ -74,10 +84,17 @@
7484
# If that doesn’t work, ignore them
7585
nitpick_ignore = {
7686
("py:class", "fast_array_utils.types.T_co"),
87+
("py:class", "Arr"),
88+
("py:class", "testing.fast_array_utils._array_type.Arr"),
89+
("py:class", "testing.fast_array_utils._array_type.Inner"),
90+
("py:class", "_DTypeLikeFloat32"),
91+
("py:class", "_DTypeLikeFloat64"),
7792
# sphinx bugs, should be covered by `autodoc_type_aliases` above
93+
("py:class", "Array"),
7894
("py:class", "ArrayLike"),
7995
("py:class", "DTypeLike"),
8096
("py:class", "NDArray"),
97+
("py:class", "_pytest.fixtures.FixtureRequest"),
8198
}
8299

83100
# Options for HTML output

docs/index.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
``fast_array_utils``
22
====================
33

4+
.. toctree::
5+
:hidden:
6+
7+
fast-array-utils <self>
8+
testing
9+
410
.. automodule:: fast_array_utils
511
:members:
612

7-
813
``fast_array_utils.conv``
914
-------------------------
1015

docs/testing.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
``testing.fast_array_utils``
2+
============================
3+
4+
.. automodule:: testing.fast_array_utils
5+
:members:
6+
7+
``testing.fast_array_utils.pytest``
8+
-----------------------------------
9+
10+
.. automodule:: testing.fast_array_utils.pytest
11+
:members:

pyproject.toml

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,33 @@ classifiers = [
2323
]
2424
dynamic = [ "description", "version" ]
2525
dependencies = [ "numba", "numpy" ]
26-
optional-dependencies.doc = [ "furo", "scanpydoc>=0.15.2", "sphinx>=8", "sphinx-autodoc-typehints" ]
26+
optional-dependencies.doc = [
27+
"furo",
28+
"pytest",
29+
"scanpydoc>=0.15.2",
30+
"sphinx>=8",
31+
"sphinx-autodoc-typehints",
32+
"sphinx-autofixture",
33+
]
2734
optional-dependencies.full = [ "dask", "fast-array-utils[sparse]", "h5py", "zarr" ]
2835
optional-dependencies.sparse = [ "scipy>=1.8" ]
2936
optional-dependencies.test = [ "coverage[toml]", "pytest", "pytest-codspeed" ]
3037
urls.'Documentation' = "https://icb-fast-array-utils.readthedocs-hosted.com/"
3138
urls.'Issue Tracker' = "https://github.com/scverse/fast-array-utils/issues"
3239
urls.'Source Code' = "https://github.com/scverse/fast-array-utils"
3340

34-
[tool.hatch.metadata.hooks.docstring-description]
41+
entry_points.pytest11.fast_array_utils = "testing.fast_array_utils.pytest"
3542

3643
[tool.hatch.version]
3744
source = "vcs"
3845
raw-options = { local_scheme = "no-local-version" } # be able to publish dev version
3946

47+
# TODO: support setting main package in the plugin
48+
# [tool.hatch.metadata.hooks.docstring-description]
49+
50+
[tool.hatch.build.targets.wheel]
51+
packages = [ "src/testing", "src/fast_array_utils" ]
52+
4053
[tool.hatch.envs.default]
4154
installer = "uv"
4255

@@ -85,6 +98,8 @@ lint.per-file-ignores."tests/**/test_*.py" = [
8598
"S101", # tests use `assert`
8699
]
87100
lint.allowed-confusables = [ "×", "" ]
101+
lint.flake8-bugbear.extend-immutable-calls = [ "testing.fast_array_utils.Flags" ]
102+
88103
lint.flake8-copyright.notice-rgx = "SPDX-License-Identifier: MPL-2\\.0"
89104
lint.flake8-type-checking.exempt-modules = [ ]
90105
lint.flake8-type-checking.strict = true

src/fast_array_utils/conv/_asarray.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22
from __future__ import annotations
33

44
from functools import singledispatch
5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, Any, cast
66

77
import numpy as np
8+
from numpy.typing import NDArray
89

910
from .. import types
1011

1112

1213
if TYPE_CHECKING:
13-
from typing import Any
14-
15-
from numpy.typing import ArrayLike, NDArray
14+
from numpy.typing import ArrayLike
1615

1716

1817
__all__ = ["asarray"]
@@ -64,9 +63,9 @@ def _(x: types.OutOfCoreDataset[types.CSBase | NDArray[Any]]) -> NDArray[Any]:
6463

6564
@asarray.register(types.CupyArray)
6665
def _(x: types.CupyArray) -> NDArray[Any]:
67-
return x.get() # type: ignore[no-any-return]
66+
return cast(NDArray[Any], x.get())
6867

6968

7069
@asarray.register(types.CupySparseMatrix)
7170
def _(x: types.CupySparseMatrix) -> NDArray[Any]:
72-
return x.toarray().get() # type: ignore[no-any-return]
71+
return cast(NDArray[Any], x.toarray().get())

src/fast_array_utils/stats/_sum.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,27 @@
22
from __future__ import annotations
33

44
from functools import partial, singledispatch
5-
from typing import TYPE_CHECKING, overload
5+
from typing import TYPE_CHECKING, Any, cast, overload
66

77
import numpy as np
8+
from numpy.typing import NDArray
89

910
from .. import types
1011

1112

1213
if TYPE_CHECKING:
13-
from typing import Any, Literal
14+
from typing import Literal
1415

15-
from numpy.typing import ArrayLike, DTypeLike, NDArray
16+
from numpy.typing import ArrayLike, DTypeLike
1617

1718

1819
@overload
1920
def sum(
20-
x: ArrayLike, /, *, axis: None = None, dtype: DTypeLike | None = None
21+
x: ArrayLike | types.ZarrArray, /, *, axis: None = None, dtype: DTypeLike | None = None
2122
) -> np.number[Any]: ...
2223
@overload
2324
def sum(
24-
x: ArrayLike, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
25+
x: ArrayLike | types.ZarrArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
2526
) -> NDArray[Any]: ...
2627
@overload
2728
def sum(
@@ -30,7 +31,11 @@ def sum(
3031

3132

3233
def sum(
33-
x: ArrayLike, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
34+
x: ArrayLike | types.ZarrArray,
35+
/,
36+
*,
37+
axis: Literal[0, 1, None] = None,
38+
dtype: DTypeLike | None = None,
3439
) -> NDArray[Any] | np.number[Any] | types.DaskArray:
3540
"""Sum over both or one axis.
3641
@@ -56,7 +61,7 @@ def _sum(
5661
dtype: DTypeLike | None = None,
5762
) -> NDArray[Any] | np.number[Any] | types.DaskArray:
5863
assert not isinstance(x, types.CSBase | types.DaskArray)
59-
return np.sum(x, axis=axis, dtype=dtype) # type: ignore[no-any-return]
64+
return cast(NDArray[Any] | np.number[Any], np.sum(x, axis=axis, dtype=dtype))
6065

6166

6267
@_sum.register(types.CSBase)
@@ -67,7 +72,7 @@ def _(
6772

6873
if isinstance(x, types.CSMatrix):
6974
x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x)
70-
return np.sum(x, axis=axis, dtype=dtype) # type: ignore[no-any-return]
75+
return cast(NDArray[Any] | np.number[Any], np.sum(x, axis=axis, dtype=dtype))
7176

7277

7378
@_sum.register(types.DaskArray)
@@ -108,11 +113,14 @@ def sum_drop_keepdims(
108113
# Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)
109114
dtype = np.zeros(1, dtype=x.dtype).sum().dtype
110115

111-
return reduction( # type: ignore[no-any-return,no-untyped-call]
112-
x,
113-
sum_drop_keepdims,
114-
partial(np.sum, dtype=dtype),
115-
axis=axis,
116-
dtype=dtype,
117-
meta=np.array([], dtype=dtype),
116+
return cast(
117+
types.DaskArray,
118+
reduction( # type: ignore[no-untyped-call]
119+
x,
120+
sum_drop_keepdims,
121+
partial(np.sum, dtype=dtype),
122+
axis=axis,
123+
dtype=dtype,
124+
meta=np.array([], dtype=dtype),
125+
),
118126
)

0 commit comments

Comments
 (0)