diff --git a/pandas-stubs/_libs/interval.pyi b/pandas-stubs/_libs/interval.pyi index fdb0a398c..f22a12778 100644 --- a/pandas-stubs/_libs/interval.pyi +++ b/pandas-stubs/_libs/interval.pyi @@ -21,7 +21,7 @@ from pandas.core.series import ( from pandas._typing import ( IntervalClosedType, IntervalT, - np_ndarray_bool, + np_1darray, npt, ) @@ -170,7 +170,9 @@ class Interval(IntervalMixin, Generic[_OrderableT]): @overload def __gt__(self, other: Interval[_OrderableT]) -> bool: ... @overload - def __gt__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ... + def __gt__( + self: IntervalT, other: IntervalIndex[IntervalT] + ) -> np_1darray[np.bool]: ... @overload def __gt__( self, @@ -179,7 +181,9 @@ class Interval(IntervalMixin, Generic[_OrderableT]): @overload def __lt__(self, other: Interval[_OrderableT]) -> bool: ... @overload - def __lt__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ... + def __lt__( + self: IntervalT, other: IntervalIndex[IntervalT] + ) -> np_1darray[np.bool]: ... @overload def __lt__( self, @@ -188,7 +192,9 @@ class Interval(IntervalMixin, Generic[_OrderableT]): @overload def __ge__(self, other: Interval[_OrderableT]) -> bool: ... @overload - def __ge__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ... + def __ge__( + self: IntervalT, other: IntervalIndex[IntervalT] + ) -> np_1darray[np.bool]: ... @overload def __ge__( self, @@ -197,11 +203,15 @@ class Interval(IntervalMixin, Generic[_OrderableT]): @overload def __le__(self, other: Interval[_OrderableT]) -> bool: ... @overload - def __le__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ... + def __le__( + self: IntervalT, other: IntervalIndex[IntervalT] + ) -> np_1darray[np.bool]: ... @overload def __eq__(self, other: Interval[_OrderableT]) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] @overload - def __eq__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ... + def __eq__( + self: IntervalT, other: IntervalIndex[IntervalT] + ) -> np_1darray[np.bool]: ... @overload def __eq__(self, other: Series[_OrderableT]) -> Series[bool]: ... # type: ignore[overload-overlap] @overload @@ -209,7 +219,9 @@ class Interval(IntervalMixin, Generic[_OrderableT]): @overload def __ne__(self, other: Interval[_OrderableT]) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] @overload - def __ne__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ... + def __ne__( + self: IntervalT, other: IntervalIndex[IntervalT] + ) -> np_1darray[np.bool]: ... @overload def __ne__(self, other: Series[_OrderableT]) -> Series[bool]: ... # type: ignore[overload-overlap] @overload diff --git a/pandas-stubs/_libs/tslibs/timestamps.pyi b/pandas-stubs/_libs/tslibs/timestamps.pyi index 11302a12c..e97ceb449 100644 --- a/pandas-stubs/_libs/tslibs/timestamps.pyi +++ b/pandas-stubs/_libs/tslibs/timestamps.pyi @@ -40,9 +40,11 @@ from pandas._libs.tslibs import ( Timedelta, ) from pandas._typing import ( + ShapeT, TimestampNonexistent, TimeUnit, - np_ndarray_bool, + np_1darray, + np_ndarray, npt, ) @@ -180,40 +182,48 @@ class Timestamp(datetime, SupportsIndex): @overload # type: ignore[override] def __le__(self, other: Timestamp | datetime | np.datetime64) -> bool: ... # type: ignore[misc] @overload + def __le__(self, other: DatetimeIndex) -> np_1darray[np.bool]: ... + @overload def __le__( - self, other: DatetimeIndex | npt.NDArray[np.datetime64] - ) -> np_ndarray_bool: ... + self, other: np_ndarray[ShapeT, np.datetime64] + ) -> np_ndarray[ShapeT, np.bool]: ... @overload def __le__(self, other: TimestampSeries) -> Series[bool]: ... @overload # type: ignore[override] def __lt__(self, other: Timestamp | datetime | np.datetime64) -> bool: ... # type: ignore[misc] @overload + def __lt__(self, other: DatetimeIndex) -> np_1darray[np.bool]: ... + @overload def __lt__( - self, other: DatetimeIndex | npt.NDArray[np.datetime64] - ) -> np_ndarray_bool: ... + self, other: np_ndarray[ShapeT, np.datetime64] + ) -> np_ndarray[ShapeT, np.bool]: ... @overload def __lt__(self, other: TimestampSeries) -> Series[bool]: ... @overload # type: ignore[override] def __ge__(self, other: Timestamp | datetime | np.datetime64) -> bool: ... # type: ignore[misc] @overload + def __ge__(self, other: DatetimeIndex) -> np_1darray[np.bool]: ... + @overload def __ge__( - self, other: DatetimeIndex | npt.NDArray[np.datetime64] - ) -> np_ndarray_bool: ... + self, other: np_ndarray[ShapeT, np.datetime64] + ) -> np_ndarray[ShapeT, np.bool]: ... @overload def __ge__(self, other: TimestampSeries) -> Series[bool]: ... @overload # type: ignore[override] def __gt__(self, other: Timestamp | datetime | np.datetime64) -> bool: ... # type: ignore[misc] @overload + def __gt__(self, other: DatetimeIndex) -> np_1darray[np.bool]: ... + @overload def __gt__( - self, other: DatetimeIndex | npt.NDArray[np.datetime64] - ) -> np_ndarray_bool: ... + self, other: np_ndarray[ShapeT, np.datetime64] + ) -> np_ndarray[ShapeT, np.bool]: ... @overload def __gt__(self, other: TimestampSeries) -> Series[bool]: ... # error: Signature of "__add__" incompatible with supertype "date"/"datetime" @overload # type: ignore[override] def __add__( - self, other: npt.NDArray[np.timedelta64] - ) -> npt.NDArray[np.datetime64]: ... + self, other: np_ndarray[ShapeT, np.timedelta64] + ) -> np_ndarray[ShapeT, np.datetime64]: ... @overload def __add__(self, other: timedelta | np.timedelta64 | Tick) -> Self: ... @overload @@ -226,8 +236,8 @@ class Timestamp(datetime, SupportsIndex): def __radd__(self, other: TimedeltaIndex) -> DatetimeIndex: ... @overload def __radd__( - self, other: npt.NDArray[np.timedelta64] - ) -> npt.NDArray[np.datetime64]: ... + self, other: np_ndarray[ShapeT, np.timedelta64] + ) -> np_ndarray[ShapeT, np.datetime64]: ... # TODO: test dt64 @overload # type: ignore[override] def __sub__(self, other: Timestamp | datetime | np.datetime64) -> Timedelta: ... @@ -241,14 +251,16 @@ class Timestamp(datetime, SupportsIndex): def __sub__(self, other: TimestampSeries) -> TimedeltaSeries: ... @overload def __sub__( - self, other: npt.NDArray[np.timedelta64] - ) -> npt.NDArray[np.datetime64]: ... + self, other: np_ndarray[ShapeT, np.timedelta64] + ) -> np_ndarray[ShapeT, np.datetime64]: ... @overload def __eq__(self, other: Timestamp | datetime | np.datetime64) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] @overload def __eq__(self, other: TimestampSeries) -> Series[bool]: ... # type: ignore[overload-overlap] @overload - def __eq__(self, other: npt.NDArray[np.datetime64] | Index) -> np_ndarray_bool: ... # type: ignore[overload-overlap] + def __eq__(self, other: Index) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap] + @overload # TODO: using shape-aware arrays similar to other methods doesn't work in mypy + def __eq__(self, other: npt.NDArray[np.datetime64]) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap] @overload def __eq__(self, other: object) -> Literal[False]: ... @overload @@ -256,7 +268,9 @@ class Timestamp(datetime, SupportsIndex): @overload def __ne__(self, other: TimestampSeries) -> Series[bool]: ... # type: ignore[overload-overlap] @overload - def __ne__(self, other: npt.NDArray[np.datetime64] | Index) -> np_ndarray_bool: ... # type: ignore[overload-overlap] + def __ne__(self, other: Index) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap] + @overload # TODO: using shape-aware arrays similar to other methods doesn't work in mypy + def __ne__(self, other: npt.NDArray[np.datetime64]) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap] @overload def __ne__(self, other: object) -> Literal[True]: ... def __hash__(self) -> int: ... diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index f14c393fe..16eedcd4c 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -819,6 +819,15 @@ np_ndarray_complex: TypeAlias = npt.NDArray[np.complexfloating] np_ndarray_bool: TypeAlias = npt.NDArray[np.bool_] np_ndarray_str: TypeAlias = npt.NDArray[np.str_] +# Define shape and generic type variables with defaults similar to numpy +GenericT = TypeVar("GenericT", bound=np.generic, default=Any) +ShapeT = TypeVar("ShapeT", bound=tuple[int, ...], default=tuple[Any, ...]) +# Numpy ndarray with more ergonomic typevar +np_ndarray: TypeAlias = np.ndarray[ShapeT, np.dtype[GenericT]] +# Numpy arrays with known shape (Do not use as argument types, only as return types) +np_1darray: TypeAlias = np.ndarray[tuple[int], np.dtype[GenericT]] +np_2darray: TypeAlias = np.ndarray[tuple[int, int], np.dtype[GenericT]] + IndexType: TypeAlias = slice | np_ndarray_anyint | Index | list[int] | Series[int] MaskType: TypeAlias = Series[bool] | np_ndarray_bool | list[bool] diff --git a/pandas-stubs/core/algorithms.pyi b/pandas-stubs/core/algorithms.pyi index 803878ba8..a89c4d829 100644 --- a/pandas-stubs/core/algorithms.pyi +++ b/pandas-stubs/core/algorithms.pyi @@ -18,6 +18,7 @@ from pandas._typing import ( AnyArrayLike, IntervalT, TakeIndexer, + np_1darray, ) # These are type: ignored because the Index types overlap due to inheritance but indices @@ -54,14 +55,14 @@ def factorize( sort: bool = ..., use_na_sentinel: bool = ..., size_hint: int | None = ..., -) -> tuple[np.ndarray, Index]: ... +) -> tuple[np_1darray, Index]: ... @overload def factorize( values: Categorical, sort: bool = ..., use_na_sentinel: bool = ..., size_hint: int | None = ..., -) -> tuple[np.ndarray, Categorical]: ... +) -> tuple[np_1darray, Categorical]: ... def value_counts( values: AnyArrayLike | list | tuple, sort: bool = True, diff --git a/pandas-stubs/core/arrays/base.pyi b/pandas-stubs/core/arrays/base.pyi index b5c0ef7cc..b59ef8639 100644 --- a/pandas-stubs/core/arrays/base.pyi +++ b/pandas-stubs/core/arrays/base.pyi @@ -12,6 +12,7 @@ from pandas._typing import ( ScalarIndexer, SequenceIndexer, TakeIndexer, + np_1darray, npt, ) @@ -31,7 +32,7 @@ class ExtensionArray: dtype: npt.DTypeLike | None = ..., copy: bool = False, na_value: Scalar = ..., - ) -> np.ndarray: ... + ) -> np_1darray: ... @property def dtype(self) -> ExtensionDtype: ... @property @@ -44,13 +45,13 @@ class ExtensionArray: def isna(self) -> ArrayLike: ... def argsort( self, *, ascending: bool = ..., kind: str = ..., **kwargs - ) -> np.ndarray: ... + ) -> np_1darray: ... def fillna(self, value=..., method=None, limit=None): ... def dropna(self): ... def shift(self, periods: int = 1, fill_value: object = ...) -> Self: ... def unique(self): ... def searchsorted(self, value, side: str = ..., sorter=...): ... - def factorize(self, use_na_sentinel: bool = True) -> tuple[np.ndarray, Self]: ... + def factorize(self, use_na_sentinel: bool = True) -> tuple[np_1darray, Self]: ... def repeat(self, repeats, axis=...): ... def take( self, @@ -60,7 +61,7 @@ class ExtensionArray: fill_value=..., ) -> Self: ... def copy(self) -> Self: ... - def view(self, dtype=...) -> Self | np.ndarray: ... + def view(self, dtype=...) -> Self | np_1darray: ... def ravel(self, order="C") -> Self: ... def tolist(self) -> list: ... def _reduce( diff --git a/pandas-stubs/core/arrays/categorical.pyi b/pandas-stubs/core/arrays/categorical.pyi index 57f075e21..a0ef49f7b 100644 --- a/pandas-stubs/core/arrays/categorical.pyi +++ b/pandas-stubs/core/arrays/categorical.pyi @@ -25,8 +25,7 @@ from pandas._typing import ( ScalarIndexer, SequenceIndexer, TakeIndexer, - np_ndarray_bool, - np_ndarray_int, + np_1darray, ) from pandas.core.dtypes.dtypes import CategoricalDtype as CategoricalDtype @@ -63,7 +62,7 @@ class Categorical(ExtensionArray): fastpath: bool = ..., ) -> Categorical: ... @property - def codes(self) -> np_ndarray_int: ... + def codes(self) -> np_1darray[np.signedinteger]: ... def set_ordered(self, value) -> Categorical: ... def as_ordered(self) -> Categorical: ... def as_unordered(self) -> Categorical: ... @@ -90,7 +89,7 @@ class Categorical(ExtensionArray): @property def shape(self): ... def shift(self, periods=1, fill_value=...): ... - def __array__(self, dtype=...) -> np.ndarray: ... + def __array__(self, dtype=...) -> np_1darray: ... def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): ... @property def T(self): ... @@ -98,10 +97,10 @@ class Categorical(ExtensionArray): def nbytes(self) -> int: ... def memory_usage(self, deep: bool = ...): ... def searchsorted(self, value, side: str = ..., sorter=...): ... - def isna(self) -> np_ndarray_bool: ... - def isnull(self) -> np_ndarray_bool: ... - def notna(self) -> np_ndarray_bool: ... - def notnull(self) -> np_ndarray_bool: ... + def isna(self) -> np_1darray[np.bool]: ... + def isnull(self) -> np_1darray[np.bool]: ... + def notna(self) -> np_1darray[np.bool]: ... + def notnull(self) -> np_1darray[np.bool]: ... def dropna(self): ... def value_counts(self, dropna: bool = True): ... def check_for_ordered(self, op) -> None: ... diff --git a/pandas-stubs/core/arrays/interval.pyi b/pandas-stubs/core/arrays/interval.pyi index 304cc2960..a7c8d5a6e 100644 --- a/pandas-stubs/core/arrays/interval.pyi +++ b/pandas-stubs/core/arrays/interval.pyi @@ -21,7 +21,7 @@ from pandas._typing import ( ScalarIndexer, SequenceIndexer, TakeIndexer, - np_ndarray_bool, + np_1darray, ) IntervalOrNA: TypeAlias = Interval | float @@ -99,7 +99,7 @@ class IntervalArray(IntervalMixin, ExtensionArray): def mid(self) -> Index: ... @property def is_non_overlapping_monotonic(self) -> bool: ... - def __array__(self, dtype=...) -> np.ndarray: ... + def __array__(self, dtype=...) -> np_1darray: ... def __arrow_array__(self, type=...): ... def to_tuples(self, na_tuple: bool = True): ... def repeat(self, repeats, axis: Axis | None = ...): ... @@ -108,5 +108,5 @@ class IntervalArray(IntervalMixin, ExtensionArray): @overload def contains( self, other: Scalar | ExtensionArray | Index | np.ndarray - ) -> np_ndarray_bool: ... + ) -> np_1darray[np.bool]: ... def overlaps(self, other: Interval) -> bool: ... diff --git a/pandas-stubs/core/base.pyi b/pandas-stubs/core/base.pyi index da7b69b74..78180a2ed 100644 --- a/pandas-stubs/core/base.pyi +++ b/pandas-stubs/core/base.pyi @@ -26,6 +26,7 @@ from pandas._typing import ( DropKeep, NDFrameT, Scalar, + np_1darray, npt, ) from pandas.util._decorators import cache_readonly @@ -63,7 +64,7 @@ class IndexOpsMixin(OpsMixin, Generic[S1]): copy: bool = False, na_value: Scalar = ..., **kwargs, - ) -> np.ndarray: ... + ) -> np_1darray: ... @property def empty(self) -> bool: ... def max(self, axis=..., skipna: bool = ..., **kwargs): ... @@ -114,7 +115,7 @@ class IndexOpsMixin(OpsMixin, Generic[S1]): def is_monotonic_increasing(self) -> bool: ... def factorize( self, sort: bool = False, use_na_sentinel: bool = True - ) -> tuple[np.ndarray, np.ndarray | Index | Categorical]: ... + ) -> tuple[np_1darray, np_1darray | Index | Categorical]: ... def searchsorted( self, value, side: Literal["left", "right"] = ..., sorter=... ) -> int | list[int]: ... diff --git a/pandas-stubs/core/dtypes/missing.pyi b/pandas-stubs/core/dtypes/missing.pyi index e36496cfd..16c7e12b3 100644 --- a/pandas-stubs/core/dtypes/missing.pyi +++ b/pandas-stubs/core/dtypes/missing.pyi @@ -4,20 +4,23 @@ from typing import ( ) import numpy as np -from numpy import typing as npt from pandas import ( DataFrame, Index, Series, ) +from pandas.core.arrays import ExtensionArray from typing_extensions import TypeIs from pandas._libs.missing import NAType from pandas._libs.tslibs import NaTType from pandas._typing import ( - ArrayLike, Scalar, ScalarT, + ShapeT, + np_1darray, + np_ndarray, + np_ndarray_bool, ) isposinf_scalar = ... @@ -28,7 +31,11 @@ def isna(obj: DataFrame) -> DataFrame: ... @overload def isna(obj: Series) -> Series[bool]: ... @overload -def isna(obj: Index | list[Any] | ArrayLike) -> npt.NDArray[np.bool_]: ... +def isna(obj: Index | ExtensionArray | list[ScalarT]) -> np_1darray[np.bool]: ... +@overload +def isna(obj: np_ndarray[ShapeT]) -> np_ndarray[ShapeT, np.bool]: ... +@overload +def isna(obj: list[Any]) -> np_ndarray_bool: ... @overload def isna( obj: Scalar | NaTType | NAType | None, @@ -41,7 +48,11 @@ def notna(obj: DataFrame) -> DataFrame: ... @overload def notna(obj: Series) -> Series[bool]: ... @overload -def notna(obj: Index | list[Any] | ArrayLike) -> npt.NDArray[np.bool_]: ... +def notna(obj: Index | ExtensionArray | list[ScalarT]) -> np_1darray[np.bool]: ... +@overload +def notna(obj: np_ndarray[ShapeT]) -> np_ndarray[ShapeT, np.bool]: ... +@overload +def notna(obj: Index | list[Any]) -> np_ndarray_bool: ... @overload def notna(obj: ScalarT | NaTType | NAType | None) -> TypeIs[ScalarT]: ... diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index a33321d75..14d8bfea1 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -165,6 +165,7 @@ from pandas._typing import ( ValueKeyFunc, WriteBuffer, XMLParsers, + np_2darray, npt, num, ) @@ -452,7 +453,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): dtype: npt.DTypeLike | None = ..., copy: bool = False, na_value: Scalar = ..., - ) -> np.ndarray: ... + ) -> np_2darray: ... @overload def to_dict( self, @@ -1766,7 +1767,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): @property def size(self) -> int: ... @property - def values(self) -> np.ndarray: ... + def values(self) -> np_2darray: ... # methods @final def abs(self) -> Self: ... diff --git a/pandas-stubs/core/indexes/accessors.pyi b/pandas-stubs/core/indexes/accessors.pyi index 6746c9dd4..e96615dee 100644 --- a/pandas-stubs/core/indexes/accessors.pyi +++ b/pandas-stubs/core/indexes/accessors.pyi @@ -10,7 +10,6 @@ from typing import ( ) import numpy as np -import numpy.typing as npt from pandas import ( DatetimeIndex, Index, @@ -40,6 +39,7 @@ from pandas._typing import ( TimestampConvention, TimeUnit, TimeZones, + np_1darray, np_ndarray_bool, ) @@ -88,7 +88,7 @@ class _DatetimeFieldOps( ): ... _DTBoolOpsReturnType = TypeVar( - "_DTBoolOpsReturnType", bound=Series[bool] | np_ndarray_bool + "_DTBoolOpsReturnType", bound=Series[bool] | np_1darray[np.bool] ) class _IsLeapYearProperty(Generic[_DTBoolOpsReturnType]): @@ -126,10 +126,10 @@ class _DatetimeObjectOps( ): ... _DTOtherOpsDateReturnType = TypeVar( - "_DTOtherOpsDateReturnType", bound=Series[dt.date] | np.ndarray + "_DTOtherOpsDateReturnType", bound=Series[dt.date] | np_1darray[np.object_] ) _DTOtherOpsTimeReturnType = TypeVar( - "_DTOtherOpsTimeReturnType", bound=Series[dt.time] | np.ndarray + "_DTOtherOpsTimeReturnType", bound=Series[dt.time] | np_1darray[np.object_] ) class _DatetimeOtherOps(Generic[_DTOtherOpsDateReturnType, _DTOtherOpsTimeReturnType]): @@ -280,7 +280,7 @@ class DatetimeProperties( _DTToPeriodReturnType, ], ): - def to_pydatetime(self) -> np.ndarray: ... + def to_pydatetime(self) -> np_1darray[np.object_]: ... def isocalendar(self) -> DataFrame: ... @property def unit(self) -> TimeUnit: ... @@ -296,7 +296,7 @@ _TDTotalSecondsReturnType = TypeVar( class _TimedeltaPropertiesNoRounding( Generic[_TDNoRoundingMethodReturnType, _TDTotalSecondsReturnType] ): - def to_pytimedelta(self) -> np.ndarray: ... + def to_pytimedelta(self) -> np_1darray[np.object_]: ... @property def components(self) -> DataFrame: ... @property @@ -401,10 +401,10 @@ class DatetimeIndexProperties( Properties, _DatetimeNoTZProperties[ Index[int], - np_ndarray_bool, + np_1darray[np.bool], DatetimeIndex, - np.ndarray, - np.ndarray, + np_1darray[np.object_], + np_1darray[np.object_], BaseOffset, DatetimeIndex, Index, @@ -416,7 +416,7 @@ class DatetimeIndexProperties( def is_normalized(self) -> bool: ... @property def tzinfo(self) -> _tzinfo | None: ... - def to_pydatetime(self) -> npt.NDArray[np.object_]: ... + def to_pydatetime(self) -> np_1darray[np.object_]: ... def std( self, axis: int | None = ..., ddof: int = ..., skipna: bool = ... ) -> Timedelta: ... diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index 0d7b440e6..942e2a180 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -62,8 +62,8 @@ from pandas._typing import ( SliceType, TimedeltaDtypeArg, TimestampDtypeArg, + np_1darray, np_ndarray_anyint, - np_ndarray_bool, np_ndarray_complex, np_ndarray_float, type_t, @@ -274,7 +274,7 @@ class Index(IndexOpsMixin[S1]): ) -> StringMethods[ Self, MultiIndex, - np_ndarray_bool, + np_1darray[np.bool], Index[list[_str]], Index[int], Index[bytes], @@ -286,7 +286,7 @@ class Index(IndexOpsMixin[S1]): def __len__(self) -> int: ... def __array__( self, dtype: _str | np.dtype = ..., copy: bool | None = ... - ) -> np.ndarray: ... + ) -> np_1darray: ... def __array_wrap__(self, result, context=...): ... @property def dtype(self) -> DtypeObj: ... @@ -354,7 +354,7 @@ class Index(IndexOpsMixin[S1]): def dropna(self, how: AnyAll = "any") -> Self: ... def unique(self, level=...) -> Self: ... def drop_duplicates(self, *, keep: DropKeep = ...) -> Self: ... - def duplicated(self, keep: DropKeep = "first") -> np_ndarray_bool: ... + def duplicated(self, keep: DropKeep = "first") -> np_1darray[np.bool]: ... def __and__(self, other: Never) -> Never: ... def __rand__(self, other: Never) -> Never: ... def __or__(self, other: Never) -> Never: ... @@ -378,7 +378,7 @@ class Index(IndexOpsMixin[S1]): result_name: Hashable = ..., sort: bool | None = None, ) -> Self: ... - def get_loc(self, key: Label) -> int | slice | np_ndarray_bool: ... + def get_loc(self, key: Label) -> int | slice | np_1darray[np.bool]: ... def get_indexer( self, target, method: ReindexMethod | None = ..., limit=..., tolerance=... ): ... @@ -400,7 +400,7 @@ class Index(IndexOpsMixin[S1]): sort: bool = ..., ): ... @property - def values(self) -> np.ndarray: ... + def values(self) -> np_1darray: ... @property def array(self) -> ExtensionArray: ... def memory_usage(self, deep: bool = False): ... @@ -445,7 +445,7 @@ class Index(IndexOpsMixin[S1]): @final def groupby(self, values) -> dict[Hashable, np.ndarray]: ... def map(self, mapper, na_action=...) -> Index: ... - def isin(self, values, level=...) -> np_ndarray_bool: ... + def isin(self, values, level=...) -> np_1darray[np.bool]: ... def slice_indexer( self, start: Label | None = None, @@ -462,13 +462,13 @@ class Index(IndexOpsMixin[S1]): @property def shape(self) -> tuple[int, ...]: ... # Extra methods from old stubs - def __eq__(self, other: object) -> np_ndarray_bool: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] + def __eq__(self, other: object) -> np_1darray[np.bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] def __iter__(self) -> Iterator[S1]: ... - def __ne__(self, other: object) -> np_ndarray_bool: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] - def __le__(self, other: Self | S1) -> np_ndarray_bool: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] - def __ge__(self, other: Self | S1) -> np_ndarray_bool: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] - def __lt__(self, other: Self | S1) -> np_ndarray_bool: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] - def __gt__(self, other: Self | S1) -> np_ndarray_bool: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] + def __ne__(self, other: object) -> np_1darray[np.bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] + def __le__(self, other: Self | S1) -> np_1darray[np.bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] + def __ge__(self, other: Self | S1) -> np_1darray[np.bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] + def __lt__(self, other: Self | S1) -> np_1darray[np.bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] + def __gt__(self, other: Self | S1) -> np_1darray[np.bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] # overwrite inherited methods from OpsMixin @overload def __mul__( diff --git a/pandas-stubs/core/indexes/interval.pyi b/pandas-stubs/core/indexes/interval.pyi index b82cdeddb..581ab274f 100644 --- a/pandas-stubs/core/indexes/interval.pyi +++ b/pandas-stubs/core/indexes/interval.pyi @@ -32,6 +32,7 @@ from pandas._typing import ( IntervalT, Label, MaskType, + np_1darray, np_ndarray_anyint, np_ndarray_bool, npt, @@ -221,7 +222,7 @@ class IntervalIndex(ExtensionIndex[IntervalT], IntervalMixin): def memory_usage(self, deep: bool = False) -> int: ... @property def is_overlapping(self) -> bool: ... - def get_loc(self, key: Label) -> int | slice | npt.NDArray[np.bool_]: ... + def get_loc(self, key: Label) -> int | slice | np_1darray[np.bool]: ... @final def get_indexer( self, @@ -260,7 +261,7 @@ class IntervalIndex(ExtensionIndex[IntervalT], IntervalMixin): @overload # type: ignore[override] def __gt__( self, other: IntervalT | IntervalIndex[IntervalT] - ) -> np_ndarray_bool: ... + ) -> np_1darray[np.bool]: ... @overload def __gt__( # pyright: ignore[reportIncompatibleMethodOverride] self, other: pd.Series[IntervalT] @@ -268,7 +269,7 @@ class IntervalIndex(ExtensionIndex[IntervalT], IntervalMixin): @overload # type: ignore[override] def __ge__( self, other: IntervalT | IntervalIndex[IntervalT] - ) -> np_ndarray_bool: ... + ) -> np_1darray[np.bool]: ... @overload def __ge__( # pyright: ignore[reportIncompatibleMethodOverride] self, other: pd.Series[IntervalT] @@ -276,7 +277,7 @@ class IntervalIndex(ExtensionIndex[IntervalT], IntervalMixin): @overload # type: ignore[override] def __le__( self, other: IntervalT | IntervalIndex[IntervalT] - ) -> np_ndarray_bool: ... + ) -> np_1darray[np.bool]: ... @overload def __le__( # pyright: ignore[reportIncompatibleMethodOverride] self, other: pd.Series[IntervalT] @@ -284,13 +285,13 @@ class IntervalIndex(ExtensionIndex[IntervalT], IntervalMixin): @overload # type: ignore[override] def __lt__( self, other: IntervalT | IntervalIndex[IntervalT] - ) -> np_ndarray_bool: ... + ) -> np_1darray[np.bool]: ... @overload def __lt__( # pyright: ignore[reportIncompatibleMethodOverride] self, other: pd.Series[IntervalT] ) -> pd.Series[bool]: ... @overload # type: ignore[override] - def __eq__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_ndarray_bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + def __eq__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] @overload def __eq__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ... # type: ignore[overload-overlap] @overload @@ -298,7 +299,7 @@ class IntervalIndex(ExtensionIndex[IntervalT], IntervalMixin): self, other: object ) -> Literal[False]: ... @overload # type: ignore[override] - def __ne__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_ndarray_bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + def __ne__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] @overload def __ne__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ... # type: ignore[overload-overlap] @overload diff --git a/pandas-stubs/core/indexes/multi.pyi b/pandas-stubs/core/indexes/multi.pyi index d51528d9d..a4301ad29 100644 --- a/pandas-stubs/core/indexes/multi.pyi +++ b/pandas-stubs/core/indexes/multi.pyi @@ -25,8 +25,8 @@ from pandas._typing import ( MaskType, NaPosition, SequenceNotStr, + np_1darray, np_ndarray_anyint, - np_ndarray_bool, ) class MultiIndex(Index): @@ -161,4 +161,4 @@ class MultiIndex(Index): def equal_levels(self, other): ... def insert(self, loc, item): ... def delete(self, loc): ... - def isin(self, values, level=...) -> np_ndarray_bool: ... + def isin(self, values, level=...) -> np_1darray[np.bool]: ... diff --git a/pandas-stubs/core/indexes/range.pyi b/pandas-stubs/core/indexes/range.pyi index 5c2d18263..2f4c82c78 100644 --- a/pandas-stubs/core/indexes/range.pyi +++ b/pandas-stubs/core/indexes/range.pyi @@ -13,8 +13,8 @@ from pandas.core.indexes.base import Index from pandas._typing import ( HashableT, MaskType, + np_1darray, np_ndarray_anyint, - npt, ) class RangeIndex(Index[int]): @@ -58,7 +58,7 @@ class RangeIndex(Index[int]): def argsort(self, *args, **kwargs): ... def factorize( self, sort: bool = False, use_na_sentinel: bool = True - ) -> tuple[npt.NDArray[np.intp], RangeIndex]: ... + ) -> tuple[np_1darray[np.intp], RangeIndex]: ... def equals(self, other): ... @final def join( diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 69dfc1b08..d98408bb1 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -176,6 +176,7 @@ from pandas._typing import ( ValueKeyFunc, VoidDtypeArg, WriteBuffer, + np_1darray, np_ndarray_anyint, np_ndarray_bool, np_ndarray_complex, @@ -453,7 +454,7 @@ class Series(IndexOpsMixin[S1], NDFrame): ): ... def __array__( self, dtype: _str | np.dtype = ..., copy: bool | None = ... - ) -> np.ndarray: ... + ) -> np_1darray: ... @property def axes(self) -> list: ... @final @@ -2885,7 +2886,7 @@ class Series(IndexOpsMixin[S1], NDFrame): copy: bool = False, na_value: Scalar = ..., **kwargs: Any, - ) -> np.ndarray: ... + ) -> np_1darray: ... def tolist(self) -> list[S1]: ... def var( self, diff --git a/pandas-stubs/core/strings/accessor.pyi b/pandas-stubs/core/strings/accessor.pyi index bedd8ac7b..a290fa12f 100644 --- a/pandas-stubs/core/strings/accessor.pyi +++ b/pandas-stubs/core/strings/accessor.pyi @@ -31,7 +31,7 @@ from pandas._typing import ( DtypeObj, Scalar, T, - np_ndarray_bool, + np_1darray, ) # Used for the result of str.split with expand=True @@ -39,7 +39,7 @@ _T_EXPANDING = TypeVar("_T_EXPANDING", bound=DataFrame | MultiIndex) # Used for the result of str.split with expand=False _T_LIST_STR = TypeVar("_T_LIST_STR", bound=Series[list[str]] | Index[list[str]]) # Used for the result of str.match -_T_BOOL = TypeVar("_T_BOOL", bound=Series[bool] | np_ndarray_bool) +_T_BOOL = TypeVar("_T_BOOL", bound=Series[bool] | np_1darray[np.bool]) # Used for the result of str.index / str.find _T_INT = TypeVar("_T_INT", bound=Series[int] | Index[int]) # Used for the result of str.encode diff --git a/tests/__init__.py b/tests/__init__.py index c4a94e773..7a95b9f82 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -11,6 +11,9 @@ TYPE_CHECKING, Final, Literal, + TypeVar, + get_args, + get_origin, ) import numpy as np @@ -40,11 +43,16 @@ TimestampDtypeArg as TimestampDtypeArg, UIntDtypeArg as UIntDtypeArg, VoidDtypeArg as VoidDtypeArg, + np_1darray as np_1darray, + np_2darray as np_2darray, np_ndarray_bool as np_ndarray_bool, np_ndarray_int as np_ndarray_int, ) else: + _G = TypeVar("_G", bound=np.generic) # Separately define here so pytest works + np_1darray = np.ndarray[tuple[int], np.dtype[_G]] + np_2darray = np.ndarray[tuple[int, int], np.dtype[_G]] np_ndarray_bool = npt.NDArray[np.bool_] np_ndarray_int = npt.NDArray[np.signedinteger] @@ -61,8 +69,39 @@ def check( attr: str = "left", index_to_check_for_type: Literal[0, -1] = 0, ) -> T: - if not isinstance(actual, klass): + __tracebackhide__ = True + origin = get_origin(klass) + if not isinstance(actual, origin or klass): raise RuntimeError(f"Expected type '{klass}' but got '{type(actual)}'") + if origin is np.ndarray: + # Check shape and dtype + args = get_args(klass) + shape_type = args[0] if len(args) >= 1 else None + dtype_type = args[1] if len(args) >= 2 else None + if ( + shape_type + and get_origin(shape_type) is tuple + and (tuple_args := get_args(shape_type)) + and ... not in tuple_args # fixed-length tuple + and (arr_ndim := getattr(actual, "ndim")) + != (expected_ndim := len(tuple_args)) + ): + raise RuntimeError( + f"Array has wrong dimension {arr_ndim}, expected {expected_ndim}" + ) + + if ( + dtype_type + and get_origin(dtype_type) is np.dtype + and (dtype_args := get_args(dtype_type)) + and isinstance((expected_dtype := dtype_args[0]), type) + and issubclass(expected_dtype, np.generic) + and (arr_dtype := getattr(actual, "dtype")) != expected_dtype + ): + raise RuntimeError( + f"Array has wrong dtype {arr_dtype}, expected {expected_dtype.__name__}" + ) + if dtype is None: return actual diff --git a/tests/series/test_series.py b/tests/series/test_series.py index 795193782..4070b4b3d 100644 --- a/tests/series/test_series.py +++ b/tests/series/test_series.py @@ -56,6 +56,7 @@ WINDOWS, check, ensure_clean, + np_1darray, pytest_warns_bounded, ) from tests.extension.decimal.array import DecimalDtype @@ -91,7 +92,6 @@ UIntDtypeArg, VoidDtypeArg, ) - from tests import np_ndarray_int # noqa: F401 else: TimedeltaSeries: TypeAlias = pd.Series @@ -1870,7 +1870,7 @@ def test_types_to_dict() -> None: def test_categorical_codes(): # GH-111 cat = pd.Categorical(["a", "b", "a"]) - assert_type(cat.codes, "np_ndarray_int") + check(assert_type(cat.codes, np_1darray[np.signedinteger]), np_1darray[np.int8]) def test_relops() -> None: @@ -2022,12 +2022,12 @@ def test_dtype_type() -> None: def test_types_to_numpy() -> None: s = pd.Series(["a", "b", "c"], dtype=str) - check(assert_type(s.to_numpy(), np.ndarray), np.ndarray) - check(assert_type(s.to_numpy(dtype="str", copy=True), np.ndarray), np.ndarray) - check(assert_type(s.to_numpy(na_value=0), np.ndarray), np.ndarray) - check(assert_type(s.to_numpy(na_value=np.int32(4)), np.ndarray), np.ndarray) - check(assert_type(s.to_numpy(na_value=np.float16(4)), np.ndarray), np.ndarray) - check(assert_type(s.to_numpy(na_value=np.complex128(4, 7)), np.ndarray), np.ndarray) + check(assert_type(s.to_numpy(), np_1darray), np_1darray) + check(assert_type(s.to_numpy(dtype="str", copy=True), np_1darray), np_1darray) + check(assert_type(s.to_numpy(na_value=0), np_1darray), np_1darray) + check(assert_type(s.to_numpy(na_value=np.int32(4)), np_1darray), np_1darray) + check(assert_type(s.to_numpy(na_value=np.float16(4)), np_1darray), np_1darray) + check(assert_type(s.to_numpy(na_value=np.complex128(4, 7)), np_1darray), np_1darray) def test_where() -> None: diff --git a/tests/test_frame.py b/tests/test_frame.py index f70504ef4..cf1b58015 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -57,6 +57,7 @@ TYPE_CHECKING_INVALID_USAGE, check, ensure_clean, + np_2darray, pytest_warns_bounded, ) @@ -1946,18 +1947,18 @@ def test_types_cov() -> None: def test_types_to_numpy() -> None: df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5]}) - check(assert_type(df.to_numpy(), np.ndarray), np.ndarray) - check(assert_type(df.to_numpy(dtype="str", copy=True), np.ndarray), np.ndarray) + check(assert_type(df.to_numpy(), np_2darray), np_2darray) + check(assert_type(df.to_numpy(dtype="str", copy=True), np_2darray), np_2darray) # na_value param was added in 1.1.0 https://pandas.pydata.org/docs/whatsnew/v1.1.0.html - check(assert_type(df.to_numpy(na_value=0), np.ndarray), np.ndarray) + check(assert_type(df.to_numpy(na_value=0), np_2darray), np_2darray) df = pd.DataFrame(data={"col1": [1, 1, 2]}, dtype=np.complex128) - check(assert_type(df.to_numpy(na_value=0), np.ndarray), np.ndarray) - check(assert_type(df.to_numpy(na_value=np.int32(4)), np.ndarray), np.ndarray) - check(assert_type(df.to_numpy(na_value=np.float16(3.68)), np.ndarray), np.ndarray) + check(assert_type(df.to_numpy(na_value=0), np_2darray), np_2darray) + check(assert_type(df.to_numpy(na_value=np.int32(4)), np_2darray), np_2darray) + check(assert_type(df.to_numpy(na_value=np.float16(3.68)), np_2darray), np_2darray) check( - assert_type(df.to_numpy(na_value=np.complex128(3.8, -493.2)), np.ndarray), - np.ndarray, + assert_type(df.to_numpy(na_value=np.complex128(3.8, -493.2)), np_2darray), + np_2darray, ) diff --git a/tests/test_indexes.py b/tests/test_indexes.py index f6ad2a270..adf21c01a 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -26,6 +26,7 @@ PD_LTE_23, TYPE_CHECKING_INVALID_USAGE, check, + np_1darray, pytest_warns_bounded, ) @@ -42,13 +43,13 @@ def test_index_duplicated() -> None: df = pd.DataFrame({"x": [1, 2, 3, 4]}, index=pd.Index([1, 2, 3, 2])) ind = df.index duplicated = ind.duplicated("first") - check(assert_type(duplicated, npt.NDArray[np.bool_]), np.ndarray, np.bool_) + check(assert_type(duplicated, np_1darray[np.bool]), np_1darray[np.bool]) def test_index_isin() -> None: ind = pd.Index([1, 2, 3, 4, 5]) isin = ind.isin([2, 4]) - check(assert_type(isin, npt.NDArray[np.bool_]), np.ndarray, np.bool_) + check(assert_type(isin, np_1darray[np.bool]), np_1darray[np.bool]) def test_index_astype() -> None: @@ -202,13 +203,6 @@ def test_str_rsplit() -> None: ) -def test_str_match() -> None: - i = pd.Index( - ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] - ) - check(assert_type(i.str.match("pp"), npt.NDArray[np.bool_]), np.ndarray, np.bool_) - - def test_index_rename() -> None: """Test that index rename returns an element of type Index.""" ind = pd.Index([1, 2, 3], name="foo") @@ -244,9 +238,9 @@ def test_index_neg(): def test_types_to_numpy() -> None: idx = pd.Index([1, 2]) - check(assert_type(idx.to_numpy(), np.ndarray), np.ndarray) - check(assert_type(idx.to_numpy(dtype="int", copy=True), np.ndarray), np.ndarray) - check(assert_type(idx.to_numpy(na_value=0), np.ndarray), np.ndarray) + check(assert_type(idx.to_numpy(), np_1darray), np_1darray) + check(assert_type(idx.to_numpy(dtype="int", copy=True), np_1darray), np_1darray) + check(assert_type(idx.to_numpy(na_value=0), np_1darray), np_1darray) def test_index_arithmetic() -> None: @@ -289,10 +283,10 @@ def test_index_relops() -> None: check(assert_type(data[dt_idx > x], pd.DatetimeIndex), pd.DatetimeIndex) ind = pd.Index([1, 2, 3]) - check(assert_type(ind <= 2, npt.NDArray[np.bool_]), np.ndarray, np.bool_) - check(assert_type(ind >= 2, npt.NDArray[np.bool_]), np.ndarray, np.bool_) - check(assert_type(ind < 2, npt.NDArray[np.bool_]), np.ndarray, np.bool_) - check(assert_type(ind > 2, npt.NDArray[np.bool_]), np.ndarray, np.bool_) + check(assert_type(ind <= 2, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(ind >= 2, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(ind < 2, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(ind > 2, np_1darray[np.bool]), np_1darray[np.bool]) def test_range_index_union(): @@ -1299,16 +1293,14 @@ def test_datetime_operators_builtin() -> None: def test_get_loc() -> None: unique_index = pd.Index(list("abc")) check( - assert_type( - unique_index.get_loc("b"), Union[int, slice, npt.NDArray[np.bool_]] - ), + assert_type(unique_index.get_loc("b"), Union[int, slice, np_1darray[np.bool]]), int, ) monotonic_index = pd.Index(list("abbc")) check( assert_type( - monotonic_index.get_loc("b"), Union[int, slice, npt.NDArray[np.bool_]] + monotonic_index.get_loc("b"), Union[int, slice, np_1darray[np.bool]] ), slice, ) @@ -1316,10 +1308,25 @@ def test_get_loc() -> None: non_monotonic_index = pd.Index(list("abcb")) check( assert_type( - non_monotonic_index.get_loc("b"), Union[int, slice, npt.NDArray[np.bool_]] + non_monotonic_index.get_loc("b"), Union[int, slice, np_1darray[np.bool]] + ), + np_1darray[np.bool], + ) + + i1, i2, i3 = pd.Interval(0, 1), pd.Interval(1, 2), pd.Interval(0, 2) + unique_interval_index = pd.IntervalIndex([i1, i2]) + check( + assert_type( + unique_interval_index.get_loc(i1), Union[int, slice, np_1darray[np.bool]] + ), + np.int64, + ) + overlap_interval_index = pd.IntervalIndex([i1, i2, i3]) + check( + assert_type( + overlap_interval_index.get_loc(1), Union[int, slice, np_1darray[np.bool]] ), - np.ndarray, - np.bool_, + np_1darray[np.bool], ) @@ -1336,14 +1343,14 @@ def test_value_counts() -> None: def test_index_factorize() -> None: """Test Index.factorize method.""" codes, idx_uniques = pd.Index(["b", "b", "a", "c", "b"]).factorize() - check(assert_type(codes, np.ndarray), np.ndarray) - check(assert_type(idx_uniques, np.ndarray | Index | Categorical), pd.Index) + check(assert_type(codes, np_1darray), np_1darray) + check(assert_type(idx_uniques, np_1darray | Index | Categorical), pd.Index) codes, idx_uniques = pd.Index(["b", "b", "a", "c", "b"]).factorize( use_na_sentinel=False ) - check(assert_type(codes, np.ndarray), np.ndarray) - check(assert_type(idx_uniques, np.ndarray | Index | Categorical), pd.Index) + check(assert_type(codes, np_1darray), np_1darray) + check(assert_type(idx_uniques, np_1darray | Index | Categorical), pd.Index) def test_disallow_empty_index() -> None: diff --git a/tests/test_interval.py b/tests/test_interval.py index 208255f89..8d9081710 100644 --- a/tests/test_interval.py +++ b/tests/test_interval.py @@ -1,13 +1,13 @@ from __future__ import annotations import numpy as np -from numpy import typing as npt import pandas as pd from typing_extensions import assert_type from tests import ( TYPE_CHECKING_INVALID_USAGE, check, + np_1darray, ) @@ -127,4 +127,4 @@ def test_interval_array_contains(): ser = pd.Series(obj, index=df.index) arr = ser.array check(assert_type(arr.contains(df["A"]), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(arr.contains(3), npt.NDArray[np.bool_]), np.ndarray) + check(assert_type(arr.contains(3), np_1darray[np.bool]), np_1darray[np.bool]) diff --git a/tests/test_pandas.py b/tests/test_pandas.py index cfa6f6be2..a7633fa8a 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -33,6 +33,9 @@ PD_LTE_23, TYPE_CHECKING_INVALID_USAGE, check, + np_1darray, + np_2darray, + np_ndarray_bool, pytest_warns_bounded, ) @@ -428,20 +431,47 @@ def test_types_json_normalize() -> None: def test_isna() -> None: # https://github.com/pandas-dev/pandas-stubs/issues/264 - s1 = pd.Series([1, np.nan, 3.2]) - check(assert_type(pd.isna(s1), "pd.Series[bool]"), pd.Series, np.bool_) - - s2 = pd.Series([1, 3.2]) - check(assert_type(pd.notna(s2), "pd.Series[bool]"), pd.Series, np.bool_) - - df1 = pd.DataFrame({"a": [1, 2, 1, 2], "b": [1, 1, 2, np.nan]}) - check(assert_type(pd.isna(df1), "pd.DataFrame"), pd.DataFrame) - - idx1 = pd.Index([1, 2, np.nan]) - check(assert_type(pd.isna(idx1), npt.NDArray[np.bool_]), np.ndarray, np.bool_) - - idx2 = pd.Index([1, 2]) - check(assert_type(pd.notna(idx2), npt.NDArray[np.bool_]), np.ndarray, np.bool_) + s = pd.Series([1, np.nan, 3.2]) + check(assert_type(pd.isna(s), "pd.Series[bool]"), pd.Series, np.bool_) + check(assert_type(pd.notna(s), "pd.Series[bool]"), pd.Series, np.bool_) + + df = pd.DataFrame({"a": [1, 2, 1, 2], "b": [1, 1, 2, np.nan]}) + check(assert_type(pd.isna(df), "pd.DataFrame"), pd.DataFrame) + check(assert_type(pd.notna(df), "pd.DataFrame"), pd.DataFrame) + + idx = pd.Index([1, 2, np.nan, float("nan")]) + check(assert_type(pd.isna(idx), np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(pd.notna(idx), np_1darray[np.bool]), np_1darray[np.bool]) + + # ExtensionArray + ext_arr = idx.array + check(assert_type(pd.isna(ext_arr), np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(pd.notna(ext_arr), np_1darray[np.bool]), np_1darray[np.bool]) + + # 1-D numpy array + arr_1d = idx.to_numpy() + check(assert_type(pd.isna(arr_1d), np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(pd.notna(arr_1d), np_1darray[np.bool]), np_1darray[np.bool]) + + # 2-D numpy array + arr_2d = idx.to_numpy().reshape(2, 2) + check(assert_type(pd.isna(arr_2d), np_2darray[np.bool]), np_2darray[np.bool]) + check(assert_type(pd.notna(arr_2d), np_2darray[np.bool]), np_2darray[np.bool]) + + # N-D numpy array + arr_nd = idx.to_numpy().reshape([2, 2]) + check(assert_type(pd.isna(arr_nd), np_ndarray_bool), np_ndarray_bool) + check(assert_type(pd.notna(arr_nd), np_ndarray_bool), np_ndarray_bool) + + # List of scalars + l_sca = [1, 2.5, float("nan")] + check(assert_type(pd.isna(l_sca), np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(pd.notna(l_sca), np_1darray[np.bool]), np_1darray[np.bool]) + + # List of unknown members + l_any: list[object] = [arr_1d, ext_arr] + check(assert_type(pd.isna(l_any), np_ndarray_bool), np_ndarray_bool) + check(assert_type(pd.notna(l_any), np_ndarray_bool), np_ndarray_bool) assert check(assert_type(pd.isna(pd.NA), bool), bool) assert not check(assert_type(pd.notna(pd.NA), bool), bool) @@ -909,15 +939,15 @@ def test_factorize() -> None: check(assert_type(uniques, np.ndarray), np.ndarray) codes, cat_uniques = pd.factorize(pd.Categorical(["b", "b", "a", "c", "b"])) - check(assert_type(codes, np.ndarray), np.ndarray) + check(assert_type(codes, np_1darray), np_1darray) check(assert_type(cat_uniques, pd.Categorical), pd.Categorical) codes, idx_uniques = pd.factorize(pd.Index(["b", "b", "a", "c", "b"])) - check(assert_type(codes, np.ndarray), np.ndarray) + check(assert_type(codes, np_1darray), np_1darray) check(assert_type(idx_uniques, pd.Index), pd.Index) codes, idx_uniques = pd.factorize(pd.Series(["b", "b", "a", "c", "b"])) - check(assert_type(codes, np.ndarray), np.ndarray) + check(assert_type(codes, np_1darray), np_1darray) check(assert_type(idx_uniques, pd.Index), pd.Index) codes, uniques = pd.factorize(np.array(list("bbacb"))) diff --git a/tests/test_scalars.py b/tests/test_scalars.py index bfc78ae65..d4cc7d9fa 100644 --- a/tests/test_scalars.py +++ b/tests/test_scalars.py @@ -2,6 +2,8 @@ import datetime import datetime as dt + +# import sys from typing import ( TYPE_CHECKING, Any, @@ -26,6 +28,9 @@ from tests import ( TYPE_CHECKING_INVALID_USAGE, check, + np_1darray, + np_2darray, + np_ndarray_bool, pytest_warns_bounded, ) @@ -48,7 +53,7 @@ PeriodSeries: TypeAlias = pd.Series OffsetSeries: TypeAlias = pd.Series -from tests import np_ndarray_bool +MYPY = False def test_interval() -> None: @@ -314,73 +319,55 @@ def test_interval_cmp(): interval_index_int = pd.IntervalIndex([interval_i]) check( - assert_type(interval_index_int >= interval_i, np_ndarray_bool), - np.ndarray, - np.bool_, + assert_type(interval_index_int >= interval_i, np_1darray[np.bool]), + np_1darray[np.bool], ) check( - assert_type(interval_index_int < interval_i, np_ndarray_bool), - np.ndarray, - np.bool_, + assert_type(interval_index_int < interval_i, np_1darray[np.bool]), + np_1darray[np.bool], ) check( - assert_type(interval_index_int <= interval_i, np_ndarray_bool), - np.ndarray, - np.bool_, + assert_type(interval_index_int <= interval_i, np_1darray[np.bool]), + np_1darray[np.bool], ) check( - assert_type(interval_index_int > interval_i, np_ndarray_bool), - np.ndarray, - np.bool_, + assert_type(interval_index_int > interval_i, np_1darray[np.bool]), + np_1darray[np.bool], ) check( - assert_type(interval_i >= interval_index_int, np_ndarray_bool), - np.ndarray, - np.bool_, + assert_type(interval_i >= interval_index_int, np_1darray[np.bool]), + np_1darray[np.bool], ) check( - assert_type(interval_i < interval_index_int, np_ndarray_bool), - np.ndarray, - np.bool_, + assert_type(interval_i < interval_index_int, np_1darray[np.bool]), + np_1darray[np.bool], ) check( - assert_type(interval_i <= interval_index_int, np_ndarray_bool), - np.ndarray, - np.bool_, + assert_type(interval_i <= interval_index_int, np_1darray[np.bool]), + np_1darray[np.bool], ) check( - assert_type(interval_i > interval_index_int, np_ndarray_bool), - np.ndarray, - np.bool_, + assert_type(interval_i > interval_index_int, np_1darray[np.bool]), + np_1darray[np.bool], ) check( - assert_type(interval_index_int == interval_i, np_ndarray_bool), - np.ndarray, - np.bool_, + assert_type(interval_index_int == interval_i, np_1darray[np.bool]), + np_1darray[np.bool], ) check( - assert_type(interval_index_int != interval_i, np_ndarray_bool), - np.ndarray, - np.bool_, + assert_type(interval_index_int != interval_i, np_1darray[np.bool]), + np_1darray[np.bool], ) check( - assert_type( - interval_i == interval_index_int, - np_ndarray_bool, - ), - np.ndarray, - np.bool_, + assert_type(interval_i == interval_index_int, np_1darray[np.bool]), + np_1darray[np.bool], ) check( - assert_type( - interval_i != interval_index_int, - np_ndarray_bool, - ), - np.ndarray, - np.bool_, + assert_type(interval_i != interval_index_int, np_1darray[np.bool]), + np_1darray[np.bool], ) @@ -854,10 +841,10 @@ def test_timedelta_cmp() -> None: check(assert_type(td < c_dt_timedelta, bool), bool) check(assert_type(td < c_timedelta64, bool), bool) check(assert_type(td < c_ndarray_td64, np_ndarray_bool), np.ndarray, np.bool_) - check(assert_type(c_timedelta_index < td, np_ndarray_bool), np.ndarray, np.bool_) + check(assert_type(c_timedelta_index < td, np_1darray[np.bool]), np_1darray[np.bool]) check(assert_type(c_dt_timedelta < td, bool), bool) check(assert_type(c_ndarray_td64 < td, np_ndarray_bool), np.ndarray, np.bool_) - check(assert_type(c_timedelta_index < td, np_ndarray_bool), np.ndarray, np.bool_) + check(assert_type(c_timedelta_index < td, np_1darray[np.bool]), np_1darray[np.bool]) gt = check(assert_type(td > c_timedelta, bool), bool) le = check(assert_type(td <= c_timedelta, bool), bool) @@ -912,10 +899,10 @@ def test_timedelta_cmp() -> None: assert (gt_a != le_a).all() gt_a = check( - assert_type(c_timedelta_index > td, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(c_timedelta_index > td, np_1darray[np.bool]), np_1darray[np.bool] ) le_a = check( - assert_type(c_timedelta_index <= td, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(c_timedelta_index <= td, np_1darray[np.bool]), np_1darray[np.bool] ) assert (gt_a != le_a).all() @@ -980,10 +967,10 @@ def test_timedelta_cmp() -> None: assert (lt_a != ge_a).all() lt_a = check( - assert_type(c_timedelta_index < td, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(c_timedelta_index < td, np_1darray[np.bool]), np_1darray[np.bool] ) ge_a = check( - assert_type(c_timedelta_index >= td, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(c_timedelta_index >= td, np_1darray[np.bool]), np_1darray[np.bool] ) assert (lt_a != ge_a).all() @@ -1066,10 +1053,10 @@ def test_timedelta_cmp_rhs() -> None: assert (eq_a != ne_a).all() eq_a = check( - assert_type(c_timedelta_index == td, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(c_timedelta_index == td, np_1darray[np.bool]), np_1darray[np.bool] ) ne_a = check( - assert_type(c_timedelta_index != td, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(c_timedelta_index != td, np_1darray[np.bool]), np_1darray[np.bool] ) assert (eq_a != ne_a).all() @@ -1256,6 +1243,9 @@ def test_timestamp_cmp() -> None: np_dt64_arr: npt.NDArray[np.datetime64] = np.array( [1, 2, 3], dtype="datetime64[ns]" ) + np_dt64_arr2d: np.ndarray[tuple[int, int], np.dtype[np.datetime64]] = ( + np.arange(6).astype(dtype=np.datetime64).reshape(3, 2) + ) c_timestamp = ts c_np_dt64 = np.datetime64(1, "ns") @@ -1264,6 +1254,7 @@ def test_timestamp_cmp() -> None: # DatetimeIndex, but the type checker detects it to be UnknownIndex. c_unknown_index = pd.DataFrame({"a": [1]}, index=c_datetimeindex).index c_np_ndarray_dt64 = np_dt64_arr + c_np_2darray_dt64 = np_dt64_arr2d c_series_dt64: TimestampSeries = pd.Series([1, 2, 3], dtype="datetime64[ns]") c_series_timestamp = pd.Series(pd.DatetimeIndex(["2000-1-1"])) check(assert_type(c_series_timestamp, TimestampSeries), pd.Series, pd.Timestamp) @@ -1281,13 +1272,17 @@ def test_timestamp_cmp() -> None: lte = check(assert_type(ts <= c_dt_datetime, bool), bool) assert gt != lte - check(assert_type(ts > c_datetimeindex, np_ndarray_bool), np.ndarray, np.bool_) - check(assert_type(ts <= c_datetimeindex, np_ndarray_bool), np.ndarray, np.bool_) - check(assert_type(ts > c_unknown_index, np_ndarray_bool), np.ndarray, np.bool_) - check(assert_type(ts <= c_unknown_index, np_ndarray_bool), np.ndarray, np.bool_) + check(assert_type(ts > c_datetimeindex, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(ts <= c_datetimeindex, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(ts > c_unknown_index, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(ts <= c_unknown_index, np_1darray[np.bool]), np_1darray[np.bool]) check(assert_type(ts > c_np_ndarray_dt64, np_ndarray_bool), np.ndarray, np.bool_) check(assert_type(ts <= c_np_ndarray_dt64, np_ndarray_bool), np.ndarray, np.bool_) + check(assert_type(ts > c_np_2darray_dt64, np_2darray[np.bool]), np_2darray[np.bool]) + check( + assert_type(ts <= c_np_2darray_dt64, np_2darray[np.bool]), np_2darray[np.bool] + ) check(assert_type(ts > c_series_timestamp, "pd.Series[bool]"), pd.Series, np.bool_) check(assert_type(ts <= c_series_timestamp, "pd.Series[bool]"), pd.Series, np.bool_) @@ -1302,13 +1297,17 @@ def test_timestamp_cmp() -> None: lte = check(assert_type(c_dt_datetime <= ts, bool), bool) assert gt != lte - check(assert_type(c_datetimeindex > ts, np_ndarray_bool), np.ndarray, np.bool_) - check(assert_type(c_datetimeindex <= ts, np_ndarray_bool), np.ndarray, np.bool_) - check(assert_type(c_unknown_index > ts, np_ndarray_bool), np.ndarray, np.bool_) - check(assert_type(c_unknown_index <= ts, np_ndarray_bool), np.ndarray, np.bool_) + check(assert_type(c_datetimeindex > ts, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(c_datetimeindex <= ts, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(c_unknown_index > ts, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(c_unknown_index <= ts, np_1darray[np.bool]), np_1darray[np.bool]) check(assert_type(c_np_ndarray_dt64 > ts, np_ndarray_bool), np.ndarray, np.bool_) check(assert_type(c_np_ndarray_dt64 <= ts, np_ndarray_bool), np.ndarray, np.bool_) + check(assert_type(c_np_2darray_dt64 > ts, np_2darray[np.bool]), np_2darray[np.bool]) + check( + assert_type(c_np_2darray_dt64 <= ts, np_2darray[np.bool]), np_2darray[np.bool] + ) check(assert_type(c_series_dt64 > ts, "pd.Series[bool]"), pd.Series, np.bool_) check(assert_type(c_series_dt64 <= ts, "pd.Series[bool]"), pd.Series, np.bool_) @@ -1325,13 +1324,17 @@ def test_timestamp_cmp() -> None: lt = check(assert_type(ts < c_dt_datetime, bool), bool) assert gte != lt - check(assert_type(ts >= c_datetimeindex, np_ndarray_bool), np.ndarray, np.bool_) - check(assert_type(ts < c_datetimeindex, np_ndarray_bool), np.ndarray, np.bool_) - check(assert_type(ts >= c_unknown_index, np_ndarray_bool), np.ndarray, np.bool_) - check(assert_type(ts < c_unknown_index, np_ndarray_bool), np.ndarray, np.bool_) + check(assert_type(ts >= c_datetimeindex, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(ts < c_datetimeindex, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(ts >= c_unknown_index, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(ts < c_unknown_index, np_1darray[np.bool]), np_1darray[np.bool]) check(assert_type(ts >= c_np_ndarray_dt64, np_ndarray_bool), np.ndarray, np.bool_) check(assert_type(ts < c_np_ndarray_dt64, np_ndarray_bool), np.ndarray, np.bool_) + check( + assert_type(c_np_2darray_dt64 >= ts, np_2darray[np.bool]), np_2darray[np.bool] + ) + check(assert_type(c_np_2darray_dt64 < ts, np_2darray[np.bool]), np_2darray[np.bool]) check(assert_type(ts >= c_series_timestamp, "pd.Series[bool]"), pd.Series, np.bool_) check(assert_type(ts < c_series_timestamp, "pd.Series[bool]"), pd.Series, np.bool_) @@ -1346,13 +1349,17 @@ def test_timestamp_cmp() -> None: check(assert_type(c_np_dt64 >= ts, np.bool), bool) check(assert_type(c_np_dt64 < ts, np.bool), bool) - check(assert_type(c_datetimeindex >= ts, np_ndarray_bool), np.ndarray, np.bool_) - check(assert_type(c_datetimeindex < ts, np_ndarray_bool), np.ndarray, np.bool_) - check(assert_type(c_unknown_index >= ts, np_ndarray_bool), np.ndarray, np.bool_) - check(assert_type(c_unknown_index < ts, np_ndarray_bool), np.ndarray, np.bool_) + check(assert_type(c_datetimeindex >= ts, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(c_datetimeindex < ts, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(c_unknown_index >= ts, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(c_unknown_index < ts, np_1darray[np.bool]), np_1darray[np.bool]) check(assert_type(c_np_ndarray_dt64 >= ts, np_ndarray_bool), np.ndarray, np.bool_) check(assert_type(c_np_ndarray_dt64 < ts, np_ndarray_bool), np.ndarray, np.bool_) + check( + assert_type(c_np_2darray_dt64 >= ts, np_2darray[np.bool]), np_2darray[np.bool] + ) + check(assert_type(c_np_2darray_dt64 < ts, np_2darray[np.bool]), np_2darray[np.bool]) check(assert_type(c_series_dt64 >= ts, "pd.Series[bool]"), pd.Series, np.bool_) check(assert_type(c_series_dt64 < ts, "pd.Series[bool]"), pd.Series, np.bool_) @@ -1370,27 +1377,44 @@ def test_timestamp_cmp() -> None: assert eq != ne eq_arr = check( - assert_type(ts == c_datetimeindex, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(ts == c_datetimeindex, np_1darray[np.bool]), np_1darray[np.bool] ) ne_arr = check( - assert_type(ts != c_datetimeindex, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(ts != c_datetimeindex, np_1darray[np.bool]), np_1darray[np.bool] ) assert (eq_arr != ne_arr).all() eq_arr = check( - assert_type(ts == c_unknown_index, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(ts == c_unknown_index, np_1darray[np.bool]), np_1darray[np.bool] ) ne_arr = check( - assert_type(ts != c_unknown_index, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(ts != c_unknown_index, np_1darray[np.bool]), np_1darray[np.bool] ) assert (eq_arr != ne_arr).all() - eq_arr = check( - assert_type(ts == c_np_ndarray_dt64, np_ndarray_bool), np.ndarray, np.bool_ - ) - ne_arr = check( - assert_type(ts != c_np_ndarray_dt64, np_ndarray_bool), np.ndarray, np.bool_ - ) - assert (eq_arr != ne_arr).all() + if True: # sys.version_info >= (3, 11) or not MYPY: + # tests in this block fail with mypy on Python 3.10 in CI only + # I couldn't reproduce the failure locally so skip mypy on Python 3.10 + eq_arr = check( + assert_type(ts == c_np_ndarray_dt64, np_1darray[np.bool]), + np_1darray[np.bool], + np.bool_, + ) + ne_arr = check( + assert_type(ts != c_np_ndarray_dt64, np_1darray[np.bool]), + np_1darray[np.bool], + np.bool_, + ) + assert (eq_arr != ne_arr).all() + # TODO: the following should be 2D-arrays but it doesn't work in mypy + eq_arr = check( + assert_type(ts == c_np_2darray_dt64, np_1darray[np.bool]), + np_1darray[np.bool], + ) + ne_arr = check( + assert_type(ts != c_np_2darray_dt64, np_1darray[np.bool]), + np_1darray[np.bool], + ) + assert (eq_arr != ne_arr).all() eq_s = check( assert_type(ts == c_series_timestamp, "pd.Series[bool]"), pd.Series, np.bool_ @@ -1437,17 +1461,17 @@ def test_timestamp_eq_ne_rhs() -> None: assert eq != ne eq_arr = check( - assert_type(c_datetimeindex == ts, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(c_datetimeindex == ts, np_1darray[np.bool]), np_1darray[np.bool] ) ne_arr = check( - assert_type(c_datetimeindex != ts, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(c_datetimeindex != ts, np_1darray[np.bool]), np_1darray[np.bool] ) assert (eq_arr != ne_arr).all() eq_arr = check( - assert_type(c_unknown_index == ts, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(c_unknown_index == ts, np_1darray[np.bool]), np_1darray[np.bool] ) ne_arr = check( - assert_type(c_unknown_index != ts, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(c_unknown_index != ts, np_1darray[np.bool]), np_1darray[np.bool] ) assert (eq_arr != ne_arr).all() @@ -1946,10 +1970,10 @@ def test_period_cmp() -> None: assert eq != ne eq_a = check( - assert_type(c_period_index == p, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(c_period_index == p, np_1darray[np.bool]), np_1darray[np.bool] ) ne_a = check( - assert_type(c_period_index != p, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(c_period_index != p, np_1darray[np.bool]), np_1darray[np.bool] ) assert (eq_a != ne_a).all() @@ -1983,9 +2007,11 @@ def test_period_cmp() -> None: le = check(assert_type(c_period <= p, bool), bool) assert gt != le - gt_a = check(assert_type(c_period_index > p, np_ndarray_bool), np.ndarray, np.bool_) + gt_a = check( + assert_type(c_period_index > p, np_1darray[np.bool]), np_1darray[np.bool] + ) le_a = check( - assert_type(c_period_index <= p, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(c_period_index <= p, np_1darray[np.bool]), np_1darray[np.bool] ) assert (gt_a != le_a).all() @@ -2019,9 +2045,11 @@ def test_period_cmp() -> None: ge = check(assert_type(c_period >= p, bool), bool) assert lt != ge - lt_a = check(assert_type(c_period_index < p, np_ndarray_bool), np.ndarray, np.bool_) + lt_a = check( + assert_type(c_period_index < p, np_1darray[np.bool]), np_1darray[np.bool] + ) ge_a = check( - assert_type(c_period_index >= p, np_ndarray_bool), np.ndarray, np.bool_ + assert_type(c_period_index >= p, np_1darray[np.bool]), np_1darray[np.bool] ) assert (lt_a != ge_a).all() diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index 5796ee40d..b501f681c 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -10,7 +10,7 @@ PD_LTE_23, TYPE_CHECKING_INVALID_USAGE, check, - np_ndarray_bool, + np_1darray, ) DATA = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] @@ -71,40 +71,42 @@ def test_string_accessors_boolean_series(): def test_string_accessors_boolean_index(): idx = pd.Index(DATA) - _check = functools.partial(check, klass=np.ndarray, dtype=np.bool_) - _check(assert_type(idx.str.startswith("a"), np_ndarray_bool)) + _check = functools.partial(check, klass=np_1darray[np.bool]) + _check(assert_type(idx.str.startswith("a"), np_1darray[np.bool])) _check( - assert_type(idx.str.startswith(("a", "b")), np_ndarray_bool), + assert_type(idx.str.startswith(("a", "b")), np_1darray[np.bool]), ) _check( - assert_type(idx.str.contains("a"), np_ndarray_bool), + assert_type(idx.str.contains("a"), np_1darray[np.bool]), ) if PD_LTE_23: # Bug in pandas 3.0 dev https://github.com/pandas-dev/pandas/issues/61942 _check( assert_type( - idx.str.contains(re.compile(r"a"), regex=True), np_ndarray_bool + idx.str.contains(re.compile(r"a"), regex=True), np_1darray[np.bool] ), ) - _check(assert_type(idx.str.endswith("e"), np_ndarray_bool)) - _check(assert_type(idx.str.endswith(("e", "f")), np_ndarray_bool)) - _check(assert_type(idx.str.fullmatch("apple"), np_ndarray_bool)) + _check(assert_type(idx.str.endswith("e"), np_1darray[np.bool])) + _check(assert_type(idx.str.endswith(("e", "f")), np_1darray[np.bool])) + _check(assert_type(idx.str.fullmatch("apple"), np_1darray[np.bool])) if PD_LTE_23: # Bug in 3.0 dev: https://github.com/pandas-dev/pandas/issues/61952 - _check(assert_type(idx.str.fullmatch(re.compile(r"apple")), np_ndarray_bool)) - _check(assert_type(idx.str.isalnum(), np_ndarray_bool)) - _check(assert_type(idx.str.isalpha(), np_ndarray_bool)) - _check(assert_type(idx.str.isdecimal(), np_ndarray_bool)) - _check(assert_type(idx.str.isdigit(), np_ndarray_bool)) - _check(assert_type(idx.str.isnumeric(), np_ndarray_bool)) - _check(assert_type(idx.str.islower(), np_ndarray_bool)) - _check(assert_type(idx.str.isspace(), np_ndarray_bool)) - _check(assert_type(idx.str.istitle(), np_ndarray_bool)) - _check(assert_type(idx.str.isupper(), np_ndarray_bool)) - _check(assert_type(idx.str.match("pp"), np_ndarray_bool)) + _check( + assert_type(idx.str.fullmatch(re.compile(r"apple")), np_1darray[np.bool]) + ) + _check(assert_type(idx.str.isalnum(), np_1darray[np.bool])) + _check(assert_type(idx.str.isalpha(), np_1darray[np.bool])) + _check(assert_type(idx.str.isdecimal(), np_1darray[np.bool])) + _check(assert_type(idx.str.isdigit(), np_1darray[np.bool])) + _check(assert_type(idx.str.isnumeric(), np_1darray[np.bool])) + _check(assert_type(idx.str.islower(), np_1darray[np.bool])) + _check(assert_type(idx.str.isspace(), np_1darray[np.bool])) + _check(assert_type(idx.str.istitle(), np_1darray[np.bool])) + _check(assert_type(idx.str.isupper(), np_1darray[np.bool])) + _check(assert_type(idx.str.match("pp"), np_1darray[np.bool])) if PD_LTE_23: # Bug in 3.0 dev: https://github.com/pandas-dev/pandas/issues/61952 - _check(assert_type(idx.str.match(re.compile(r"pp")), np_ndarray_bool)) + _check(assert_type(idx.str.match(re.compile(r"pp")), np_1darray[np.bool])) def test_string_accessors_integer_series(): diff --git a/tests/test_timefuncs.py b/tests/test_timefuncs.py index 7e39e4a75..22b480ab7 100644 --- a/tests/test_timefuncs.py +++ b/tests/test_timefuncs.py @@ -18,7 +18,6 @@ WE, ) import numpy as np -from numpy import typing as npt import pandas as pd from pandas.api.typing import NaTType from pandas.core.tools.datetimes import FulldatetimeDict @@ -34,6 +33,7 @@ PD_LTE_23, TYPE_CHECKING_INVALID_USAGE, check, + np_1darray, pytest_warns_bounded, ) @@ -61,8 +61,6 @@ else: Pandas4Warning: TypeAlias = FutureWarning # type: ignore[no-redef] -from tests import np_ndarray_bool - def test_types_init() -> None: check(assert_type(pd.Timestamp("2021-03-01T12"), pd.Timestamp), pd.Timestamp) @@ -350,12 +348,12 @@ def test_comparisons_datetimeindex() -> None: # GH 74 dti = pd.date_range("2000-01-01", "2000-01-10") ts = pd.Timestamp("2000-01-05") - check(assert_type((dti < ts), np_ndarray_bool), np.ndarray) - check(assert_type((dti > ts), np_ndarray_bool), np.ndarray) - check(assert_type((dti >= ts), np_ndarray_bool), np.ndarray) - check(assert_type((dti <= ts), np_ndarray_bool), np.ndarray) - check(assert_type((dti == ts), np_ndarray_bool), np.ndarray) - check(assert_type((dti != ts), np_ndarray_bool), np.ndarray) + check(assert_type((dti < ts), np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type((dti > ts), np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type((dti >= ts), np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type((dti <= ts), np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type((dti == ts), np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type((dti != ts), np_1darray[np.bool]), np_1darray[np.bool]) def test_to_datetime_nat() -> None: @@ -427,8 +425,8 @@ def test_series_dt_accessors() -> None: upper="2.3.99", ): check( - assert_type(s0.dt.to_pydatetime(), np.ndarray), - np.ndarray if PD_LTE_23 else pd.Series, + assert_type(s0.dt.to_pydatetime(), np_1darray[np.object_]), + np_1darray[np.object_] if PD_LTE_23 else pd.Series, dt.datetime, ) s0_local = s0.dt.tz_localize("UTC") @@ -560,7 +558,11 @@ def test_series_dt_accessors() -> None: upper="3.0.99", ), ): - check(assert_type(s2.dt.to_pytimedelta(), np.ndarray), np.ndarray) + check( + assert_type(s2.dt.to_pytimedelta(), np_1darray[np.object_]), + np_1darray[np.object_], + dt.timedelta, + ) check(assert_type(s2.dt.total_seconds(), "pd.Series[float]"), pd.Series, float) check(assert_type(s2.dt.unit, TimeUnit), str) check(assert_type(s2.dt.as_unit("s"), "TimedeltaSeries"), pd.Series, pd.Timedelta) @@ -601,9 +603,11 @@ def test_datetimeindex_accessors() -> None: i0 = pd.date_range(start="2022-06-01", periods=10) check(assert_type(i0, pd.DatetimeIndex), pd.DatetimeIndex, pd.Timestamp) - check(assert_type(i0.date, np.ndarray), np.ndarray, dt.date) - check(assert_type(i0.time, np.ndarray), np.ndarray, dt.time) - check(assert_type(i0.timetz, np.ndarray), np.ndarray, dt.time) + check(assert_type(i0.date, np_1darray[np.object_]), np_1darray[np.object_], dt.date) + check(assert_type(i0.time, np_1darray[np.object_]), np_1darray[np.object_], dt.time) + check( + assert_type(i0.timetz, np_1darray[np.object_]), np_1darray[np.object_], dt.time + ) check(assert_type(i0.year, "pd.Index[int]"), pd.Index, np.int32) check(assert_type(i0.month, "pd.Index[int]"), pd.Index, np.int32) check(assert_type(i0.day, "pd.Index[int]"), pd.Index, np.int32) @@ -618,13 +622,13 @@ def test_datetimeindex_accessors() -> None: check(assert_type(i0.dayofyear, "pd.Index[int]"), pd.Index, np.int32) check(assert_type(i0.day_of_year, "pd.Index[int]"), pd.Index, np.int32) check(assert_type(i0.quarter, "pd.Index[int]"), pd.Index, np.int32) - check(assert_type(i0.is_month_start, npt.NDArray[np.bool_]), np.ndarray, np.bool_) - check(assert_type(i0.is_month_end, npt.NDArray[np.bool_]), np.ndarray, np.bool_) - check(assert_type(i0.is_quarter_start, npt.NDArray[np.bool_]), np.ndarray, np.bool_) - check(assert_type(i0.is_quarter_end, npt.NDArray[np.bool_]), np.ndarray, np.bool_) - check(assert_type(i0.is_year_start, npt.NDArray[np.bool_]), np.ndarray, np.bool_) - check(assert_type(i0.is_year_end, npt.NDArray[np.bool_]), np.ndarray, np.bool_) - check(assert_type(i0.is_leap_year, npt.NDArray[np.bool_]), np.ndarray, np.bool_) + check(assert_type(i0.is_month_start, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(i0.is_month_end, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(i0.is_quarter_start, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(i0.is_quarter_end, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(i0.is_year_start, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(i0.is_year_end, np_1darray[np.bool]), np_1darray[np.bool]) + check(assert_type(i0.is_leap_year, np_1darray[np.bool]), np_1darray[np.bool]) check(assert_type(i0.daysinmonth, "pd.Index[int]"), pd.Index, np.int32) check(assert_type(i0.days_in_month, "pd.Index[int]"), pd.Index, np.int32) check(assert_type(i0.tz, Optional[dt.tzinfo]), type(None)) @@ -632,8 +636,8 @@ def test_datetimeindex_accessors() -> None: check(assert_type(i0.isocalendar(), pd.DataFrame), pd.DataFrame) check(assert_type(i0.to_period("D"), pd.PeriodIndex), pd.PeriodIndex, pd.Period) check( - assert_type(i0.to_pydatetime(), npt.NDArray[np.object_]), - np.ndarray, + assert_type(i0.to_pydatetime(), np_1darray[np.object_]), + np_1darray[np.object_], dt.datetime, ) ilocal = i0.tz_localize("UTC") @@ -674,7 +678,11 @@ def test_timedeltaindex_accessors() -> None: check(assert_type(i0.microseconds, pd.Index), pd.Index, np.integer) check(assert_type(i0.nanoseconds, pd.Index), pd.Index, np.integer) check(assert_type(i0.components, pd.DataFrame), pd.DataFrame) - check(assert_type(i0.to_pytimedelta(), np.ndarray), np.ndarray) + check( + assert_type(i0.to_pytimedelta(), np_1darray[np.object_]), + np_1darray[np.object_], + dt.timedelta, + ) check(assert_type(i0.total_seconds(), pd.Index), pd.Index, float) check( assert_type(i0.round("D"), pd.TimedeltaIndex), pd.TimedeltaIndex, pd.Timedelta @@ -877,9 +885,9 @@ def test_timestampseries_offset() -> None: def test_types_to_numpy() -> None: td_s = pd.to_timedelta(pd.Series([10, 20]), "minutes") - check(assert_type(td_s.to_numpy(), np.ndarray), np.ndarray) - check(assert_type(td_s.to_numpy(dtype="int", copy=True), np.ndarray), np.ndarray) - check(assert_type(td_s.to_numpy(na_value=pd.Timedelta(0)), np.ndarray), np.ndarray) + check(assert_type(td_s.to_numpy(), np_1darray), np_1darray) + check(assert_type(td_s.to_numpy(dtype="int", copy=True), np_1darray), np_1darray) + check(assert_type(td_s.to_numpy(na_value=pd.Timedelta(0)), np_1darray), np_1darray) def test_to_timedelta_units() -> None: