diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index 6677670e73..a775376d95 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -255,7 +255,7 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev): raise ValueError o1_kind_num = _weak_type_num_kind(o1_dtype) o2_kind_num = _strong_dtype_num_kind(o2_dtype) - if o1_kind_num > o2_kind_num: + if o1_kind_num > o2_kind_num or o1_kind_num == 2: if isinstance(o1_dtype, WeakBooleanType): return dpt.bool, o2_dtype if isinstance(o1_dtype, WeakIntegralType): @@ -273,7 +273,7 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev): ): o1_kind_num = _strong_dtype_num_kind(o1_dtype) o2_kind_num = _weak_type_num_kind(o2_dtype) - if o2_kind_num > o1_kind_num: + if o2_kind_num > o1_kind_num or o2_kind_num == 2: if isinstance(o2_dtype, WeakBooleanType): return o1_dtype, dpt.bool if isinstance(o2_dtype, WeakIntegralType): diff --git a/dpctl/tests/elementwise/test_multiply.py b/dpctl/tests/elementwise/test_multiply.py index cd506cd182..1305154021 100644 --- a/dpctl/tests/elementwise/test_multiply.py +++ b/dpctl/tests/elementwise/test_multiply.py @@ -152,3 +152,18 @@ def test_multiply_python_scalar(arr_dt): assert isinstance(R, dpt.usm_ndarray) R = dpt.multiply(sc, X) assert isinstance(R, dpt.usm_ndarray) + + +def test_multiply_python_scalar_gh1219(): + q = get_queue_or_skip() + + X = dpt.ones(4, dtype="f4", sycl_queue=q) + + r = dpt.multiply(X, 2j) + expected = dpt.multiply(X, dpt.asarray(2j, sycl_queue=q)) + assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) + + # symmetric case + r = dpt.multiply(2j, X) + expected = dpt.multiply(dpt.asarray(2j, sycl_queue=q), X) + assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)