From bfae55d85e5eb473420704de382d74d1a0d95e96 Mon Sep 17 00:00:00 2001 From: Igor Chorazewicz Date: Thu, 2 Jan 2025 18:31:07 +0000 Subject: [PATCH] [L0 v2] implement urKernelSuggestMaxCooperativeGroupCountExp --- source/adapters/level_zero/v2/api.cpp | 8 -------- source/adapters/level_zero/v2/kernel.cpp | 20 ++++++++++++++++++++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/source/adapters/level_zero/v2/api.cpp b/source/adapters/level_zero/v2/api.cpp index 33ff023bb8..04ed82c03c 100644 --- a/source/adapters/level_zero/v2/api.cpp +++ b/source/adapters/level_zero/v2/api.cpp @@ -474,14 +474,6 @@ ur_result_t urCommandBufferCommandGetInfoExp( return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } -ur_result_t urKernelSuggestMaxCooperativeGroupCountExp( - ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, uint32_t workDim, - const size_t *pLocalWorkSize, size_t dynamicSharedMemorySize, - uint32_t *pGroupCountRet) { - logger::error("{} function not implemented!", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; -} - ur_result_t urUSMImportExp(ur_context_handle_t hContext, void *pMem, size_t size) { logger::error("{} function not implemented!", __FUNCTION__); diff --git a/source/adapters/level_zero/v2/kernel.cpp b/source/adapters/level_zero/v2/kernel.cpp index 9313c56395..dcf58d5b62 100644 --- a/source/adapters/level_zero/v2/kernel.cpp +++ b/source/adapters/level_zero/v2/kernel.cpp @@ -649,4 +649,24 @@ ur_result_t urKernelGetSuggestedLocalWorkSize( std::copy(localWorkSize, localWorkSize + workDim, pSuggestedLocalWorkSize); return UR_RESULT_SUCCESS; } + +ur_result_t urKernelSuggestMaxCooperativeGroupCountExp( + ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, uint32_t workDim, + const size_t *pLocalWorkSize, size_t dynamicSharedMemorySize, + uint32_t *pGroupCountRet) { + (void)dynamicSharedMemorySize; + + uint32_t wg[3]; + wg[0] = ur_cast(pLocalWorkSize[0]); + wg[1] = workDim >= 2 ? ur_cast(pLocalWorkSize[1]) : 1; + wg[2] = workDim == 3 ? ur_cast(pLocalWorkSize[2]) : 1; + ZE2UR_CALL(zeKernelSetGroupSize, + (hKernel->getZeHandle(hDevice), wg[0], wg[1], wg[2])); + + uint32_t totalGroupCount = 0; + ZE2UR_CALL(zeKernelSuggestMaxCooperativeGroupCount, + (hKernel->getZeHandle(hDevice), &totalGroupCount)); + *pGroupCountRet = totalGroupCount; + return UR_RESULT_SUCCESS; +} } // namespace ur::level_zero