Skip to content

Commit ca22c6d

Browse files
committed
Add GGML_MUSA in CMake
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 59db73d commit ca22c6d

File tree

2 files changed

+61
-8
lines changed

2 files changed

+61
-8
lines changed

ggml/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
107107
option(GGML_LLAMAFILE "ggml: use LLAMAFILE" OFF)
108108

109109
option(GGML_CUDA "ggml: use CUDA" OFF)
110+
option(GGML_MUSA "ggml: use MUSA" OFF)
110111
option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
111112
option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
112113
option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)

ggml/src/CMakeLists.txt

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,19 @@ if (GGML_METAL)
139139
)
140140
endif()
141141

142+
if (GGML_MUSA)
143+
set(CMAKE_C_COMPILER clang)
144+
set(CMAKE_C_EXTENSIONS OFF)
145+
set(CMAKE_CXX_COMPILER clang++)
146+
set(CMAKE_CXX_EXTENSIONS OFF)
147+
148+
set(GGML_CUDA ON)
149+
150+
list(APPEND GGML_CDEF_PUBLIC GGML_USE_MUSA)
151+
152+
add_compile_definitions(GGML_USE_MUSA)
153+
endif()
154+
142155
if (GGML_OPENMP)
143156
find_package(OpenMP)
144157
if (OpenMP_FOUND)
@@ -147,6 +160,11 @@ if (GGML_OPENMP)
147160
add_compile_definitions(GGML_USE_OPENMP)
148161

149162
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
163+
164+
if (GGML_MUSA)
165+
set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} "/usr/lib/llvm-10/include/openmp")
166+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} "/usr/lib/llvm-10/lib/libomp.so")
167+
endif()
150168
else()
151169
message(WARNING "OpenMP not found")
152170
endif()
@@ -249,7 +267,13 @@ endif()
249267
if (GGML_CUDA)
250268
cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
251269

252-
find_package(CUDAToolkit)
270+
if (GGML_MUSA)
271+
list(APPEND CMAKE_MODULE_PATH "/usr/local/musa/cmake/")
272+
find_package(MUSAToolkit)
273+
set(CUDAToolkit_FOUND ${MUSAToolkit_FOUND})
274+
else()
275+
find_package(CUDAToolkit)
276+
endif()
253277

254278
if (CUDAToolkit_FOUND)
255279
message(STATUS "CUDA found")
@@ -268,7 +292,11 @@ if (GGML_CUDA)
268292
endif()
269293
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
270294

271-
enable_language(CUDA)
295+
if (GGML_MUSA)
296+
set(CMAKE_CUDA_COMPILER ${MUSAToolkit_MCC_EXECUTABLE})
297+
else()
298+
enable_language(CUDA)
299+
endif()
272300

273301
file(GLOB GGML_HEADERS_CUDA "ggml-cuda/*.cuh")
274302
list(APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h")
@@ -332,21 +360,40 @@ if (GGML_CUDA)
332360
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
333361
endif()
334362

363+
if (GGML_MUSA)
364+
set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
365+
foreach(SOURCE ${GGML_SOURCES_CUDA})
366+
set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22")
367+
endforeach()
368+
endif()
369+
335370
if (GGML_STATIC)
336371
if (WIN32)
337372
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
338373
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
339374
else ()
340-
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
375+
if (GGML_MUSA)
376+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart_static MUSA::mublas_static)
377+
else()
378+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
379+
endif()
341380
endif()
342381
else()
343-
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
382+
if (GGML_MUSA)
383+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart MUSA::mublas)
384+
else()
385+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
386+
endif()
344387
endif()
345388

346389
if (GGML_CUDA_NO_VMM)
347390
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
348391
else()
349-
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
392+
if (GGML_MUSA)
393+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musa_driver) # required by muDeviceGetAttribute(), muMemGetAllocationGranularity(...), ...
394+
else()
395+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
396+
endif()
350397
endif()
351398
else()
352399
message(WARNING "CUDA not found")
@@ -757,8 +804,10 @@ function(get_flags CCID CCVER)
757804
set(C_FLAGS -Wdouble-promotion)
758805
set(CXX_FLAGS -Wno-array-bounds)
759806

760-
if (CCVER VERSION_GREATER_EQUAL 7.1.0)
761-
list(APPEND CXX_FLAGS -Wno-format-truncation)
807+
if (NOT GGML_MUSA)
808+
if (CCVER VERSION_GREATER_EQUAL 7.1.0)
809+
list(APPEND CXX_FLAGS -Wno-format-truncation)
810+
endif()
762811
endif()
763812
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
764813
list(APPEND CXX_FLAGS -Wextra-semi)
@@ -1059,7 +1108,9 @@ if (GGML_CUDA)
10591108
list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument
10601109

10611110
if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "")
1062-
list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})
1111+
# list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})
1112+
# XXX: Removed flags: -Xcompiler
1113+
list(APPEND CUDA_FLAGS ${CUDA_CXX_FLAGS_JOINED})
10631114
endif()
10641115

10651116
add_compile_options("$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>")
@@ -1163,6 +1214,7 @@ endif()
11631214
target_compile_definitions(ggml PUBLIC ${GGML_CDEF_PUBLIC})
11641215
target_include_directories(ggml PUBLIC ../include)
11651216
target_include_directories(ggml PRIVATE . ${GGML_EXTRA_INCLUDES})
1217+
target_link_directories(ggml PRIVATE ${GGML_EXTRA_LIBDIRS})
11661218
target_compile_features (ggml PRIVATE c_std_11) # don't bump
11671219

11681220
target_link_libraries(ggml PRIVATE Threads::Threads ${GGML_EXTRA_LIBS})

0 commit comments

Comments
 (0)