Skip to content

Commit 04efec6

Browse files
committed
Add OpenCL-based test
1 parent ed4e9ca commit 04efec6

File tree

2 files changed

+81
-1
lines changed

2 files changed

+81
-1
lines changed

dpctl/tests/test_work_group_memory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
"""Defines unit test cases for the SyclProgram and SyclKernel classes"""
17+
"""Defines unit test cases for the work_group_memory in a SYCL kernel"""
1818

1919
import os
2020

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Defines unit test cases for the work_group_memory in an OpenCL kernel"""
18+
19+
import numpy as np
20+
import pytest
21+
22+
import dpctl
23+
import dpctl.tensor
24+
25+
ocl_kernel_src = """
26+
__kernel void local_mem_kernel(__global float *input, __global float *output,
27+
__local float *local_data) {
28+
int gid = get_global_id(0);
29+
int lid = get_local_id(0);
30+
31+
// Load input data into local memory
32+
local_data[lid] = input[gid];
33+
34+
// Store the data in the output array
35+
output[gid] = local_data[lid];
36+
}
37+
"""
38+
39+
40+
def test_submit_work_group_memory_opencl():
41+
if not dpctl.experimental.WorkGroupMemory.is_available():
42+
pytest.skip("Work group memory extension not supported")
43+
44+
try:
45+
q = dpctl.SyclQueue("opencl")
46+
except dpctl.SyclQueueCreationError:
47+
pytest.skip("OpenCL queue could not be created")
48+
49+
prog = dpctl.program.create_program_from_source(q, ocl_kernel_src)
50+
kernel = prog.get_sycl_kernel("local_mem_kernel")
51+
local_size = 16
52+
global_size = local_size * 8
53+
54+
x_dev = dpctl.memory.MemoryUSMDevice(global_size * 4, queue=q)
55+
y_dev = dpctl.memory.MemoryUSMDevice(global_size * 4, queue=q)
56+
57+
x = np.ones(global_size, dtype="float32")
58+
y = np.zeros(global_size, dtype="float32")
59+
q.memcpy(x_dev, x, x_dev.nbytes)
60+
q.memcpy(y_dev, y, y_dev.nbytes)
61+
62+
try:
63+
q.submit(
64+
kernel,
65+
[
66+
x_dev,
67+
y_dev,
68+
dpctl.experimental.WorkGroupMemory(local_size * x.itemsize),
69+
],
70+
[global_size],
71+
[local_size],
72+
)
73+
q.wait()
74+
except dpctl._sycl_queue.SyclKernelSubmitError:
75+
pytest.fail("Foo")
76+
pytest.skip(f"Kernel submission to {q.sycl_device} failed")
77+
78+
q.memcpy(y, y_dev, y_dev.nbytes)
79+
80+
assert np.all(x == y)

0 commit comments

Comments
 (0)