From cf0f0db9ee7478e44d1acf01992351a0f6e3e329 Mon Sep 17 00:00:00 2001 From: GUAN MING Date: Sun, 9 Mar 2025 01:23:50 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=F0=9F=A9=B9=20add=20overload=20for=20`ndar?= =?UTF-8?q?ray.=5F=5Fmatmul=5F=5F`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/numpy-stubs/__init__.pyi | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/numpy-stubs/__init__.pyi b/src/numpy-stubs/__init__.pyi index 08922b3b..c562e0e1 100644 --- a/src/numpy-stubs/__init__.pyi +++ b/src/numpy-stubs/__init__.pyi @@ -27,6 +27,8 @@ from typing import ( ) from typing_extensions import Buffer, CapsuleType, LiteralString, Never, Protocol, Self, TypeVar, Unpack, deprecated, override +import numpy as np + from . import ( __config__ as __config__, _array_api_info as _array_api_info, @@ -611,6 +613,8 @@ _DT64ItemT = TypeVar("_DT64ItemT", bound=dt.date | int | None) _DT64ItemT_co = TypeVar("_DT64ItemT_co", bound=dt.date | int | None, default=dt.date | int | None, covariant=True) _TD64UnitT = TypeVar("_TD64UnitT", bound=_TD64Unit, default=_TD64Unit) +_Array1D: TypeAlias = np.ndarray[tuple[int], np.dtype[_ScalarT]] + ### # Type Aliases (for internal use only) @@ -2530,9 +2534,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): def __imul__(self: NDArray[complexfloating], rhs: _ArrayLikeComplex_co, /) -> ndarray[_ShapeT_co, _DTypeT_co]: ... @overload def __imul__(self: NDArray[object_], rhs: object, /) -> ndarray[_ShapeT_co, _DTypeT_co]: ... - - # TODO(jorenham): Support the "1d @ 1d -> scalar" case - # https://github.com/numpy/numtype/issues/197 + @overload + def __matmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ... @overload def __matmul__(self: NDArray[_NumberT], rhs: _ArrayLikeBool_co, /) -> NDArray[_NumberT]: ... @overload @@ -2566,12 +2569,14 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): @overload def __matmul__(self: NDArray[bool_ | number], rhs: _ArrayLikeNumber_co, /) -> NDArray[Incomplete]: ... @overload - def __matmul__(self: NDArray[object_], rhs: object, /) -> NDArray[object_]: ... + def __matmul__(self: NDArray[object_], rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ... @overload def __matmul__(self, rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ... # keep in sync with __matmul__ @overload + def __rmatmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ... + @overload def __rmatmul__(self: NDArray[_NumberT], lhs: _ArrayLikeBool_co, /) -> NDArray[_NumberT]: ... @overload def __rmatmul__(self: NDArray[bool_], lhs: _ArrayLike[_NumberT], /) -> NDArray[_NumberT]: ... @@ -2604,7 +2609,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): @overload def __rmatmul__(self: NDArray[bool_ | number], lhs: _ArrayLikeNumber_co, /) -> NDArray[Incomplete]: ... @overload - def __rmatmul__(self: NDArray[object_], lhs: object, /) -> NDArray[object_]: ... + def __rmatmul__(self: NDArray[object_], lhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ... @overload def __rmatmul__(self, lhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ... From bec757d001327fc6a812fd915a1bcd527fdfecf9 Mon Sep 17 00:00:00 2001 From: GUAN MING Date: Sun, 9 Mar 2025 14:42:42 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=F0=9F=A9=B9=20remove=20numpy=20and=20add?= =?UTF-8?q?=20=5FMatmulScalarT?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/numpy-stubs/__init__.pyi | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/numpy-stubs/__init__.pyi b/src/numpy-stubs/__init__.pyi index c562e0e1..4524fb0c 100644 --- a/src/numpy-stubs/__init__.pyi +++ b/src/numpy-stubs/__init__.pyi @@ -27,8 +27,6 @@ from typing import ( ) from typing_extensions import Buffer, CapsuleType, LiteralString, Never, Protocol, Self, TypeVar, Unpack, deprecated, override -import numpy as np - from . import ( __config__ as __config__, _array_api_info as _array_api_info, @@ -590,6 +588,7 @@ _IntegerT = TypeVar("_IntegerT", bound=integer) _SignedIntegerT = TypeVar("_SignedIntegerT", bound=signedinteger) _UnsignedIntegerT = TypeVar("_UnsignedIntegerT", bound=unsignedinteger) _CharT = TypeVar("_CharT", bound=character) +_IntegralT = TypeVar("_IntegralT", bound=bool_ | number | object_) _NBitT = TypeVar("_NBitT", bound=NBitBase, default=Any) _NBitT1 = TypeVar("_NBitT1", bound=NBitBase, default=Any) @@ -613,7 +612,7 @@ _DT64ItemT = TypeVar("_DT64ItemT", bound=dt.date | int | None) _DT64ItemT_co = TypeVar("_DT64ItemT_co", bound=dt.date | int | None, default=dt.date | int | None, covariant=True) _TD64UnitT = TypeVar("_TD64UnitT", bound=_TD64Unit, default=_TD64Unit) -_Array1D: TypeAlias = np.ndarray[tuple[int], np.dtype[_ScalarT]] +_Array1D: TypeAlias = ndarray[tuple[int], dtype[_ScalarT]] ### # Type Aliases (for internal use only) @@ -2534,8 +2533,10 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): def __imul__(self: NDArray[complexfloating], rhs: _ArrayLikeComplex_co, /) -> ndarray[_ShapeT_co, _DTypeT_co]: ... @overload def __imul__(self: NDArray[object_], rhs: object, /) -> ndarray[_ShapeT_co, _DTypeT_co]: ... + + # @overload - def __matmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ... + def __matmul__(self: _Array1D[_IntegralT], rhs: _Array1D[_IntegralT], /) -> _IntegralT: ... @overload def __matmul__(self: NDArray[_NumberT], rhs: _ArrayLikeBool_co, /) -> NDArray[_NumberT]: ... @overload @@ -2569,14 +2570,12 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): @overload def __matmul__(self: NDArray[bool_ | number], rhs: _ArrayLikeNumber_co, /) -> NDArray[Incomplete]: ... @overload - def __matmul__(self: NDArray[object_], rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ... + def __matmul__(self: NDArray[object_], rhs: object, /) -> NDArray[object_]: ... @overload def __matmul__(self, rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ... # keep in sync with __matmul__ @overload - def __rmatmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ... - @overload def __rmatmul__(self: NDArray[_NumberT], lhs: _ArrayLikeBool_co, /) -> NDArray[_NumberT]: ... @overload def __rmatmul__(self: NDArray[bool_], lhs: _ArrayLike[_NumberT], /) -> NDArray[_NumberT]: ... @@ -2609,7 +2608,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): @overload def __rmatmul__(self: NDArray[bool_ | number], lhs: _ArrayLikeNumber_co, /) -> NDArray[Incomplete]: ... @overload - def __rmatmul__(self: NDArray[object_], lhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ... + def __rmatmul__(self: NDArray[object_], lhs: object, /) -> NDArray[object_]: ... @overload def __rmatmul__(self, lhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ... From 3bb2d1ac33346cfe58530d9dcb9ab87ed7c8e05b Mon Sep 17 00:00:00 2001 From: Wesley Chiu Date: Mon, 10 Mar 2025 11:23:10 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=94=A8=20add=20type=20ignore=20commen?= =?UTF-8?q?ts=20for=20overlapping=20overloads=20in=20`ndarray.=5F=5Fmatmul?= =?UTF-8?q?=5F=5F`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/numpy-stubs/__init__.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/numpy-stubs/__init__.pyi b/src/numpy-stubs/__init__.pyi index 4524fb0c..07a50741 100644 --- a/src/numpy-stubs/__init__.pyi +++ b/src/numpy-stubs/__init__.pyi @@ -2536,7 +2536,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): # @overload - def __matmul__(self: _Array1D[_IntegralT], rhs: _Array1D[_IntegralT], /) -> _IntegralT: ... + def __matmul__(self: _Array1D[_IntegralT], rhs: _Array1D[_IntegralT], /) -> _IntegralT: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] @overload def __matmul__(self: NDArray[_NumberT], rhs: _ArrayLikeBool_co, /) -> NDArray[_NumberT]: ... @overload