Skip to content

Commit 097ecf5

Browse files
Modified sycl_timer example to use dpctl.tensor function
This removes use of dpnp.matmul from the example, making this example self-contained.
1 parent 645044a commit 097ecf5

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

examples/python/sycl_timer.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,27 @@
1515
# limitations under the License.
1616

1717

18-
import dpnp
1918
import numpy as np
2019

2120
import dpctl
2221
import dpctl.tensor as dpt
2322
from dpctl import SyclTimer
2423

25-
n = 4000
24+
25+
def matmul(m1, m2):
26+
"""Naive matrix multiplication implementation"""
27+
assert m1.ndim == 2
28+
assert m2.ndim == 2
29+
assert m1.shape[1] == m2.shape[0]
30+
m1 = m1[:, dpt.newaxis, :]
31+
m2 = dpt.permute_dims(m2, (1, 0))[dpt.newaxis, :, :]
32+
# form m_prod[i, j, k] = m1[i,k] * m2[k, j]
33+
m_prods = m1 * m2
34+
# sum over k
35+
return dpt.sum(m_prods, axis=-1)
36+
37+
38+
n = 500
2639

2740
try:
2841
q = dpctl.SyclQueue(property="enable_profiling")
@@ -33,32 +46,36 @@
3346
)
3447
exit(0)
3548

36-
a = dpt.reshape(dpt.arange(n * n, dtype=np.float32, sycl_queue=q), (n, n))
37-
b = dpt.reshape(
38-
dpt.asarray(np.random.random(n * n), dtype=np.float32, sycl_queue=q), (n, n)
39-
)
49+
a_flat = dpt.arange(n * n, dtype=dpt.float32, sycl_queue=q)
50+
a = dpt.reshape(a_flat, (n, n))
4051

41-
timer = SyclTimer(time_scale=1)
52+
b_rand = np.random.random(n * n).astype(np.float32)
53+
b_flat = dpt.asarray(b_rand, dtype=dpt.float32, sycl_queue=q)
54+
b = dpt.reshape(b_flat, (n, n))
4255

4356
wall_times = []
4457
device_times = []
58+
4559
print(
46-
f"Performing matrix multiplication of two {n} by {n} matrices "
60+
f"Computing naive matrix multiplication of two {n} by {n} matrices "
4761
f"on {q.sycl_device.name}, repeating 5 times."
4862
)
63+
print()
4964
for _ in range(5):
65+
timer = SyclTimer(time_scale=1)
5066
with timer(q):
51-
a_matmul_b = dpnp.matmul(a, b)
67+
a_matmul_b = matmul(a, b)
5268
host_time, device_time = timer.dt
5369
wall_times.append(host_time)
5470
device_times.append(device_time)
5571

56-
c = dpnp.asnumpy(a_matmul_b)
57-
cc = np.dot(dpnp.asnumpy(a), dpnp.asnumpy(b))
72+
c = dpt.asnumpy(a_matmul_b)
73+
cc = np.dot(dpt.asnumpy(a), dpt.asnumpy(b))
5874

5975
print("Wall time: ", wall_times, "\nDevice time: ", device_times)
76+
print()
6077
print(
6178
"Accuracy test: passed."
6279
if np.allclose(c, cc)
63-
else (f"Accuracy test: failed. Discrepancy {np.max(np.abs(c-cc))}")
80+
else (f"Accuracy test: FAILED. \n Discrepancy = {np.max(np.abs(c-cc))}")
6481
)

0 commit comments

Comments
 (0)