@@ -526,10 +526,21 @@ ifndef GGML_NO_ACCELERATE
526
526
endif
527
527
endif # GGML_NO_ACCELERATE
528
528
529
+ ifdef GGML_MUSA
530
+ CC := clang
531
+ CXX := clang++
532
+ GGML_CUDA := 1
533
+ MK_CPPFLAGS += -DGGML_USE_MUSA
534
+ endif
535
+
529
536
ifndef GGML_NO_OPENMP
530
537
MK_CPPFLAGS += -DGGML_USE_OPENMP
531
538
MK_CFLAGS += -fopenmp
532
539
MK_CXXFLAGS += -fopenmp
540
+ ifdef GGML_MUSA
541
+ MK_CPPFLAGS += -I/usr/lib/llvm-10/include/openmp
542
+ MK_LDFLAGS += -L/usr/lib/llvm-10/lib
543
+ endif # GGML_MUSA
533
544
endif # GGML_NO_OPENMP
534
545
535
546
ifdef GGML_OPENBLAS
@@ -574,15 +585,27 @@ else
574
585
endif # GGML_CUDA_FA_ALL_QUANTS
575
586
576
587
ifdef GGML_CUDA
577
- ifneq ('', '$(wildcard /opt/cuda)')
578
- CUDA_PATH ?= /opt/cuda
588
+ ifdef GGML_MUSA
589
+ ifneq ('', '$(wildcard /opt/musa)')
590
+ CUDA_PATH ?= /opt/musa
591
+ else
592
+ CUDA_PATH ?= /usr/local/musa
593
+ endif
594
+
595
+ MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
596
+ MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
597
+ MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22
579
598
else
580
- CUDA_PATH ?= /usr/local/cuda
581
- endif
599
+ ifneq ('', '$(wildcard /opt/cuda)')
600
+ CUDA_PATH ?= /opt/cuda
601
+ else
602
+ CUDA_PATH ?= /usr/local/cuda
603
+ endif
582
604
583
- MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
584
- MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
585
- MK_NVCCFLAGS += -use_fast_math
605
+ MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
606
+ MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
607
+ MK_NVCCFLAGS += -use_fast_math
608
+ endif # GGML_MUSA
586
609
587
610
OBJ_GGML += ggml/src/ggml-cuda.o
588
611
OBJ_GGML += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
@@ -592,9 +615,11 @@ ifdef LLAMA_FATAL_WARNINGS
592
615
MK_NVCCFLAGS += -Werror all-warnings
593
616
endif # LLAMA_FATAL_WARNINGS
594
617
618
+ ifndef GGML_MUSA
595
619
ifndef JETSON_EOL_MODULE_DETECT
596
620
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
597
621
endif # JETSON_EOL_MODULE_DETECT
622
+ endif # GGML_MUSA
598
623
599
624
ifdef LLAMA_DEBUG
600
625
MK_NVCCFLAGS += -lineinfo
@@ -607,8 +632,12 @@ endif # GGML_CUDA_DEBUG
607
632
ifdef GGML_CUDA_NVCC
608
633
NVCC = $(CCACHE) $(GGML_CUDA_NVCC)
609
634
else
610
- NVCC = $(CCACHE) nvcc
611
- endif # GGML_CUDA_NVCC
635
+ ifdef GGML_MUSA
636
+ NVCC = $(CCACHE) mcc
637
+ else
638
+ NVCC = $(CCACHE) nvcc
639
+ endif # GGML_MUSA
640
+ endif # GGML_CUDA_NVCC
612
641
613
642
ifdef CUDA_DOCKER_ARCH
614
643
MK_NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
@@ -679,9 +708,15 @@ define NVCC_COMPILE
679
708
$(NVCC ) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS ) $(CPPFLAGS ) -Xcompiler "$(CUDA_CXXFLAGS ) " -c $< -o $@
680
709
endef # NVCC_COMPILE
681
710
else
711
+ ifdef GGML_MUSA
712
+ define NVCC_COMPILE
713
+ $(NVCC ) $(NVCCFLAGS ) $(CPPFLAGS ) -c $< -o $@
714
+ endef # NVCC_COMPILE
715
+ else
682
716
define NVCC_COMPILE
683
717
$(NVCC ) $(NVCCFLAGS ) $(CPPFLAGS ) -Xcompiler "$(CUDA_CXXFLAGS ) " -c $< -o $@
684
718
endef # NVCC_COMPILE
719
+ endif # GGML_MUSA
685
720
endif # JETSON_EOL_MODULE_DETECT
686
721
687
722
ggml/src/ggml-cuda/% .o : \
@@ -907,6 +942,7 @@ $(info I CXX: $(shell $(CXX) --version | head -n 1))
907
942
ifdef GGML_CUDA
908
943
$(info I NVCC : $(shell $(NVCC ) --version | tail -n 1) )
909
944
CUDA_VERSION := $(shell $(NVCC ) --version | grep -oP 'release (\K[0-9]+\.[0-9]) ')
945
+ ifndef GGML_MUSA
910
946
ifeq ($(shell awk -v "v=$(CUDA_VERSION ) " 'BEGIN { print (v < 11.7) }'),1)
911
947
912
948
ifndef CUDA_DOCKER_ARCH
@@ -916,6 +952,7 @@ endif # CUDA_POWER_ARCH
916
952
endif # CUDA_DOCKER_ARCH
917
953
918
954
endif # eq ($(shell echo "$(CUDA_VERSION) < 11.7" | bc),1)
955
+ endif # GGML_MUSA
919
956
endif # GGML_CUDA
920
957
$(info )
921
958
0 commit comments