Skip to content

Commit 4012039

Browse files
committed
Added finfo_object subclass to np.finfo
- Improves array API conformity
1 parent d7c2e3b commit 4012039

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,29 @@
3131
)
3232

3333

34+
class finfo_object(np.finfo):
35+
"""
36+
numpy.finfo subclass which returns Python floating-point scalars for
37+
eps, max, min, and smallest_normal.
38+
"""
39+
40+
def __init__(self, dtype):
41+
_supported_dtype([dpt.dtype(dtype)])
42+
super().__init__()
43+
44+
self.eps = float(self.eps)
45+
self.max = float(self.max)
46+
self.min = float(self.min)
47+
48+
@property
49+
def smallest_normal(self):
50+
return float(super().smallest_normal)
51+
52+
@property
53+
def tiny(self):
54+
return float(super().tiny)
55+
56+
3457
def _broadcast_strides(X_shape, X_strides, res_ndim):
3558
"""
3659
Broadcasts strides to match the given dimensions;
@@ -495,8 +518,7 @@ def finfo(dtype):
495518
"""
496519
if isinstance(dtype, dpt.usm_ndarray):
497520
raise TypeError("Expected dtype type, got {to}.")
498-
_supported_dtype([dpt.dtype(dtype)])
499-
return np.finfo(dtype)
521+
return finfo_object(dtype)
500522

501523

502524
def _supported_dtype(dtypes):

0 commit comments

Comments
 (0)