@@ -139,6 +139,17 @@ if (GGML_METAL)
139
139
)
140
140
endif ()
141
141
142
+ if (GGML_MUSA)
143
+ set (CMAKE_C_COMPILER clang)
144
+ set (CMAKE_C_EXTENSIONS OFF )
145
+ set (CMAKE_CXX_COMPILER clang++)
146
+ set (CMAKE_CXX_EXTENSIONS OFF )
147
+
148
+ set (GGML_CUDA ON )
149
+
150
+ list (APPEND GGML_CDEF_PUBLIC GGML_USE_MUSA)
151
+ endif ()
152
+
142
153
if (GGML_OPENMP)
143
154
find_package (OpenMP)
144
155
if (OpenMP_FOUND)
@@ -147,6 +158,11 @@ if (GGML_OPENMP)
147
158
add_compile_definitions (GGML_USE_OPENMP)
148
159
149
160
set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
161
+
162
+ if (GGML_MUSA)
163
+ set (GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} "/usr/lib/llvm-10/include/openmp" )
164
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} "/usr/lib/llvm-10/lib/libomp.so" )
165
+ endif ()
150
166
else ()
151
167
message (WARNING "OpenMP not found" )
152
168
endif ()
@@ -249,7 +265,13 @@ endif()
249
265
if (GGML_CUDA)
250
266
cmake_minimum_required (VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
251
267
252
- find_package (CUDAToolkit)
268
+ if (GGML_MUSA)
269
+ list (APPEND CMAKE_MODULE_PATH "/usr/local/musa/cmake/" )
270
+ find_package (MUSAToolkit)
271
+ set (CUDAToolkit_FOUND ${MUSAToolkit_FOUND} )
272
+ else ()
273
+ find_package (CUDAToolkit)
274
+ endif ()
253
275
254
276
if (CUDAToolkit_FOUND)
255
277
message (STATUS "CUDA found" )
@@ -268,7 +290,11 @@ if (GGML_CUDA)
268
290
endif ()
269
291
message (STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES} " )
270
292
271
- enable_language (CUDA)
293
+ if (GGML_MUSA)
294
+ set (CMAKE_CUDA_COMPILER ${MUSAToolkit_MCC_EXECUTABLE} )
295
+ else ()
296
+ enable_language (CUDA)
297
+ endif ()
272
298
273
299
file (GLOB GGML_HEADERS_CUDA "ggml-cuda/*.cuh" )
274
300
list (APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h" )
@@ -332,21 +358,40 @@ if (GGML_CUDA)
332
358
add_compile_definitions (GGML_CUDA_NO_PEER_COPY)
333
359
endif ()
334
360
361
+ if (GGML_MUSA)
362
+ set_source_files_properties (${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
363
+ foreach (SOURCE ${GGML_SOURCES_CUDA} )
364
+ set_property (SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22" )
365
+ endforeach ()
366
+ endif ()
367
+
335
368
if (GGML_STATIC)
336
369
if (WIN32 )
337
370
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
338
371
set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
339
372
else ()
340
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
373
+ if (GGML_MUSA)
374
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart_static MUSA::mublas_static)
375
+ else ()
376
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
377
+ endif ()
341
378
endif ()
342
379
else ()
343
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
380
+ if (GGML_MUSA)
381
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart MUSA::mublas)
382
+ else ()
383
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
384
+ endif ()
344
385
endif ()
345
386
346
387
if (GGML_CUDA_NO_VMM)
347
388
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
348
389
else ()
349
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
390
+ if (GGML_MUSA)
391
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musa_driver) # required by muDeviceGetAttribute(), muMemGetAllocationGranularity(...), ...
392
+ else ()
393
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
394
+ endif ()
350
395
endif ()
351
396
else ()
352
397
message (WARNING "CUDA not found" )
@@ -757,8 +802,10 @@ function(get_flags CCID CCVER)
757
802
set (C_FLAGS -Wdouble-promotion)
758
803
set (CXX_FLAGS -Wno-array-bounds)
759
804
760
- if (CCVER VERSION_GREATER_EQUAL 7.1.0)
761
- list (APPEND CXX_FLAGS -Wno-format-truncation)
805
+ if (NOT GGML_MUSA)
806
+ if (CCVER VERSION_GREATER_EQUAL 7.1.0)
807
+ list (APPEND CXX_FLAGS -Wno-format-truncation)
808
+ endif ()
762
809
endif ()
763
810
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
764
811
list (APPEND CXX_FLAGS -Wextra-semi)
@@ -1163,6 +1210,7 @@ endif()
1163
1210
target_compile_definitions (ggml PUBLIC ${GGML_CDEF_PUBLIC} )
1164
1211
target_include_directories (ggml PUBLIC ../include )
1165
1212
target_include_directories (ggml PRIVATE . ${GGML_EXTRA_INCLUDES} )
1213
+ target_link_directories (ggml PRIVATE ${GGML_EXTRA_LIBDIRS} )
1166
1214
target_compile_features (ggml PRIVATE c_std_11) # don't bump
1167
1215
1168
1216
target_link_libraries (ggml PRIVATE Threads::Threads ${GGML_EXTRA_LIBS} )
0 commit comments