Skip to content

Commit f2c3fc8

Browse files
committed
Address Joren's comments
1 parent ae20262 commit f2c3fc8

File tree

8 files changed

+90
-32
lines changed

8 files changed

+90
-32
lines changed

src/array_api_typing/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from . import signature_types
1515
from ._array import Array
16-
from ._device import Device
17-
from ._dtype import DType
16+
from ._misc_objects import Device, DType
1817
from ._namespace import ArrayNamespace, HasArrayNamespace
1918
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
"""Static typing support for the array API standard."""
1+
"""Static typing support for array API arrays."""
22

33
from typing import Protocol
44

55
from ._namespace import HasArrayNamespace
66

77

88
class Array(HasArrayNamespace, Protocol):
9-
pass
9+
"""An Array API array of homogenously-typed numbers."""
10+
11+
# TODO(https://github.com/data-apis/array-api-typing/issues/23): Populate this
12+
# protocol with methods defined by the Array API specification.

src/array_api_typing/_device.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

src/array_api_typing/_dtype.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

src/array_api_typing/_misc_objects.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""Static typing support for miscellaneous objects in the array API."""
2+
3+
Device = object # The device on which an Array API array is stored.
4+
DType = object # The type of the numbers contained in an Array API array."""

src/array_api_typing/_namespace.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
1-
"""Static typing support for the array API standard."""
1+
"""Static typing support for array API namespaces."""
22

33
from __future__ import annotations
44

55
from typing import TYPE_CHECKING, Protocol
66
from typing_extensions import TypeVar
77

88
if TYPE_CHECKING:
9+
# This condition exists to prevent a circular import: _array imports _namespace for
10+
# HasArrayNamespace. Therefore, _namespace cannot import _array except when
11+
# type-checking. The type variable depends on Array, so we create a dummy type
12+
# variable without the same bounds and default for this case. In Python 3.13, this
13+
# is no longer be necessary.
14+
from collections.abc import Buffer
15+
916
from ._array import Array
10-
from ._device import Device
11-
from ._dtype import DType
12-
from .signature_types import NestedSequence, SupportsBufferProtocol
17+
from ._misc_objects import Device, DType
18+
from .signature_types import NestedSequence
1319

1420
A = TypeVar("A", bound=Array, default=Array) # PEP 696 default
1521
else:
@@ -21,13 +27,12 @@ class ArrayNamespace(Protocol[A]):
2127

2228
def asarray(
2329
self,
24-
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
30+
obj: Array | complex | NestedSequence[complex] | Buffer,
2531
/,
2632
*,
2733
dtype: DType | None = None,
2834
device: Device | None = None,
2935
copy: bool | None = None,
30-
**kwargs: object,
3136
) -> A: ...
3237

3338
def astype(

src/array_api_typing/signature_types/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
__all__ = [
44
"NestedSequence",
5-
"SupportsBufferProtocol",
65
]
76

8-
from ._signature_types import NestedSequence, SupportsBufferProtocol
7+
from ._signature_types import NestedSequence
Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,77 @@
11
from __future__ import annotations
22

3-
from typing import Any, Protocol, TypeAlias, TypeVar
3+
from typing import TYPE_CHECKING, Protocol, TypeVar, runtime_checkable
4+
5+
if TYPE_CHECKING:
6+
from collections.abc import Iterator
47

58
_T_co = TypeVar("_T_co", covariant=True)
69

710

11+
@runtime_checkable
812
class NestedSequence(Protocol[_T_co]):
9-
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
10-
def __len__(self, /) -> int: ...
13+
"""A protocol for representing nested sequences.
14+
15+
Warning:
16+
-------
17+
`NestedSequence` currently does not work in combination with type variables,
18+
*e.g.* ``def func(a: NestedSequnce[T]) -> T: ...``.
19+
20+
See Also:
21+
--------
22+
collections.abc.Sequence:
23+
ABCs for read-only and mutable :term:`sequences`.
24+
25+
Examples:
26+
--------
27+
.. code-block:: python
28+
29+
>>> from typing import TYPE_CHECKING
30+
>>> import numpy as np
31+
>>> import array_api_typing as xpt
32+
33+
>>> def get_dtype(seq: xpt.NestedSequence[float]) -> np.dtype[np.float64]:
34+
... return np.asarray(seq).dtype
35+
36+
>>> a = get_dtype([1.0])
37+
>>> b = get_dtype([[1.0]])
38+
>>> c = get_dtype([[[1.0]]])
39+
>>> d = get_dtype([[[[1.0]]]])
40+
41+
>>> if TYPE_CHECKING:
42+
... reveal_locals()
43+
... # note: Revealed local types are:
44+
... # note: a: numpy.dtype[numpy.floating[numpy._typing._64Bit]]
45+
... # note: b: numpy.dtype[numpy.floating[numpy._typing._64Bit]]
46+
... # note: c: numpy.dtype[numpy.floating[numpy._typing._64Bit]]
47+
... # note: d: numpy.dtype[numpy.floating[numpy._typing._64Bit]]
48+
49+
"""
50+
51+
def __len__(self, /) -> int:
52+
"""Implement ``len(self)``."""
53+
raise NotImplementedError
54+
55+
def __getitem__(self, index: int, /) -> _T_co | NestedSequence[_T_co]:
56+
"""Implement ``self[x]``."""
57+
raise NotImplementedError
58+
59+
def __contains__(self, x: object, /) -> bool:
60+
"""Implement ``x in self``."""
61+
raise NotImplementedError
62+
63+
def __iter__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]:
64+
"""Implement ``iter(self)``."""
65+
raise NotImplementedError
66+
67+
def __reversed__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]:
68+
"""Implement ``reversed(self)``."""
69+
raise NotImplementedError
1170

71+
def count(self, value: object, /) -> int:
72+
"""Return the number of occurrences of `value`."""
73+
raise NotImplementedError
1274

13-
SupportsBufferProtocol: TypeAlias = Any
75+
def index(self, value: object, /) -> int:
76+
"""Return the first index of `value`."""
77+
raise NotImplementedError

0 commit comments

Comments
 (0)