diff --git a/sycl/plugins/level_zero/pi_level_zero.cpp b/sycl/plugins/level_zero/pi_level_zero.cpp index 65d3a29fcba06..bb545d40dcd51 100644 --- a/sycl/plugins/level_zero/pi_level_zero.cpp +++ b/sycl/plugins/level_zero/pi_level_zero.cpp @@ -2954,9 +2954,22 @@ piEnqueueKernelLaunch(pi_queue Queue, pi_kernel Kernel, pi_uint32 WorkDim, return PI_INVALID_VALUE; } - assert(GlobalWorkSize[0] == (ZeThreadGroupDimensions.groupCountX * WG[0])); - assert(GlobalWorkSize[1] == (ZeThreadGroupDimensions.groupCountY * WG[1])); - assert(GlobalWorkSize[2] == (ZeThreadGroupDimensions.groupCountZ * WG[2])); + // Error handling for non-uniform group size case + if (GlobalWorkSize[0] != (ZeThreadGroupDimensions.groupCountX * WG[0])) { + zePrint("piEnqueueKernelLaunch: invalid work_dim. The range is not a " + "multiple of the group size in the 1st dimension\n"); + return PI_INVALID_WORK_GROUP_SIZE; + } + if (GlobalWorkSize[1] != (ZeThreadGroupDimensions.groupCountY * WG[1])) { + zePrint("piEnqueueKernelLaunch: invalid work_dim. The range is not a " + "multiple of the group size in the 2nd dimension\n"); + return PI_INVALID_WORK_GROUP_SIZE; + } + if (GlobalWorkSize[2] != (ZeThreadGroupDimensions.groupCountZ * WG[2])) { + zePrint("piEnqueueKernelLaunch: invalid work_dim. The range is not a " + "multiple of the group size in the 3rd dimension\n"); + return PI_INVALID_WORK_GROUP_SIZE; + } ZE_CALL(zeKernelSetGroupSize(Kernel->ZeKernel, WG[0], WG[1], WG[2]));