Skip to content

Commit b44272f

Browse files
authored
Make type annotations for NumPy arrays more specific (#1358)
* Make type annotations for NumPy arrays more specific * Add np.generic to nd.dtype annotations * Use Any instead of np.generic
1 parent 515802c commit b44272f

File tree

2 files changed

+23
-19
lines changed

2 files changed

+23
-19
lines changed

rustworkx/__init__.pyi

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import sys
1313
import numpy as np
14+
import numpy.typing as npt
1415

1516
from typing import Generic, Any, Callable, overload
1617
from collections.abc import Iterator, Sequence
@@ -289,7 +290,7 @@ def distance_matrix(
289290
parallel_threshold: int = ...,
290291
as_undirected: bool = ...,
291292
null_value: float = ...,
292-
) -> np.ndarray: ...
293+
) -> npt.NDArray[np.float64]: ...
293294
def unweighted_average_shortest_path_length(
294295
graph: PyGraph | PyDiGraph,
295296
parallel_threshold: int = ...,
@@ -300,7 +301,7 @@ def adjacency_matrix(
300301
weight_fn: Callable[[_T], float] | None = ...,
301302
default_weight: float = ...,
302303
null_value: float = ...,
303-
) -> np.ndarray: ...
304+
) -> npt.NDArray[np.float64]: ...
304305
def all_simple_paths(
305306
graph: PyGraph | PyDiGraph,
306307
from_: int,
@@ -319,13 +320,13 @@ def floyd_warshall_numpy(
319320
weight_fn: Callable[[_T], float] | None = ...,
320321
default_weight: float = ...,
321322
parallel_threshold: int = ...,
322-
) -> np.ndarray: ...
323+
) -> npt.NDArray[np.float64]: ...
323324
def floyd_warshall_successor_and_distance(
324325
graph: PyGraph[_S, _T] | PyDiGraph[_S, _T],
325326
weight_fn: Callable[[_T], float] | None = ...,
326327
default_weight: float | None = ...,
327328
parallel_threshold: int | None = ...,
328-
) -> tuple[np.ndarray, np.ndarray]: ...
329+
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: ...
329330
def astar_shortest_path(
330331
graph: PyGraph[_S, _T] | PyDiGraph[_S, _T],
331332
node: int,

rustworkx/rustworkx.pyi

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ from rustworkx import generators # noqa
3535
from typing_extensions import Self
3636

3737
import numpy as np
38+
import numpy.typing as npt
3839
import sys
3940

4041
if sys.version_info >= (3, 13):
@@ -206,15 +207,15 @@ def digraph_adjacency_matrix(
206207
default_weight: float = ...,
207208
null_value: float = ...,
208209
parallel_edge: str = ...,
209-
) -> np.ndarray: ...
210+
) -> npt.NDArray[np.float64]: ...
210211
def graph_adjacency_matrix(
211212
graph: PyGraph[_S, _T],
212213
/,
213214
weight_fn: Callable[[_T], float] | None = ...,
214215
default_weight: float = ...,
215216
null_value: float = ...,
216217
parallel_edge: str = ...,
217-
) -> np.ndarray: ...
218+
) -> npt.NDArray[np.float64]: ...
218219
def cycle_basis(graph: PyGraph, /, root: int | None = ...) -> list[list[int]]: ...
219220
def articulation_points(graph: PyGraph, /) -> set[int]: ...
220221
def bridges(graph: PyGraph, /) -> set[tuple[int]]: ...
@@ -595,14 +596,14 @@ def undirected_gnp_random_graph(
595596
) -> PyGraph: ...
596597
def directed_sbm_random_graph(
597598
sizes: list[int],
598-
probabilities: np.ndarray,
599+
probabilities: npt.NDArray[np.float64],
599600
loops: bool,
600601
/,
601602
seed: int | None = ...,
602603
) -> PyDiGraph: ...
603604
def undirected_sbm_random_graph(
604605
sizes: list[int],
605-
probabilities: np.ndarray,
606+
probabilities: npt.NDArray[np.float64],
606607
loops: bool,
607608
/,
608609
seed: int | None = ...,
@@ -863,13 +864,13 @@ def digraph_distance_matrix(
863864
parallel_threshold: int | None = ...,
864865
as_undirected: bool | None = ...,
865866
null_value: float | None = ...,
866-
) -> np.ndarray: ...
867+
) -> npt.NDArray[np.float64]: ...
867868
def graph_distance_matrix(
868869
graph: PyGraph,
869870
/,
870871
parallel_threshold: int | None = ...,
871872
null_value: float | None = ...,
872-
) -> np.ndarray: ...
873+
) -> npt.NDArray[np.float64]: ...
873874
def digraph_floyd_warshall(
874875
graph: PyDiGraph[_S, _T],
875876
/,
@@ -892,29 +893,29 @@ def digraph_floyd_warshall_numpy(
892893
as_undirected: bool | None = ...,
893894
default_weight: float | None = ...,
894895
parallel_threshold: int | None = ...,
895-
) -> np.ndarray: ...
896+
) -> npt.NDArray[np.float64]: ...
896897
def graph_floyd_warshall_numpy(
897898
graph: PyGraph[_S, _T],
898899
/,
899900
weight_fn: Callable[[_T], float] | None = ...,
900901
default_weight: float | None = ...,
901902
parallel_threshold: int | None = ...,
902-
) -> np.ndarray: ...
903+
) -> npt.NDArray[np.float64]: ...
903904
def digraph_floyd_warshall_successor_and_distance(
904905
graph: PyDiGraph[_S, _T],
905906
/,
906907
weight_fn: Callable[[_T], float] | None = ...,
907908
as_undirected: bool | None = ...,
908909
default_weight: float | None = ...,
909910
parallel_threshold: int | None = ...,
910-
) -> tuple[np.ndarray, np.ndarray]: ...
911+
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: ...
911912
def graph_floyd_warshall_successor_and_distance(
912913
graph: PyGraph[_S, _T],
913914
/,
914915
weight_fn: Callable[[_T], float] | None = ...,
915916
default_weight: float | None = ...,
916917
parallel_threshold: int | None = ...,
917-
) -> tuple[np.ndarray, np.ndarray]: ...
918+
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: ...
918919
def find_negative_cycle(
919920
graph: PyDiGraph[_S, _T],
920921
edge_cost_fn: Callable[[_T], float],
@@ -1079,7 +1080,9 @@ class _RustworkxCustomVecIter(Generic[_T_co], Sequence[_T_co], ABC):
10791080
def __len__(self) -> int: ...
10801081
def __ne__(self, other: object) -> bool: ...
10811082
def __setstate__(self, state: Sequence[_T_co]) -> None: ...
1082-
def __array__(self, dtype: np.dtype | None = ..., copy: bool | None = ...) -> np.ndarray: ...
1083+
def __array__(
1084+
self, dtype: np.dtype[Any] | None = ..., copy: bool | None = ...
1085+
) -> npt.NDArray[Any]: ...
10831086
def __iter__(self) -> Iterator[_T_co]: ...
10841087
def __reversed__(self) -> Iterator[_T_co]: ...
10851088

@@ -1235,11 +1238,11 @@ class PyGraph(Generic[_S, _T]):
12351238
) -> int | None: ...
12361239
@staticmethod
12371240
def from_adjacency_matrix(
1238-
matrix: np.ndarray, /, null_value: float = ...
1241+
matrix: npt.NDArray[np.float64], /, null_value: float = ...
12391242
) -> PyGraph[int, float]: ...
12401243
@staticmethod
12411244
def from_complex_adjacency_matrix(
1242-
matrix: np.ndarray, /, null_value: complex = ...
1245+
matrix: npt.NDArray[np.complex64], /, null_value: complex = ...
12431246
) -> PyGraph[int, complex]: ...
12441247
def get_all_edge_data(self, node_a: int, node_b: int, /) -> list[_T]: ...
12451248
def get_edge_data(self, node_a: int, node_b: int, /) -> _T: ...
@@ -1400,11 +1403,11 @@ class PyDiGraph(Generic[_S, _T]):
14001403
) -> list[_S]: ...
14011404
@staticmethod
14021405
def from_adjacency_matrix(
1403-
matrix: np.ndarray, /, null_value: float = ...
1406+
matrix: npt.NDArray[np.float64], /, null_value: float = ...
14041407
) -> PyDiGraph[int, float]: ...
14051408
@staticmethod
14061409
def from_complex_adjacency_matrix(
1407-
matrix: np.ndarray, /, null_value: complex = ...
1410+
matrix: npt.NDArray[np.complex64], /, null_value: complex = ...
14081411
) -> PyDiGraph[int, complex]: ...
14091412
def get_all_edge_data(self, node_a: int, node_b: int, /) -> list[_T]: ...
14101413
def get_edge_data(self, node_a: int, node_b: int, /) -> _T: ...

0 commit comments

Comments
 (0)