Skip to content

Commit 1ed9a6d

Browse files
Merge pull request #1051 from IntelPython/change_linspace_int
Changed generating integer output for linspace() function.
2 parents 61430eb + 6e358cf commit 1ed9a6d

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

dpctl/tensor/_ctors.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,25 +1119,28 @@ def linspace(
11191119
num = operator.index(num)
11201120
if num < 0:
11211121
raise ValueError("Number of points must be non-negative")
1122-
((start, stop,), dt) = _coerce_and_infer_dt(
1122+
_, dt = _coerce_and_infer_dt(
11231123
start,
11241124
stop,
11251125
dt=dtype,
11261126
sycl_queue=sycl_queue,
11271127
err_msg="start and stop must be Python scalars.",
11281128
allow_bool=True,
11291129
)
1130-
if dtype is None and np.issubdtype(dt, np.integer):
1130+
int_dt = None
1131+
if np.issubdtype(dt, np.integer):
1132+
if dtype is not None:
1133+
int_dt = dt
11311134
dt = ti.default_device_fp_type(sycl_queue)
11321135
dt = dpt.dtype(dt)
11331136
start = float(start)
11341137
stop = float(stop)
1135-
res = dpt.empty(num, dtype=dt, sycl_queue=sycl_queue)
1138+
res = dpt.empty(num, dtype=dt, usm_type=usm_type, sycl_queue=sycl_queue)
11361139
hev, _ = ti._linspace_affine(
11371140
start, stop, dst=res, include_endpoint=endpoint, sycl_queue=sycl_queue
11381141
)
11391142
hev.wait()
1140-
return res
1143+
return res if int_dt is None else dpt.astype(res, int_dt)
11411144

11421145

11431146
def eye(

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,13 @@ def test_linspace_fp_max(dtype):
13161316
)
13171317

13181318

1319+
def test_linspace_int():
1320+
q = get_queue_or_skip()
1321+
X = dpt.linspace(0.1, 9.1, 11, endpoint=True, dtype=int, sycl_queue=q)
1322+
Xnp = np.linspace(0.1, 9.1, 11, endpoint=True, dtype=int)
1323+
assert np.array_equal(dpt.asnumpy(X), Xnp)
1324+
1325+
13191326
@pytest.mark.parametrize(
13201327
"dt",
13211328
_all_dtypes,

0 commit comments

Comments
 (0)