Skip to content

Commit 01b6cc1

Browse files
Added array-API dpctl.tensor.linspace
Added tests
1 parent b7b9bcf commit 01b6cc1

File tree

3 files changed

+135
-17
lines changed

3 files changed

+135
-17
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
empty_like,
3030
full,
3131
full_like,
32+
linspace,
3233
ones,
3334
ones_like,
3435
zeros,
@@ -61,6 +62,7 @@
6162
"zeros",
6263
"ones",
6364
"full",
65+
"linspace",
6466
"empty_like",
6567
"zeros_like",
6668
"ones_like",

dpctl/tensor/_ctors.py

Lines changed: 105 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import operator
18+
1719
import numpy as np
1820

1921
import dpctl
@@ -196,9 +198,12 @@ def _asarray_from_numpy_ndarray(
196198
raise TypeError(f"Expected numpy.ndarray, got {type(ary)}")
197199
if usm_type is None:
198200
usm_type = "device"
199-
if dtype is None:
200-
dtype = _get_dtype(dtype, sycl_queue, ref_type=ary.dtype)
201201
copy_q = normalize_queue_device(sycl_queue=None, device=sycl_queue)
202+
if dtype is None:
203+
ary_dtype = ary.dtype
204+
dtype = _get_dtype(dtype, copy_q, ref_type=ary_dtype)
205+
if dtype.itemsize > ary_dtype.itemsize:
206+
dtype = ary_dtype
202207
f_contig = ary.flags["F"]
203208
c_contig = ary.flags["C"]
204209
fc_contig = f_contig or c_contig
@@ -292,7 +297,7 @@ def asarray(
292297
for output array allocation and copying. `sycl_queue` and `device`
293298
are exclusive keywords, i.e. use one or another. If both are
294299
specified, a `TypeError` is raised unless both imply the same
295-
underlying SYCL queue to be used. If both a `None`, the
300+
underlying SYCL queue to be used. If both are `None`, the
296301
`dpctl.SyclQueue()` is used for allocation and copying.
297302
Default: `None`.
298303
"""
@@ -430,7 +435,7 @@ def empty(
430435
for output array allocation and copying. `sycl_queue` and `device`
431436
are exclusive keywords, i.e. use one or another. If both are
432437
specified, a `TypeError` is raised unless both imply the same
433-
underlying SYCL queue to be used. If both a `None`, the
438+
underlying SYCL queue to be used. If both are `None`, the
434439
`dpctl.SyclQueue()` is used for allocation and copying.
435440
Default: `None`.
436441
"""
@@ -453,18 +458,20 @@ def empty(
453458
return res
454459

455460

456-
def _coerce_and_infer_dt(*args, dt, sycl_queue):
461+
def _coerce_and_infer_dt(*args, dt, sycl_queue, err_msg, allow_bool=False):
457462
"Deduce arange type from sequence spec"
458463
nd, seq_dt, d = _array_info_sequence(args)
459464
if d != _host_set or nd != (len(args),):
460-
raise ValueError("start, stop and step must be Python scalars")
465+
raise ValueError(err_msg)
461466
dt = _get_dtype(dt, sycl_queue, ref_type=seq_dt)
462467
if np.issubdtype(dt, np.integer):
463468
return tuple(int(v) for v in args), dt
464469
elif np.issubdtype(dt, np.floating):
465470
return tuple(float(v) for v in args), dt
466471
elif np.issubdtype(dt, np.complexfloating):
467472
return tuple(complex(v) for v in args), dt
473+
elif allow_bool and dt.char == "?":
474+
return tuple(bool(v) for v in args), dt
468475
else:
469476
raise ValueError(f"Data type {dt} is not supported")
470477

@@ -517,7 +524,7 @@ def arange(
517524
for output array allocation and copying. `sycl_queue` and `device`
518525
are exclusive keywords, i.e. use one or another. If both are
519526
specified, a `TypeError` is raised unless both imply the same
520-
underlying SYCL queue to be used. If both a `None`, the
527+
underlying SYCL queue to be used. If both are `None`, the
521528
`dpctl.SyclQueue()` is used for allocation and copying.
522529
Default: `None`.
523530
"""
@@ -526,12 +533,14 @@ def arange(
526533
start = 0
527534
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
528535
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
529-
(
536+
(start, stop, step,), dt = _coerce_and_infer_dt(
530537
start,
531538
stop,
532539
step,
533-
), dt = _coerce_and_infer_dt(
534-
start, stop, step, dt=dtype, sycl_queue=sycl_queue
540+
dt=dtype,
541+
sycl_queue=sycl_queue,
542+
err_msg="start, stop, and step must be Python scalars",
543+
allow_bool=False,
535544
)
536545
try:
537546
tmp = _get_arange_length(start, stop, step)
@@ -579,7 +588,7 @@ def zeros(
579588
for output array allocation and copying. `sycl_queue` and `device`
580589
are exclusive keywords, i.e. use one or another. If both are
581590
specified, a `TypeError` is raised unless both imply the same
582-
underlying SYCL queue to be used. If both a `None`, the
591+
underlying SYCL queue to be used. If both are `None`, the
583592
`dpctl.SyclQueue()` is used for allocation and copying.
584593
Default: `None`.
585594
"""
@@ -627,7 +636,7 @@ def ones(
627636
for output array allocation and copying. `sycl_queue` and `device`
628637
are exclusive keywords, i.e. use one or another. If both are
629638
specified, a `TypeError` is raised unless both imply the same
630-
underlying SYCL queue to be used. If both a `None`, the
639+
underlying SYCL queue to be used. If both are `None`, the
631640
`dpctl.SyclQueue()` is used for allocation and copying.
632641
Default: `None`.
633642
"""
@@ -683,7 +692,7 @@ def full(
683692
for output array allocation and copying. `sycl_queue` and `device`
684693
are exclusive keywords, i.e. use one or another. If both are
685694
specified, a `TypeError` is raised unless both imply the same
686-
underlying SYCL queue to be used. If both a `None`, the
695+
underlying SYCL queue to be used. If both are `None`, the
687696
`dpctl.SyclQueue()` is used for allocation and copying.
688697
Default: `None`.
689698
"""
@@ -733,7 +742,7 @@ def empty_like(
733742
for output array allocation and copying. `sycl_queue` and `device`
734743
are exclusive keywords, i.e. use one or another. If both are
735744
specified, a `TypeError` is raised unless both imply the same
736-
underlying SYCL queue to be used. If both a `None`, the
745+
underlying SYCL queue to be used. If both are `None`, the
737746
`dpctl.SyclQueue()` is used for allocation and copying.
738747
Default: `None`.
739748
"""
@@ -790,7 +799,7 @@ def zeros_like(
790799
for output array allocation and copying. `sycl_queue` and `device`
791800
are exclusive keywords, i.e. use one or another. If both are
792801
specified, a `TypeError` is raised unless both imply the same
793-
underlying SYCL queue to be used. If both a `None`, the
802+
underlying SYCL queue to be used. If both are `None`, the
794803
`dpctl.SyclQueue()` is used for allocation and copying.
795804
Default: `None`.
796805
"""
@@ -847,7 +856,7 @@ def ones_like(
847856
for output array allocation and copying. `sycl_queue` and `device`
848857
are exclusive keywords, i.e. use one or another. If both are
849858
specified, a `TypeError` is raised unless both imply the same
850-
underlying SYCL queue to be used. If both a `None`, the
859+
underlying SYCL queue to be used. If both are `None`, the
851860
`dpctl.SyclQueue()` is used for allocation and copying.
852861
Default: `None`.
853862
"""
@@ -911,7 +920,7 @@ def full_like(
911920
for output array allocation and copying. `sycl_queue` and `device`
912921
are exclusive keywords, i.e. use one or another. If both are
913922
specified, a `TypeError` is raised unless both imply the same
914-
underlying SYCL queue to be used. If both a `None`, the
923+
underlying SYCL queue to be used. If both are `None`, the
915924
`dpctl.SyclQueue()` is used for allocation and copying.
916925
Default: `None`.
917926
"""
@@ -942,3 +951,82 @@ def full_like(
942951
usm_type=usm_type,
943952
sycl_queue=sycl_queue,
944953
)
954+
955+
956+
def linspace(
957+
start,
958+
stop,
959+
/,
960+
num,
961+
*,
962+
dtype=None,
963+
device=None,
964+
endpoint=True,
965+
sycl_queue=None,
966+
usm_type="device",
967+
):
968+
"""
969+
linspace(start, stop, num, dtype=None, device=None, endpoint=True,
970+
sycl_queue=None, usm_type=None): usm_ndarray
971+
972+
Returns evenly spaced numbers of specified interval.
973+
974+
Args:
975+
start: the start of the interval.
976+
stop: the end of the interval. If the `endpoint` is `False`, the
977+
function must generate `num+1` evenly spaced points starting
978+
with `start` and ending with `stop` and exclude the `stop`
979+
from the returned array such that the returned array consists
980+
of evenly spaced numbers over the half-open interval
981+
`[start, stop)`. If `endpoint` is `True`, the output
982+
array must consist of evenly spaced numbers over the closed
983+
interval `[start, stop]`. Default: `True`.
984+
num: number of samples. Must be a non-negative integer; otherwise,
985+
the function must raise an exception.
986+
dtype: output array data type. Should be a floating data type.
987+
If `dtype` is `None`, the output array must be the default
988+
floating point data type. Default: `None`.
989+
device (optional): array API concept of device where the output array
990+
is created. `device` can be `None`, a oneAPI filter selector string,
991+
an instance of :class:`dpctl.SyclDevice` corresponding to a
992+
non-partitioned SYCL device, an instance of
993+
:class:`dpctl.SyclQueue`, or a `Device` object returnedby
994+
`dpctl.tensor.usm_array.device`. Default: `None`.
995+
usm_type ("device"|"shared"|"host", optional): The type of SYCL USM
996+
allocation for the output array. Default: `"device"`.
997+
sycl_queue (:class:`dpctl.SyclQueue`, optional): The SYCL queue to use
998+
for output array allocation and copying. `sycl_queue` and `device`
999+
are exclusive keywords, i.e. use one or another. If both are
1000+
specified, a `TypeError` is raised unless both imply the same
1001+
underlying SYCL queue to be used. If both are `None`, the
1002+
`dpctl.SyclQueue()` is used for allocation and copying.
1003+
Default: `None`.
1004+
endpoint: boolean indicating whether to include `stop` in the
1005+
interval. Default: `True`.
1006+
"""
1007+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
1008+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
1009+
if endpoint not in [True, False]:
1010+
raise TypeError("endpoint keyword argument must be of boolean type")
1011+
num = operator.index(num)
1012+
if num < 0:
1013+
raise ValueError("Number of points must be non-negative")
1014+
((start, stop,), dt) = _coerce_and_infer_dt(
1015+
start,
1016+
stop,
1017+
dt=dtype,
1018+
sycl_queue=sycl_queue,
1019+
err_msg="start and stop must be Python scalars.",
1020+
allow_bool=True,
1021+
)
1022+
if dtype is None and np.issubdtype(dt, np.integer):
1023+
dt = ti.default_device_fp_type(sycl_queue)
1024+
dt = np.dtype(dt)
1025+
start = float(start)
1026+
stop = float(stop)
1027+
res = dpt.empty(num, dtype=dt, sycl_queue=sycl_queue)
1028+
hev, _ = ti._linspace_affine(
1029+
start, stop, dst=res, include_endpoint=endpoint, sycl_queue=sycl_queue
1030+
)
1031+
hev.wait()
1032+
return res

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,34 @@ def test_arange_fp():
10711071
assert dpt.arange(0, 1, 0.25, dtype="f4", device=q).shape == (4,)
10721072

10731073

1074+
@pytest.mark.parametrize(
1075+
"dt",
1076+
_all_dtypes,
1077+
)
1078+
def test_linspace(dt):
1079+
try:
1080+
q = dpctl.SyclQueue()
1081+
except dpctl.SyclQueueCreationError:
1082+
pytest.skip("Default queue could not be created")
1083+
X = dpt.linspace(0, 1, num=2, dtype=dt, sycl_queue=q)
1084+
assert np.allclose(dpt.asnumpy(X), np.linspace(0, 1, num=2, dtype=dt))
1085+
1086+
1087+
def test_linspace_fp():
1088+
try:
1089+
q = dpctl.SyclQueue()
1090+
except dpctl.SyclQueueCreationError:
1091+
pytest.skip("Default queue could not be created")
1092+
n = 16
1093+
X = dpt.linspace(0, n - 1, num=n, sycl_queue=q)
1094+
if q.sycl_device.has_aspect_fp64:
1095+
assert X.dtype == np.dtype("float64")
1096+
else:
1097+
assert X.dtype == np.dtype("float32")
1098+
assert X.shape == (n,)
1099+
assert X.strides == (1,)
1100+
1101+
10741102
@pytest.mark.parametrize(
10751103
"dt",
10761104
_all_dtypes,

0 commit comments

Comments
 (0)