diff --git a/sycl/include/CL/sycl/buffer.hpp b/sycl/include/CL/sycl/buffer.hpp index 9a43504d5de72..d3a43837a666f 100644 --- a/sycl/include/CL/sycl/buffer.hpp +++ b/sycl/include/CL/sycl/buffer.hpp @@ -18,7 +18,7 @@ namespace cl { namespace sycl { class handler; class queue; -template class range; +template class range; template @@ -28,6 +28,8 @@ class buffer { using reference = value_type &; using const_reference = const value_type &; using allocator_type = AllocatorT; + template + using EnableIfOneDimension = typename std::enable_if<1 == dims>::type; buffer(const range &bufferRange, const property_list &propList = {}) @@ -83,7 +85,8 @@ class buffer { hostData, get_count() * sizeof(T), propList); } - template + template > buffer(InputIterator first, InputIterator last, AllocatorT allocator, const property_list &propList = {}) : Range(range<1>(std::distance(first, last))) { @@ -92,7 +95,7 @@ class buffer { } template > + typename = EnableIfOneDimension> buffer(InputIterator first, InputIterator last, const property_list &propList = {}) : Range(range<1>(std::distance(first, last))) { @@ -105,11 +108,16 @@ class buffer { // impl = std::make_shared(b, baseIndex, subRange); // } - template > + template > buffer(cl_mem MemObject, const context &SyclContext, event AvailableEvent = {}) { + + size_t BufSize = 0; + CHECK_OCL_CODE(clGetMemObjectInfo(MemObject, CL_MEM_SIZE, sizeof(size_t), + &BufSize, nullptr)); + Range[0] = BufSize / sizeof(T); impl = std::make_shared>( - MemObject, SyclContext, AvailableEvent); + MemObject, SyclContext, BufSize, AvailableEvent); } buffer(const buffer &rhs) = default; diff --git a/sycl/include/CL/sycl/detail/buffer_impl.hpp b/sycl/include/CL/sycl/detail/buffer_impl.hpp index 321bacb2f5aae..4da8134a94424 100644 --- a/sycl/include/CL/sycl/detail/buffer_impl.hpp +++ b/sycl/include/CL/sycl/detail/buffer_impl.hpp @@ -119,8 +119,9 @@ template class buffer_impl { } buffer_impl(cl_mem MemObject, const context &SyclContext, - event AvailableEvent = {}) - : OpenCLInterop(true), AvailableEvent(AvailableEvent) { + const size_t sizeInBytes, event AvailableEvent = {}) + : OpenCLInterop(true), SizeInBytes(sizeInBytes), + AvailableEvent(AvailableEvent) { if (SyclContext.is_host()) throw cl::sycl::invalid_parameter_error( "Creation of interoperability buffer using host context is not " diff --git a/sycl/test/basic_tests/buffer/buffer_interop.cpp b/sycl/test/basic_tests/buffer/buffer_interop.cpp index 3d836c16812fc..f0bdaa2090711 100644 --- a/sycl/test/basic_tests/buffer/buffer_interop.cpp +++ b/sycl/test/basic_tests/buffer/buffer_interop.cpp @@ -16,10 +16,14 @@ using namespace cl::sycl; int main() { + bool Failed = false; { const size_t Size = 32; int Init[Size] = {5}; cl_int Error = CL_SUCCESS; + cl::sycl::range<1> InteropRange; + InteropRange[0] = Size; + size_t InteropSize = Size * sizeof(int); queue MyQueue; @@ -29,6 +33,19 @@ int main() { CHECK_OCL_CODE(Error); buffer Buffer(OpenCLBuffer, MyQueue.get_context()); + if (Buffer.get_range() != InteropRange) { + assert(false); + Failed = true; + } + if (Buffer.get_size() != InteropSize) { + assert(false); + Failed = true; + } + if (Buffer.get_count() != Size) { + assert(false); + Failed = true; + } + MyQueue.submit([&](handler &CGH) { auto B = Buffer.get_access(CGH); CGH.parallel_for( @@ -58,8 +75,9 @@ int main() { std::cout << " array[" << i << "] is " << Result[i] << " expected " << 20 << std::endl; assert(false); + Failed = true; } } } - return 0; + return Failed; }