diff --git a/unified-runtime/source/adapters/native_cpu/enqueue.cpp b/unified-runtime/source/adapters/native_cpu/enqueue.cpp index 4c780031f8cf7..148f1dca58a7c 100644 --- a/unified-runtime/source/adapters/native_cpu/enqueue.cpp +++ b/unified-runtime/source/adapters/native_cpu/enqueue.cpp @@ -50,8 +50,42 @@ struct NDRDescT { << GlobalOffset[2] << "\n"; } }; + +namespace { +class WaitInfo { + std::vector *const events; + static_assert(std::is_pointer_v); + +public: + WaitInfo(uint32_t numEvents, const ur_event_handle_t *WaitList) + : events(numEvents ? new std::vector( + WaitList, WaitList + numEvents) + : nullptr) {} + void wait() const { + if (events) + urEventWait(events->size(), events->data()); + } + std::unique_ptr> getUniquePtr() { + return std::unique_ptr>(events); + } +}; + +inline static WaitInfo getWaitInfo(uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList) { + return native_cpu::WaitInfo(numEventsInWaitList, phEventWaitList); +} + +} // namespace } // namespace native_cpu +static inline native_cpu::state getState(const native_cpu::NDRDescT &ndr) { + native_cpu::state resized_state( + ndr.GlobalSize[0], ndr.GlobalSize[1], ndr.GlobalSize[2], ndr.LocalSize[0], + ndr.LocalSize[1], ndr.LocalSize[2], ndr.GlobalOffset[0], + ndr.GlobalOffset[1], ndr.GlobalOffset[2]); + return resized_state; +} + UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize, @@ -67,7 +101,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( } } - urEventWait(numEventsInWaitList, phEventWaitList); UR_ASSERT(hQueue, UR_RESULT_ERROR_INVALID_NULL_HANDLE); UR_ASSERT(hKernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE); UR_ASSERT(workDim > 0, UR_RESULT_ERROR_INVALID_WORK_DIMENSION); @@ -119,14 +152,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( auto &tp = hQueue->getDevice()->tp; const size_t numParallelThreads = tp.num_threads(); std::vector> futures; - std::vector> groups; auto numWG0 = ndr.GlobalSize[0] / ndr.LocalSize[0]; auto numWG1 = ndr.GlobalSize[1] / ndr.LocalSize[1]; auto numWG2 = ndr.GlobalSize[2] / ndr.LocalSize[2]; - native_cpu::state state(ndr.GlobalSize[0], ndr.GlobalSize[1], - ndr.GlobalSize[2], ndr.LocalSize[0], ndr.LocalSize[1], - ndr.LocalSize[2], ndr.GlobalOffset[0], - ndr.GlobalOffset[1], ndr.GlobalOffset[2]); auto event = new ur_event_handle_t_(hQueue, UR_COMMAND_KERNEL_LAUNCH); event->tick_start(); @@ -134,6 +162,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( auto kernel = std::make_unique(*hKernel); kernel->updateMemPool(numParallelThreads); + auto InEvents = native_cpu::getWaitInfo(numEventsInWaitList, phEventWaitList); + const size_t numWG = numWG0 * numWG1 * numWG2; const size_t numWGPerThread = numWG / numParallelThreads; const size_t remainderWG = numWG - numWGPerThread * numParallelThreads; @@ -147,13 +177,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( rangeEnd[0] = rangeEnd[3] % numWG0; rangeEnd[1] = (rangeEnd[3] / numWG0) % numWG1; rangeEnd[2] = rangeEnd[3] / (numWG0 * numWG1); - futures.emplace_back( - tp.schedule_task([state, &kernel = *kernel, rangeStart, - rangeEnd = rangeEnd[3], numWG0, numWG1, -#ifndef NATIVECPU_USE_OCK - localSize = ndr.LocalSize, -#endif - numParallelThreads](size_t threadId) mutable { + futures.emplace_back(tp.schedule_task( + [ndr, InEvents, &kernel = *kernel, rangeStart, rangeEnd = rangeEnd[3], + numWG0, numWG1, numParallelThreads](size_t threadId) { + auto state = getState(ndr); + InEvents.wait(); for (size_t g0 = rangeStart[0], g1 = rangeStart[1], g2 = rangeStart[2], g3 = rangeStart[3]; g3 < rangeEnd; ++g3) { @@ -162,9 +190,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( kernel._subhandler( kernel.getArgs(numParallelThreads, threadId).data(), &state); #else - for (size_t local2 = 0; local2 < localSize[2]; ++local2) { - for (size_t local1 = 0; local1 < localSize[1]; ++local1) { - for (size_t local0 = 0; local0 < localSize[0]; ++local0) { + for (size_t local2 = 0; local2 < ndr.LocalSize[2]; ++local2) { + for (size_t local1 = 0; local1 < ndr.LocalSize[1]; ++local1) { + for (size_t local0 = 0; local0 < ndr.LocalSize[0]; ++local0) { state.update(g0, g1, g2, local0, local1, local2); kernel._subhandler( kernel.getArgs(numParallelThreads, threadId).data(), @@ -189,7 +217,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( if (phEvent) { *phEvent = event; } - event->set_callback([kernel = std::move(kernel), hKernel, event]() { + event->set_callback([kernel = std::move(kernel), hKernel, event, + InEvents = InEvents.getUniquePtr()]() { event->tick_end(); // TODO: avoid calling clear() here. hKernel->_localArgInfo.clear(); @@ -207,20 +236,32 @@ static inline ur_result_t withTimingEvent(ur_command_t command_type, ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, - ur_event_handle_t *phEvent, T &&f) { - urEventWait(numEventsInWaitList, phEventWaitList); - ur_event_handle_t event = nullptr; + ur_event_handle_t *phEvent, T &&f, bool blocking = true) { if (phEvent) { - event = new ur_event_handle_t_(hQueue, command_type); + ur_event_handle_t event = new ur_event_handle_t_(hQueue, command_type); + *phEvent = event; event->tick_start(); + if (blocking || hQueue->isInOrder()) { + urEventWait(numEventsInWaitList, phEventWaitList); + ur_result_t result = f(); + event->tick_end(); + return result; + } + auto &tp = hQueue->getDevice()->tp; + std::vector> futures; + auto InEvents = + native_cpu::getWaitInfo(numEventsInWaitList, phEventWaitList); + futures.emplace_back(tp.schedule_task([f, InEvents](size_t) { + InEvents.wait(); + f(); + })); + event->set_futures(futures); + event->set_callback( + [event, InEvents = InEvents.getUniquePtr()]() { event->tick_end(); }); + return UR_RESULT_SUCCESS; } - + urEventWait(numEventsInWaitList, phEventWaitList); ur_result_t result = f(); - - if (phEvent) { - event->tick_end(); - *phEvent = event; - } return result; } @@ -231,7 +272,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait( // TODO: the wait here should be async return withTimingEvent(UR_COMMAND_EVENTS_WAIT, hQueue, numEventsInWaitList, phEventWaitList, phEvent, - [&]() { return UR_RESULT_SUCCESS; }); + []() { return UR_RESULT_SUCCESS; }); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier( @@ -239,7 +280,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier( const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { return withTimingEvent(UR_COMMAND_EVENTS_WAIT_WITH_BARRIER, hQueue, numEventsInWaitList, phEventWaitList, phEvent, - [&]() { return UR_RESULT_SUCCESS; }); + []() { return UR_RESULT_SUCCESS; }); } UR_APIEXPORT ur_result_t urEnqueueEventsWaitWithBarrierExt( @@ -250,9 +291,43 @@ UR_APIEXPORT ur_result_t urEnqueueEventsWaitWithBarrierExt( phEventWaitList, phEvent); } +template +static inline void MemBufferReadWriteRect_impl( + ur_mem_handle_t Buff, ur_rect_offset_t BufferOffset, + ur_rect_offset_t HostOffset, ur_rect_region_t region, size_t BufferRowPitch, + size_t BufferSlicePitch, size_t HostRowPitch, size_t HostSlicePitch, + typename std::conditional::type DstMem) { + // TODO: check other constraints, performance optimizations + // More sharing with level_zero where possible + + if (BufferRowPitch == 0) + BufferRowPitch = region.width; + if (BufferSlicePitch == 0) + BufferSlicePitch = BufferRowPitch * region.height; + if (HostRowPitch == 0) + HostRowPitch = region.width; + if (HostSlicePitch == 0) + HostSlicePitch = HostRowPitch * region.height; + for (size_t w = 0; w < region.width; w++) + for (size_t h = 0; h < region.height; h++) + for (size_t d = 0; d < region.depth; d++) { + size_t buff_orign = (d + BufferOffset.z) * BufferSlicePitch + + (h + BufferOffset.y) * BufferRowPitch + w + + BufferOffset.x; + size_t host_origin = (d + HostOffset.z) * HostSlicePitch + + (h + HostOffset.y) * HostRowPitch + w + + HostOffset.x; + int8_t &buff_mem = ur_cast(Buff->_mem)[buff_orign]; + if constexpr (IsRead) + ur_cast(DstMem)[host_origin] = buff_mem; + else + buff_mem = ur_cast(DstMem)[host_origin]; + } +} + template static inline ur_result_t enqueueMemBufferReadWriteRect_impl( - ur_queue_handle_t hQueue, ur_mem_handle_t Buff, bool, + ur_queue_handle_t hQueue, ur_mem_handle_t Buff, bool blocking, ur_rect_offset_t BufferOffset, ur_rect_offset_t HostOffset, ur_rect_region_t region, size_t BufferRowPitch, size_t BufferSlicePitch, size_t HostRowPitch, size_t HostSlicePitch, @@ -265,71 +340,63 @@ static inline ur_result_t enqueueMemBufferReadWriteRect_impl( else command_t = UR_COMMAND_MEM_BUFFER_WRITE_RECT; return withTimingEvent( - command_t, hQueue, NumEventsInWaitList, phEventWaitList, phEvent, [&]() { - // TODO: blocking, check other constraints, performance optimizations - // More sharing with level_zero where possible - - if (BufferRowPitch == 0) - BufferRowPitch = region.width; - if (BufferSlicePitch == 0) - BufferSlicePitch = BufferRowPitch * region.height; - if (HostRowPitch == 0) - HostRowPitch = region.width; - if (HostSlicePitch == 0) - HostSlicePitch = HostRowPitch * region.height; - for (size_t w = 0; w < region.width; w++) - for (size_t h = 0; h < region.height; h++) - for (size_t d = 0; d < region.depth; d++) { - size_t buff_orign = (d + BufferOffset.z) * BufferSlicePitch + - (h + BufferOffset.y) * BufferRowPitch + w + - BufferOffset.x; - size_t host_origin = (d + HostOffset.z) * HostSlicePitch + - (h + HostOffset.y) * HostRowPitch + w + - HostOffset.x; - int8_t &buff_mem = ur_cast(Buff->_mem)[buff_orign]; - if constexpr (IsRead) - ur_cast(DstMem)[host_origin] = buff_mem; - else - buff_mem = ur_cast(DstMem)[host_origin]; - } - + command_t, hQueue, NumEventsInWaitList, phEventWaitList, phEvent, + [BufferRowPitch, region, BufferSlicePitch, HostRowPitch, HostSlicePitch, + BufferOffset, HostOffset, Buff, DstMem]() { + MemBufferReadWriteRect_impl( + Buff, BufferOffset, HostOffset, region, BufferRowPitch, + BufferSlicePitch, HostRowPitch, HostSlicePitch, DstMem); return UR_RESULT_SUCCESS; - }); + }, + blocking); } -static inline ur_result_t doCopy_impl(ur_queue_handle_t hQueue, void *DstPtr, - const void *SrcPtr, size_t Size, - uint32_t numEventsInWaitList, - const ur_event_handle_t *phEventWaitList, - ur_event_handle_t *phEvent, - ur_command_t command_type) { - return withTimingEvent(command_type, hQueue, numEventsInWaitList, - phEventWaitList, phEvent, [&]() { - if (SrcPtr != DstPtr && Size) - memmove(DstPtr, SrcPtr, Size); - return UR_RESULT_SUCCESS; - }); +template +static inline ur_result_t doCopy_impl( + ur_queue_handle_t hQueue, void *DstPtr, const void *SrcPtr, size_t Size, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent, ur_command_t command_type, bool blocking) { + if (SrcPtr == DstPtr || Size == 0) { + bool hasInEvents = numEventsInWaitList && phEventWaitList; + return withTimingEvent( + command_type, hQueue, numEventsInWaitList, phEventWaitList, phEvent, + []() { return UR_RESULT_SUCCESS; }, blocking || !hasInEvents); + } + + return withTimingEvent( + command_type, hQueue, numEventsInWaitList, phEventWaitList, phEvent, + [DstPtr, SrcPtr, Size]() { + if constexpr (AllowPartialOverlap) { + memmove(DstPtr, SrcPtr, Size); + } else { + memcpy(DstPtr, SrcPtr, Size); + } + return UR_RESULT_SUCCESS; + }, + blocking); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead( - ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool /*blockingRead*/, + ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingRead, size_t offset, size_t size, void *pDst, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { void *FromPtr = /*Src*/ hBuffer->_mem + offset; auto res = doCopy_impl(hQueue, pDst, FromPtr, size, numEventsInWaitList, - phEventWaitList, phEvent, UR_COMMAND_MEM_BUFFER_READ); + phEventWaitList, phEvent, UR_COMMAND_MEM_BUFFER_READ, + blockingRead); return res; } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite( - ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool /*blockingWrite*/, + ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingWrite, size_t offset, size_t size, const void *pSrc, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { void *ToPtr = hBuffer->_mem + offset; auto res = doCopy_impl(hQueue, ToPtr, pSrc, size, numEventsInWaitList, - phEventWaitList, phEvent, UR_COMMAND_MEM_BUFFER_WRITE); + phEventWaitList, phEvent, UR_COMMAND_MEM_BUFFER_WRITE, + blockingWrite); return res; } @@ -368,7 +435,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy( const void *SrcPtr = hBufferSrc->_mem + srcOffset; void *DstPtr = hBufferDst->_mem + dstOffset; return doCopy_impl(hQueue, DstPtr, SrcPtr, size, numEventsInWaitList, - phEventWaitList, phEvent, UR_COMMAND_MEM_BUFFER_COPY); + phEventWaitList, phEvent, UR_COMMAND_MEM_BUFFER_COPY, + true /*TODO: check false for non-blocking*/); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( @@ -379,7 +447,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { return enqueueMemBufferReadWriteRect_impl( - hQueue, hBufferSrc, false /*todo: check blocking*/, srcOrigin, + hQueue, hBufferSrc, true /*todo: check false for non-blocking*/, + srcOrigin, /*HostOffset*/ dstOrigin, region, srcRowPitch, srcSlicePitch, dstRowPitch, dstSlicePitch, hBufferDst->_mem, numEventsInWaitList, phEventWaitList, phEvent); @@ -390,12 +459,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill( size_t patternSize, size_t offset, size_t size, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - + UR_ASSERT(hQueue, UR_RESULT_ERROR_INVALID_NULL_HANDLE); return withTimingEvent( UR_COMMAND_MEM_BUFFER_FILL, hQueue, numEventsInWaitList, phEventWaitList, - phEvent, [&]() { - UR_ASSERT(hQueue, UR_RESULT_ERROR_INVALID_NULL_HANDLE); - + phEvent, [hBuffer, offset, size, patternSize, pPattern]() { // TODO: error checking // TODO: handle async void *startingPtr = hBuffer->_mem + offset; @@ -449,7 +516,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap( ur_event_handle_t *phEvent, void **ppRetMap) { return withTimingEvent(UR_COMMAND_MEM_BUFFER_MAP, hQueue, numEventsInWaitList, - phEventWaitList, phEvent, [&]() { + phEventWaitList, phEvent, + [ppRetMap, hBuffer, offset]() { *ppRetMap = hBuffer->_mem + offset; return UR_RESULT_SUCCESS; }); @@ -461,7 +529,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap( ur_event_handle_t *phEvent) { return withTimingEvent(UR_COMMAND_MEM_UNMAP, hQueue, numEventsInWaitList, phEventWaitList, phEvent, - [&]() { return UR_RESULT_SUCCESS; }); + []() { return UR_RESULT_SUCCESS; }); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( @@ -470,7 +538,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { return withTimingEvent( UR_COMMAND_USM_FILL, hQueue, numEventsInWaitList, phEventWaitList, - phEvent, [&]() { + phEvent, [ptr, pPattern, patternSize, size]() { UR_ASSERT(ptr, UR_RESULT_ERROR_INVALID_NULL_POINTER); UR_ASSERT(pPattern, UR_RESULT_ERROR_INVALID_NULL_POINTER); UR_ASSERT(patternSize != 0, UR_RESULT_ERROR_INVALID_SIZE) @@ -520,20 +588,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy( - ur_queue_handle_t hQueue, bool /*blocking*/, void *pDst, const void *pSrc, + ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc, size_t size, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - return withTimingEvent( - UR_COMMAND_USM_MEMCPY, hQueue, numEventsInWaitList, phEventWaitList, - phEvent, [&]() { - UR_ASSERT(hQueue, UR_RESULT_ERROR_INVALID_QUEUE); - UR_ASSERT(pDst, UR_RESULT_ERROR_INVALID_NULL_POINTER); - UR_ASSERT(pSrc, UR_RESULT_ERROR_INVALID_NULL_POINTER); - - memcpy(pDst, pSrc, size); + UR_ASSERT(hQueue, UR_RESULT_ERROR_INVALID_QUEUE); + UR_ASSERT(pDst, UR_RESULT_ERROR_INVALID_NULL_POINTER); + UR_ASSERT(pSrc, UR_RESULT_ERROR_INVALID_NULL_POINTER); - return UR_RESULT_SUCCESS; - }); + return doCopy_impl( + hQueue, pDst, pSrc, size, numEventsInWaitList, phEventWaitList, phEvent, + UR_COMMAND_USM_MEMCPY, blocking); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(