Skip to content

Commit 2356410

Browse files
committed
Add work group memory to libsyclinterface
1 parent 11a50cd commit 2356410

9 files changed

+194
-8
lines changed

dpctl/_backend.pxd

+16
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,19 @@ cdef extern from "syclinterface/dpctl_sycl_usm_interface.h":
469469
cdef DPCTLSyclDeviceRef DPCTLUSM_GetPointerDevice(
470470
DPCTLSyclUSMRef MRef,
471471
DPCTLSyclContextRef CRef)
472+
473+
cdef extern from "syclinterface/dpctl_sycl_extension_interface.h":
474+
cdef struct RawWorkGroupMemoryTy
475+
ctypedef RawWorkGroupMemoryTy RawWorkGroupMemory
476+
477+
cdef struct DPCTLOpaqueWorkGroupMemory
478+
ctypedef DPCTLOpaqueWorkGroupMemory *DPCTLSyclWorkGroupMemoryRef;
479+
480+
cdef DPCTLSyclWorkGroupMemoryRef DPCTLWorkGroupMemory_Create(size_t nbytes);
481+
482+
cdef void DPCTLWorkGroupMemory_Delete(
483+
DPCTLSyclWorkGroupMemoryRef Ref);
484+
485+
cdef bint DPCTLWorkGroupMemory_Available();
486+
487+

dpctl/_sycl_queue.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ cdef class SyclQueue(_SyclQueue):
835835
kargs[idx]= <void*>(<size_t>arg._pointer)
836836
kargty[idx] = _arg_data_type._VOID_PTR
837837
elif isinstance(arg, WorkGroupMemory):
838-
kargs[idx] = <void*>(<size_t>arg.nbytes)
838+
kargs[idx] = <void*>(<size_t>arg._ref)
839839
kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
840840
else:
841841
ret = -1

dpctl/experimental/_work_group_memory.pxd

+10-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,15 @@
1717
# distutils: language = c++
1818
# cython: language_level=3
1919

20-
cdef public api class WorkGroupMemory [object PyWorkGroupMemoryObject, type PyWorkGroupMemoryType]:
21-
cdef Py_ssize_t nbytes
20+
from .._backend cimport DPCTLSyclWorkGroupMemoryRef
2221

22+
cdef public api class _WorkGroupMemory [
23+
object Py_WorkGroupMemoryObject, type Py_WorkGroupMemoryType
24+
]:
25+
cdef DPCTLSyclWorkGroupMemoryRef _mem_ref
26+
27+
cdef public api class WorkGroupMemory(_WorkGroupMemory) [
28+
object PyWorkGroupMemoryObject, type PyWorkGroupMemoryType
29+
]:
30+
pass
2331

dpctl/experimental/_work_group_memory.pyx

+24-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@
1818
# cython: language_level=3
1919
# cython: linetrace=True
2020

21+
from .._backend cimport (
22+
DPCTLWorkGroupMemory_Create,
23+
DPCTLWorkGroupMemory_Delete,
24+
DPCTLWorkGroupMemory_Available
25+
)
26+
27+
cdef class _WorkGroupMemory:
28+
def __dealloc__(self):
29+
if(self._mem_ref):
30+
DPCTLWorkGroupMemory_Delete(self._mem_ref)
31+
2132
cdef class WorkGroupMemory:
2233
"""
2334
WorkGroupMemory(nbytes)
@@ -31,11 +42,20 @@ cdef class WorkGroupMemory:
3142
Expected to be positive.
3243
"""
3344
def __cinit__(self, Py_ssize_t nbytes):
34-
self.nbytes = nbytes
45+
if not DPCTLWorkGroupMemory_Available():
46+
raise RuntimeError("Workgroup memory extension not available")
47+
48+
self._mem_ref = DPCTLWorkGroupMemory_Create(nbytes)
49+
50+
@staticmethod
51+
def is_available():
52+
return DPCTLWorkGroupMemory_Available()
3553

36-
property nbytes:
37-
"""Local memory size in bytes."""
54+
property _ref:
55+
"""Returns the address of the C API ``DPCTLWorkGroupMemoryRef``
56+
pointer as a ``size_t``.
57+
"""
3858
def __get__(self):
39-
return self.nbytes
59+
return <size_t>self._mem_ref
4060

4161

dpctl/sycl.pxd

+9
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,12 @@ cdef extern from "syclinterface/dpctl_sycl_type_casters.hpp" \
6767
"dpctl::syclinterface::wrap<sycl::event>" (const event *)
6868
cdef event * unwrap_event "dpctl::syclinterface::unwrap<sycl::event>" (
6969
dpctl_backend.DPCTLSyclEventRef)
70+
71+
# work group memory extension[
72+
cdef dpctl_backend.DPCTLSyclWorkGroupMemoryRef wrap_work_group_memory \
73+
"dpctl::syclinterface::wrap<RawWorkGroupMemory>" \
74+
(const RawWorkGroupMemory *)
75+
76+
cdef RawWorkGroupMemory * unwrap_work_group_memory \
77+
"dpctl::syclinterface::unwrap<RawWorkGroupMemory>" (
78+
dpctl_backend.DPCTLSyclWorkGroupMemoryRef)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//===---- dpctl_sycl_extension_interface.h - C API for SYCL ext -*-C++-*- ===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2025 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This header declares a C interface to SYCL language extensions defined by
23+
/// DPC++.
24+
///
25+
//===----------------------------------------------------------------------===//
26+
27+
#pragma once
28+
29+
#include "Support/DllExport.h"
30+
#include "Support/ExternC.h"
31+
#include "Support/MemOwnershipAttrs.h"
32+
#include "dpctl_data_types.h"
33+
#include "dpctl_error_handler_type.h"
34+
#include "dpctl_sycl_enum_types.h"
35+
#include "dpctl_sycl_types.h"
36+
37+
DPCTL_C_EXTERN_C_BEGIN
38+
39+
typedef struct RawWorkGroupMemoryTy {
40+
size_t nbytes;
41+
} RawWorkGroupMemory;
42+
43+
typedef struct DPCTLOpaqueSyclWorkGroupMemory *DPCTLSyclWorkGroupMemoryRef;
44+
45+
DPCTL_API
46+
__dpctl_give DPCTLSyclWorkGroupMemoryRef
47+
DPCTLWorkGroupMemory_Create(size_t nbytes);
48+
49+
DPCTL_API
50+
void DPCTLWorkGroupMemory_Delete(__dpctl_take DPCTLSyclWorkGroupMemoryRef Ref);
51+
52+
DPCTL_API
53+
bool DPCTLWorkGroupMemory_Available();
54+
55+
DPCTL_C_EXTERN_C_END

libsyclinterface/include/syclinterface/dpctl_sycl_type_casters.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ DEFINE_SIMPLE_CONVERSION_FUNCTIONS(std::vector<DPCTLSyclPlatformRef>,
8080
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(std::vector<DPCTLSyclEventRef>,
8181
DPCTLEventVectorRef)
8282

83+
#include "dpctl_sycl_extension_interface.h"
84+
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(
85+
RawWorkGroupMemory,
86+
DPCTLSyclWorkGroupMemoryRef)
87+
8388
#endif
8489

8590
} // namespace dpctl::syclinterface
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//===---- dpctl_sycl_extension_interface.cpp - Implements C API for SYCL ext =//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2025 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file implements the data types and functions declared in
23+
/// dpctl_sycl_extension_interface.h.
24+
///
25+
//===----------------------------------------------------------------------===//
26+
27+
#include "dpctl_sycl_extension_interface.h"
28+
29+
#include "dpctl_error_handlers.h"
30+
#include "dpctl_sycl_type_casters.hpp"
31+
32+
#include <sycl/sycl.hpp>
33+
34+
using namespace dpctl::syclinterface;
35+
36+
DPCTL_API
37+
__dpctl_give DPCTLSyclWorkGroupMemoryRef
38+
DPCTLWorkGroupMemory_Create(size_t nbytes)
39+
{
40+
DPCTLSyclWorkGroupMemoryRef wgm = nullptr;
41+
try {
42+
auto WorkGroupMem = new RawWorkGroupMemory{nbytes};
43+
wgm = wrap<RawWorkGroupMemory>(WorkGroupMem);
44+
} catch (std::exception const &e) {
45+
error_handler(e, __FILE__, __func__, __LINE__);
46+
}
47+
return wgm;
48+
}
49+
50+
DPCTL_API
51+
void DPCTLWorkGroupMemory_Delete(__dpctl_take DPCTLSyclWorkGroupMemoryRef Ref)
52+
{
53+
delete unwrap<RawWorkGroupMemory>(Ref);
54+
}
55+
56+
DPCTL_API
57+
bool DPCTLWorkGroupMemory_Available()
58+
{
59+
#ifdef SYCL_EXT_ONEAPI_WORK_GROUP_MEMORY
60+
return true;
61+
#else
62+
return false;
63+
#endif
64+
}
65+

libsyclinterface/source/dpctl_sycl_queue_interface.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
#include <sycl/sycl.hpp> /* SYCL headers */
4343
#include <utility>
4444

45+
#ifdef SYCL_EXT_ONEAPI_WORK_GROUP_MEMORY
46+
#include "dpctl_sycl_extension_interface.h"
47+
#endif
48+
4549
using namespace sycl;
4650

4751
#define SET_LOCAL_ACCESSOR_ARG(CGH, NDIM, ARGTY, R, IDX) \
@@ -216,14 +220,18 @@ bool set_kernel_arg(handler &cgh,
216220
case DPCTL_LOCAL_ACCESSOR:
217221
arg_set = set_local_accessor_arg(cgh, idx, (MDLocalAccessor *)Arg);
218222
break;
223+
#ifdef SYCL_EXT_ONEAPI_WORK_GROUP_MEMORY
219224
case DPCTL_WORK_GROUP_MEMORY:
220225
{
221-
size_t num_bytes = reinterpret_cast<std::uintptr_t>(Arg);
226+
auto ref = static_cast<DPCTLSyclWorkGroupMemoryRef>(Arg);
227+
RawWorkGroupMemory *raw_mem = unwrap<RawWorkGroupMemory>(ref);
228+
size_t num_bytes = raw_mem->nbytes;
222229
sycl::ext::oneapi::experimental::work_group_memory<char[]> mem{
223230
num_bytes, cgh};
224231
cgh.set_arg(idx, mem);
225232
break;
226233
}
234+
#endif
227235
default:
228236
arg_set = false;
229237
break;

0 commit comments

Comments
 (0)