Skip to content

Commit ddd8349

Browse files
committed
Python test work group memory
1 parent f964aeb commit ddd8349

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed
Binary file not shown.

dpctl/tests/test_work_group_memory.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Defines unit test cases for the SyclProgram and SyclKernel classes
18+
"""
19+
20+
import os
21+
22+
import pytest
23+
24+
import dpctl
25+
import dpctl.tensor
26+
27+
28+
def get_spirv_abspath(fn):
29+
curr_dir = os.path.dirname(os.path.abspath(__file__))
30+
spirv_file = os.path.join(curr_dir, "input_files", fn)
31+
return spirv_file
32+
33+
34+
def test_submit_work_group_memory():
35+
if not dpctl.experimental.WorkGroupMemory.is_available():
36+
pytest.skip("Work group memory extension not supported")
37+
38+
try:
39+
q = dpctl.SyclQueue("level_zero")
40+
except dpctl.SyclQueueCreationError:
41+
pytest.skip("LevelZero queue could not be created")
42+
spirv_file = get_spirv_abspath("work-group-memory-kernel.spv")
43+
with open(spirv_file, "br") as spv:
44+
spv_bytes = spv.read()
45+
prog = dpctl.program.create_program_from_spirv(q, spv_bytes)
46+
kernel = prog.get_sycl_kernel("__sycl_kernel_local_mem_kernel")
47+
local_size = 16
48+
global_size = local_size * 8
49+
50+
x = dpctl.tensor.ones(global_size, dtype="int32")
51+
y = dpctl.tensor.zeros(global_size, dtype="int32")
52+
x.sycl_queue.wait()
53+
y.sycl_queue.wait()
54+
55+
try:
56+
q.submit(
57+
kernel,
58+
[
59+
x.usm_data,
60+
y.usm_data,
61+
dpctl.experimental.WorkGroupMemory(local_size * x.itemsize),
62+
],
63+
[global_size],
64+
[local_size],
65+
)
66+
q.wait()
67+
except dpctl._sycl_queue.SyclKernelSubmitError:
68+
pytest.skip(f"Kernel submission to {q.sycl_device} failed")
69+
70+
assert dpctl.tensor.all(x == y)

0 commit comments

Comments
 (0)