Skip to content

Commit e54c35e

Browse files
authored
feat: Support Moore Threads GPU (#8383)
* Update doc for MUSA Signed-off-by: Xiaodong Ye <[email protected]> * Add GGML_MUSA in Makefile Signed-off-by: Xiaodong Ye <[email protected]> * Add GGML_MUSA in CMake Signed-off-by: Xiaodong Ye <[email protected]> * CUDA => MUSA Signed-off-by: Xiaodong Ye <[email protected]> * MUSA adds support for __vsubss4 Signed-off-by: Xiaodong Ye <[email protected]> * Fix CI build failure Signed-off-by: Xiaodong Ye <[email protected]> --------- Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 5e2727f commit e54c35e

File tree

9 files changed

+328
-29
lines changed

9 files changed

+328
-29
lines changed

Makefile

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -528,10 +528,21 @@ ifndef GGML_NO_ACCELERATE
528528
endif
529529
endif # GGML_NO_ACCELERATE
530530

531+
ifdef GGML_MUSA
532+
CC := clang
533+
CXX := clang++
534+
GGML_CUDA := 1
535+
MK_CPPFLAGS += -DGGML_USE_MUSA
536+
endif
537+
531538
ifndef GGML_NO_OPENMP
532539
MK_CPPFLAGS += -DGGML_USE_OPENMP
533540
MK_CFLAGS += -fopenmp
534541
MK_CXXFLAGS += -fopenmp
542+
ifdef GGML_MUSA
543+
MK_CPPFLAGS += -I/usr/lib/llvm-10/include/openmp
544+
MK_LDFLAGS += -L/usr/lib/llvm-10/lib
545+
endif # GGML_MUSA
535546
endif # GGML_NO_OPENMP
536547

537548
ifdef GGML_OPENBLAS
@@ -582,15 +593,27 @@ else
582593
endif # GGML_CUDA_FA_ALL_QUANTS
583594

584595
ifdef GGML_CUDA
585-
ifneq ('', '$(wildcard /opt/cuda)')
586-
CUDA_PATH ?= /opt/cuda
596+
ifdef GGML_MUSA
597+
ifneq ('', '$(wildcard /opt/musa)')
598+
CUDA_PATH ?= /opt/musa
599+
else
600+
CUDA_PATH ?= /usr/local/musa
601+
endif
602+
603+
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
604+
MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
605+
MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22
587606
else
588-
CUDA_PATH ?= /usr/local/cuda
589-
endif
607+
ifneq ('', '$(wildcard /opt/cuda)')
608+
CUDA_PATH ?= /opt/cuda
609+
else
610+
CUDA_PATH ?= /usr/local/cuda
611+
endif
590612

591-
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
592-
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
593-
MK_NVCCFLAGS += -use_fast_math
613+
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
614+
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
615+
MK_NVCCFLAGS += -use_fast_math
616+
endif # GGML_MUSA
594617

595618
OBJ_GGML += ggml/src/ggml-cuda.o
596619
OBJ_GGML += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
@@ -600,9 +623,11 @@ ifdef LLAMA_FATAL_WARNINGS
600623
MK_NVCCFLAGS += -Werror all-warnings
601624
endif # LLAMA_FATAL_WARNINGS
602625

626+
ifndef GGML_MUSA
603627
ifndef JETSON_EOL_MODULE_DETECT
604628
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
605629
endif # JETSON_EOL_MODULE_DETECT
630+
endif # GGML_MUSA
606631

607632
ifdef LLAMA_DEBUG
608633
MK_NVCCFLAGS += -lineinfo
@@ -615,8 +640,12 @@ endif # GGML_CUDA_DEBUG
615640
ifdef GGML_CUDA_NVCC
616641
NVCC = $(CCACHE) $(GGML_CUDA_NVCC)
617642
else
618-
NVCC = $(CCACHE) nvcc
619-
endif #GGML_CUDA_NVCC
643+
ifdef GGML_MUSA
644+
NVCC = $(CCACHE) mcc
645+
else
646+
NVCC = $(CCACHE) nvcc
647+
endif # GGML_MUSA
648+
endif # GGML_CUDA_NVCC
620649

621650
ifdef CUDA_DOCKER_ARCH
622651
MK_NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
@@ -687,9 +716,15 @@ define NVCC_COMPILE
687716
$(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
688717
endef # NVCC_COMPILE
689718
else
719+
ifdef GGML_MUSA
720+
define NVCC_COMPILE
721+
$(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -c $< -o $@
722+
endef # NVCC_COMPILE
723+
else
690724
define NVCC_COMPILE
691725
$(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
692726
endef # NVCC_COMPILE
727+
endif # GGML_MUSA
693728
endif # JETSON_EOL_MODULE_DETECT
694729

695730
ggml/src/ggml-cuda/%.o: \
@@ -944,6 +979,7 @@ $(info I CXX: $(shell $(CXX) --version | head -n 1))
944979
ifdef GGML_CUDA
945980
$(info I NVCC: $(shell $(NVCC) --version | tail -n 1))
946981
CUDA_VERSION := $(shell $(NVCC) --version | grep -oP 'release (\K[0-9]+\.[0-9])')
982+
ifndef GGML_MUSA
947983
ifeq ($(shell awk -v "v=$(CUDA_VERSION)" 'BEGIN { print (v < 11.7) }'),1)
948984

949985
ifndef CUDA_DOCKER_ARCH
@@ -953,6 +989,7 @@ endif # CUDA_POWER_ARCH
953989
endif # CUDA_DOCKER_ARCH
954990

955991
endif # eq ($(shell echo "$(CUDA_VERSION) < 11.7" | bc),1)
992+
endif # GGML_MUSA
956993
endif # GGML_CUDA
957994
$(info )
958995

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ Please refer to [Build llama.cpp locally](./docs/build.md)
409409
| [BLAS](./docs/build.md#blas-build) | All |
410410
| [BLIS](./docs/backend/BLIS.md) | All |
411411
| [SYCL](./docs/backend/SYCL.md) | Intel and Nvidia GPU |
412+
| [MUSA](./docs/build.md#musa) | Moore Threads GPU |
412413
| [CUDA](./docs/build.md#cuda) | Nvidia GPU |
413414
| [hipBLAS](./docs/build.md#hipblas) | AMD GPU |
414415
| [Vulkan](./docs/build.md#vulkan) | GPU |

docs/build.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,19 @@ The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/c
192192
| GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
193193
| GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. |
194194
195+
### MUSA
196+
197+
- Using `make`:
198+
```bash
199+
make GGML_MUSA=1
200+
```
201+
- Using `CMake`:
202+
203+
```bash
204+
cmake -B build -DGGML_MUSA=ON
205+
cmake --build build --config Release
206+
```
207+
195208
### hipBLAS
196209
197210
This provides BLAS acceleration on HIP-supported AMD GPUs.

ggml/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
113113
option(GGML_LLAMAFILE "ggml: use LLAMAFILE" OFF)
114114

115115
option(GGML_CUDA "ggml: use CUDA" OFF)
116+
option(GGML_MUSA "ggml: use MUSA" OFF)
116117
option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
117118
option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
118119
option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)

ggml/include/ggml-cuda.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#ifdef GGML_USE_HIPBLAS
77
#define GGML_CUDA_NAME "ROCm"
88
#define GGML_CUBLAS_NAME "hipBLAS"
9+
#elif defined(GGML_USE_MUSA)
10+
#define GGML_CUDA_NAME "MUSA"
11+
#define GGML_CUBLAS_NAME "muBLAS"
912
#else
1013
#define GGML_CUDA_NAME "CUDA"
1114
#define GGML_CUBLAS_NAME "cuBLAS"

ggml/src/CMakeLists.txt

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,17 @@ if (GGML_METAL)
139139
)
140140
endif()
141141

142+
if (GGML_MUSA)
143+
set(CMAKE_C_COMPILER clang)
144+
set(CMAKE_C_EXTENSIONS OFF)
145+
set(CMAKE_CXX_COMPILER clang++)
146+
set(CMAKE_CXX_EXTENSIONS OFF)
147+
148+
set(GGML_CUDA ON)
149+
150+
list(APPEND GGML_CDEF_PUBLIC GGML_USE_MUSA)
151+
endif()
152+
142153
if (GGML_OPENMP)
143154
find_package(OpenMP)
144155
if (OpenMP_FOUND)
@@ -147,6 +158,11 @@ if (GGML_OPENMP)
147158
add_compile_definitions(GGML_USE_OPENMP)
148159

149160
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
161+
162+
if (GGML_MUSA)
163+
set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} "/usr/lib/llvm-10/include/openmp")
164+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} "/usr/lib/llvm-10/lib/libomp.so")
165+
endif()
150166
else()
151167
message(WARNING "OpenMP not found")
152168
endif()
@@ -249,7 +265,13 @@ endif()
249265
if (GGML_CUDA)
250266
cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
251267

252-
find_package(CUDAToolkit)
268+
if (GGML_MUSA)
269+
list(APPEND CMAKE_MODULE_PATH "/usr/local/musa/cmake/")
270+
find_package(MUSAToolkit)
271+
set(CUDAToolkit_FOUND ${MUSAToolkit_FOUND})
272+
else()
273+
find_package(CUDAToolkit)
274+
endif()
253275

254276
if (CUDAToolkit_FOUND)
255277
message(STATUS "CUDA found")
@@ -268,7 +290,11 @@ if (GGML_CUDA)
268290
endif()
269291
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
270292

271-
enable_language(CUDA)
293+
if (GGML_MUSA)
294+
set(CMAKE_CUDA_COMPILER ${MUSAToolkit_MCC_EXECUTABLE})
295+
else()
296+
enable_language(CUDA)
297+
endif()
272298

273299
file(GLOB GGML_HEADERS_CUDA "ggml-cuda/*.cuh")
274300
list(APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h")
@@ -332,21 +358,40 @@ if (GGML_CUDA)
332358
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
333359
endif()
334360

361+
if (GGML_MUSA)
362+
set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
363+
foreach(SOURCE ${GGML_SOURCES_CUDA})
364+
set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22")
365+
endforeach()
366+
endif()
367+
335368
if (GGML_STATIC)
336369
if (WIN32)
337370
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
338371
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
339372
else ()
340-
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
373+
if (GGML_MUSA)
374+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart_static MUSA::mublas_static)
375+
else()
376+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
377+
endif()
341378
endif()
342379
else()
343-
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
380+
if (GGML_MUSA)
381+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart MUSA::mublas)
382+
else()
383+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
384+
endif()
344385
endif()
345386

346387
if (GGML_CUDA_NO_VMM)
347388
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
348389
else()
349-
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
390+
if (GGML_MUSA)
391+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musa_driver) # required by muDeviceGetAttribute(), muMemGetAllocationGranularity(...), ...
392+
else()
393+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
394+
endif()
350395
endif()
351396
else()
352397
message(WARNING "CUDA not found")
@@ -857,8 +902,10 @@ function(get_flags CCID CCVER)
857902
set(C_FLAGS -Wdouble-promotion)
858903
set(CXX_FLAGS -Wno-array-bounds)
859904

860-
if (CCVER VERSION_GREATER_EQUAL 7.1.0)
861-
list(APPEND CXX_FLAGS -Wno-format-truncation)
905+
if (NOT GGML_MUSA)
906+
if (CCVER VERSION_GREATER_EQUAL 7.1.0)
907+
list(APPEND CXX_FLAGS -Wno-format-truncation)
908+
endif()
862909
endif()
863910
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
864911
list(APPEND CXX_FLAGS -Wextra-semi)
@@ -1264,6 +1311,7 @@ endif()
12641311
target_compile_definitions(ggml PUBLIC ${GGML_CDEF_PUBLIC})
12651312
target_include_directories(ggml PUBLIC ../include)
12661313
target_include_directories(ggml PRIVATE . ${GGML_EXTRA_INCLUDES})
1314+
target_link_directories(ggml PRIVATE ${GGML_EXTRA_LIBDIRS})
12671315
target_compile_features (ggml PRIVATE c_std_11) # don't bump
12681316

12691317
target_link_libraries(ggml PRIVATE Threads::Threads ${GGML_EXTRA_LIBS})

ggml/src/ggml-common.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ typedef half2 ggml_half2;
1919

2020
#define GGML_COMMON_DECL
2121
#elif defined(GGML_COMMON_DECL_CUDA)
22+
#if defined(GGML_COMMON_DECL_MUSA)
23+
#include <musa_fp16.h>
24+
#else
2225
#include <cuda_fp16.h>
26+
#endif
2327
#include <cstdint>
2428

2529
typedef half ggml_half;
@@ -415,7 +419,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
415419
#define GGML_TABLE_END() };
416420

417421
#define GGML_COMMON_IMPL
418-
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP)
422+
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
419423
#include <cstdint>
420424

421425
#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {

ggml/src/ggml-cuda.cu

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
167167
for (int id = 0; id < info.device_count; ++id) {
168168
int device_vmm = 0;
169169

170-
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
170+
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
171171
CUdevice device;
172172
CU_CHECK(cuDeviceGet(&device, id));
173173
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
@@ -179,7 +179,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
179179
alloc_prop.location.id = id;
180180
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
181181
}
182-
#endif // !defined(GGML_USE_HIPBLAS)
182+
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
183183
info.devices[id].vmm = !!device_vmm;
184184

185185
cudaDeviceProp prop;
@@ -315,7 +315,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
315315
};
316316

317317
// pool with virtual memory
318-
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
318+
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
319319
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
320320
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
321321

@@ -409,14 +409,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
409409
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
410410
}
411411
};
412-
#endif // !defined(GGML_USE_HIPBLAS)
412+
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
413413

414414
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
415-
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
415+
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
416416
if (ggml_cuda_info().devices[device].vmm) {
417417
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
418418
}
419-
#endif
419+
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
420420
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
421421
}
422422

@@ -1341,7 +1341,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
13411341
static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
13421342
void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
13431343

1344-
#if !defined(GGML_USE_HIPBLAS)
1344+
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
13451345
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
13461346
cudaMemcpy3DPeerParms p = {};
13471347
p.dstDevice = dstDevice;
@@ -1355,7 +1355,7 @@ static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
13551355
GGML_UNUSED(dstDevice);
13561356
GGML_UNUSED(srcDevice);
13571357
return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
1358-
#endif // !defined(GGML_USE_HIPBLAS)
1358+
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
13591359
}
13601360

13611361
static void ggml_cuda_op_mul_mat(
@@ -1828,6 +1828,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18281828
}
18291829
}
18301830
#else
1831+
#ifdef GGML_USE_MUSA
1832+
GGML_ASSERT(false);
1833+
#else // !GGML_USE_MUSA
18311834
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
18321835
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
18331836
// use cublasGemmStridedBatchedEx
@@ -1870,6 +1873,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18701873
cu_compute_type,
18711874
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
18721875
}
1876+
#endif // GGML_USE_MUSA
18731877
#endif
18741878

18751879
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
@@ -3027,7 +3031,7 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size
30273031
return false;
30283032
}
30293033

3030-
#if CUDART_VERSION >= 11100
3034+
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
30313035
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
30323036
if (err != cudaSuccess) {
30333037
// clear the error

0 commit comments

Comments
 (0)