@@ -139,6 +139,19 @@ if (GGML_METAL)
139
139
)
140
140
endif ()
141
141
142
+ if (GGML_MUSA)
143
+ set (CMAKE_C_COMPILER clang)
144
+ set (CMAKE_C_EXTENSIONS OFF )
145
+ set (CMAKE_CXX_COMPILER clang++)
146
+ set (CMAKE_CXX_EXTENSIONS OFF )
147
+
148
+ set (GGML_CUDA ON )
149
+
150
+ list (APPEND GGML_CDEF_PUBLIC GGML_USE_MUSA)
151
+
152
+ add_compile_definitions (GGML_USE_MUSA)
153
+ endif ()
154
+
142
155
if (GGML_OPENMP)
143
156
find_package (OpenMP)
144
157
if (OpenMP_FOUND)
@@ -147,6 +160,11 @@ if (GGML_OPENMP)
147
160
add_compile_definitions (GGML_USE_OPENMP)
148
161
149
162
set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
163
+
164
+ if (GGML_MUSA)
165
+ set (GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} "/usr/lib/llvm-10/include/openmp" )
166
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} "/usr/lib/llvm-10/lib/libomp.so" )
167
+ endif ()
150
168
else ()
151
169
message (WARNING "OpenMP not found" )
152
170
endif ()
@@ -249,7 +267,13 @@ endif()
249
267
if (GGML_CUDA)
250
268
cmake_minimum_required (VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
251
269
252
- find_package (CUDAToolkit)
270
+ if (GGML_MUSA)
271
+ list (APPEND CMAKE_MODULE_PATH "/usr/local/musa/cmake/" )
272
+ find_package (MUSAToolkit)
273
+ set (CUDAToolkit_FOUND ${MUSAToolkit_FOUND} )
274
+ else ()
275
+ find_package (CUDAToolkit)
276
+ endif ()
253
277
254
278
if (CUDAToolkit_FOUND)
255
279
message (STATUS "CUDA found" )
@@ -268,7 +292,11 @@ if (GGML_CUDA)
268
292
endif ()
269
293
message (STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES} " )
270
294
271
- enable_language (CUDA)
295
+ if (GGML_MUSA)
296
+ set (CMAKE_CUDA_COMPILER ${MUSAToolkit_MCC_EXECUTABLE} )
297
+ else ()
298
+ enable_language (CUDA)
299
+ endif ()
272
300
273
301
file (GLOB GGML_HEADERS_CUDA "ggml-cuda/*.cuh" )
274
302
list (APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h" )
@@ -332,21 +360,40 @@ if (GGML_CUDA)
332
360
add_compile_definitions (GGML_CUDA_NO_PEER_COPY)
333
361
endif ()
334
362
363
+ if (GGML_MUSA)
364
+ set_source_files_properties (${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
365
+ foreach (SOURCE ${GGML_SOURCES_CUDA} )
366
+ set_property (SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22" )
367
+ endforeach ()
368
+ endif ()
369
+
335
370
if (GGML_STATIC)
336
371
if (WIN32 )
337
372
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
338
373
set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
339
374
else ()
340
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
375
+ if (GGML_MUSA)
376
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart_static MUSA::mublas_static)
377
+ else ()
378
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
379
+ endif ()
341
380
endif ()
342
381
else ()
343
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
382
+ if (GGML_MUSA)
383
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart MUSA::mublas)
384
+ else ()
385
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
386
+ endif ()
344
387
endif ()
345
388
346
389
if (GGML_CUDA_NO_VMM)
347
390
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
348
391
else ()
349
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
392
+ if (GGML_MUSA)
393
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musa_driver) # required by muDeviceGetAttribute(), muMemGetAllocationGranularity(...), ...
394
+ else ()
395
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
396
+ endif ()
350
397
endif ()
351
398
else ()
352
399
message (WARNING "CUDA not found" )
@@ -757,8 +804,10 @@ function(get_flags CCID CCVER)
757
804
set (C_FLAGS -Wdouble-promotion)
758
805
set (CXX_FLAGS -Wno-array-bounds)
759
806
760
- if (CCVER VERSION_GREATER_EQUAL 7.1.0)
761
- list (APPEND CXX_FLAGS -Wno-format-truncation)
807
+ if (NOT GGML_MUSA)
808
+ if (CCVER VERSION_GREATER_EQUAL 7.1.0)
809
+ list (APPEND CXX_FLAGS -Wno-format-truncation)
810
+ endif ()
762
811
endif ()
763
812
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
764
813
list (APPEND CXX_FLAGS -Wextra-semi)
@@ -1059,7 +1108,9 @@ if (GGML_CUDA)
1059
1108
list (JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument
1060
1109
1061
1110
if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "" )
1062
- list (APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED} )
1111
+ # list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})
1112
+ # XXX: Removed flags: -Xcompiler
1113
+ list (APPEND CUDA_FLAGS ${CUDA_CXX_FLAGS_JOINED} )
1063
1114
endif ()
1064
1115
1065
1116
add_compile_options ("$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS} >" )
@@ -1163,6 +1214,7 @@ endif()
1163
1214
target_compile_definitions (ggml PUBLIC ${GGML_CDEF_PUBLIC} )
1164
1215
target_include_directories (ggml PUBLIC ../include )
1165
1216
target_include_directories (ggml PRIVATE . ${GGML_EXTRA_INCLUDES} )
1217
+ target_link_directories (ggml PRIVATE ${GGML_EXTRA_LIBDIRS} )
1166
1218
target_compile_features (ggml PRIVATE c_std_11) # don't bump
1167
1219
1168
1220
target_link_libraries (ggml PRIVATE Threads::Threads ${GGML_EXTRA_LIBS} )
0 commit comments