Skip to content

Commit 9700b94

Browse files
Fixed several issues with tensor.arange found during integration with dpnp
dpt.arange(0, stop=10, step=None) raised, works in numpy dpt.arange(9.7, stop=10) gave empty array, gives 1 element array in numpy dpt.arange(0,stop=2, dtype='bool') raised, works in numpy First two were just bugs, and got fixed. The last one now works through special-casing bools. It works by constructing int8 temporary and casting it into bool array only if the resulting sequence has length 0, 1, or 2. Aligned with behavior of np.arange in computation of the step. To this end changed the logic of determining step argument for the call to `_linspace_step` routine. We now compute first and second element of the array of given type, and determine the step as a the difference of these. To avoid possible overflow message when subtracting unsigned integers, cast first and second element to int64, subtract, and cast to the target type.
1 parent 59980a2 commit 9700b94

File tree

1 file changed

+37
-9
lines changed

1 file changed

+37
-9
lines changed

dpctl/tensor/_ctors.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -483,21 +483,22 @@ def _coerce_and_infer_dt(*args, dt, sycl_queue, err_msg, allow_bool=False):
483483

484484
def _round_for_arange(tmp):
485485
k = int(tmp)
486-
if k > 0 and float(k) < tmp:
486+
if k >= 0 and float(k) < tmp:
487487
tmp = tmp + 1
488488
return tmp
489489

490490

491491
def _get_arange_length(start, stop, step):
492492
"Compute length of arange sequence"
493493
span = stop - start
494-
if type(step) in [int, float] and type(span) in [int, float]:
494+
if hasattr(step, "__float__") and hasattr(span, "__float__"):
495495
return _round_for_arange(span / step)
496496
tmp = span / step
497-
if type(tmp) is complex and tmp.imag == 0:
497+
if hasattr(tmp, "__complex__"):
498+
tmp = complex(tmp)
498499
tmp = tmp.real
499500
else:
500-
return tmp
501+
tmp = float(tmp)
501502
return _round_for_arange(tmp)
502503

503504

@@ -536,13 +537,18 @@ def arange(
536537
if stop is None:
537538
stop = start
538539
start = 0
540+
if step is None:
541+
step = 1
539542
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
540543
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
541-
(start, stop, step,), dt = _coerce_and_infer_dt(
544+
is_bool = False
545+
if dtype:
546+
is_bool = (dtype is bool) or (dpt.dtype(dtype) == dpt.bool)
547+
(start_, stop_, step_), dt = _coerce_and_infer_dt(
542548
start,
543549
stop,
544550
step,
545-
dt=dtype,
551+
dt=dpt.int8 if is_bool else dtype,
546552
sycl_queue=sycl_queue,
547553
err_msg="start, stop, and step must be Python scalars",
548554
allow_bool=False,
@@ -554,18 +560,40 @@ def arange(
554560
sh = 0
555561
except TypeError:
556562
sh = 0
563+
if is_bool and sh > 2:
564+
raise ValueError("no fill-function for boolean data type")
557565
res = dpt.usm_ndarray(
558566
(sh,),
559567
dtype=dt,
560568
buffer=usm_type,
561569
order="C",
562570
buffer_ctor_kwargs={"queue": sycl_queue},
563571
)
564-
_step = (start + step) - start
565-
_step = dt.type(_step)
566-
_start = dt.type(start)
572+
sc_ty = dt.type
573+
_first = sc_ty(start)
574+
if sh > 1:
575+
_second = sc_ty(start + step)
576+
if dt in [dpt.uint8, dpt.uint16, dpt.uint32, dpt.uint64]:
577+
int64_ty = dpt.int64.type
578+
_step = int64_ty(_second) - int64_ty(_first)
579+
else:
580+
_step = _second - _first
581+
_step = sc_ty(_step)
582+
else:
583+
_step = sc_ty(1)
584+
_start = _first
567585
hev, _ = ti._linspace_step(_start, _step, res, sycl_queue)
568586
hev.wait()
587+
if is_bool:
588+
res_out = dpt.usm_ndarray(
589+
(sh,),
590+
dtype=dpt.bool,
591+
buffer=usm_type,
592+
order="C",
593+
buffer_ctor_kwargs={"queue": sycl_queue},
594+
)
595+
res_out[:] = res
596+
res = res_out
569597
return res
570598

571599

0 commit comments

Comments
 (0)