Skip to content

Commit a69169f

Browse files
authored
Merge pull request #10 from l3utterfly/test-flash-attn
Test flash attn
2 parents f6e7e93 + 75c37ed commit a69169f

35 files changed

+4828
-364
lines changed

CMakeLists.txt

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,18 @@ else()
4343
set(LLAMA_METAL_DEFAULT OFF)
4444
endif()
4545

46+
# TODO: fix this for Android CI
47+
# https://github.com/ggerganov/llama.cpp/pull/6716#issuecomment-2061509191
48+
#if (CMAKE_SYSTEM_NAME MATCHES "ANDROID")
49+
# set(LLAMA_LLAMAFILE_DEFAULT OFF)
50+
#else()
51+
# set(LLAMA_LLAMAFILE_DEFAULT ON)
52+
#endif()
53+
54+
# TODO: temporary disable until MoE is fixed
55+
# https://github.com/ggerganov/llama.cpp/pull/6716
56+
set(LLAMA_LLAMAFILE_DEFAULT OFF)
57+
4658
# general
4759
option(BUILD_SHARED_LIBS "build shared libraries" OFF)
4860
option(LLAMA_STATIC "llama: static link libraries" OFF)
@@ -88,6 +100,7 @@ endif()
88100
# 3rd party libs
89101
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
90102
option(LLAMA_BLAS "llama: use BLAS" OFF)
103+
option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" ${LLAMA_LLAMAFILE_DEFAULT})
91104
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
92105
option(LLAMA_CUDA "llama: use CUDA" OFF)
93106
option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF)
@@ -286,6 +299,7 @@ if (LLAMA_METAL)
286299
${METALKIT_FRAMEWORK}
287300
)
288301
endif()
302+
289303
if (LLAMA_BLAS)
290304
message(STATUS "Building with OpenBLAS")
291305

@@ -307,6 +321,13 @@ if (LLAMA_BLAS)
307321
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ../OpenBLAS ${CMAKE_CURRENT_BINARY_DIR}/OpenBLAS)
308322
endif()
309323

324+
if (LLAMA_LLAMAFILE)
325+
add_compile_definitions(GGML_USE_LLAMAFILE)
326+
327+
set(GGML_HEADERS_LLAMAFILE sgemm.h)
328+
set(GGML_SOURCES_LLAMAFILE sgemm.cpp)
329+
endif()
330+
310331
if (LLAMA_QKK_64)
311332
add_compile_definitions(GGML_QKK_64)
312333
endif()
@@ -1095,15 +1116,16 @@ add_library(ggml OBJECT
10951116
ggml-backend.h
10961117
ggml-quants.c
10971118
ggml-quants.h
1098-
${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA}
1099-
${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
1100-
${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
1101-
${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI}
1102-
${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA}
1103-
${GGML_SOURCES_SYCL} ${GGML_HEADERS_SYCL}
1104-
${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE}
1105-
${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN}
1106-
${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM}
1119+
${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA}
1120+
${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
1121+
${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
1122+
${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI}
1123+
${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA}
1124+
${GGML_SOURCES_SYCL} ${GGML_HEADERS_SYCL}
1125+
${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE}
1126+
${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN}
1127+
${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM}
1128+
${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
11071129
)
11081130

11091131
target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES})

Makefile

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,15 @@ ifdef LLAMA_OPENBLAS
384384
MK_LDFLAGS += $(shell pkg-config --libs openblas)
385385
endif # LLAMA_OPENBLAS
386386

387+
# TODO: temporary disable until MoE is fixed
388+
# https://github.com/ggerganov/llama.cpp/pull/6716
389+
LLAMA_NO_LLAMAFILE := 1
390+
391+
ifndef LLAMA_NO_LLAMAFILE
392+
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
393+
OBJS += sgemm.o
394+
endif
395+
387396
ifdef LLAMA_BLIS
388397
MK_CPPFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/blis -I/usr/include/blis
389398
MK_LDFLAGS += -lblis -L/usr/local/lib
@@ -480,11 +489,9 @@ ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/com
480489

481490
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
482491
$(NVCC_COMPILE)
483-
484492
endif # LLAMA_CUDA
485493

486494
ifdef LLAMA_CLBLAST
487-
488495
MK_CPPFLAGS += -DGGML_USE_CLBLAST $(shell pkg-config --cflags-only-I clblast OpenCL)
489496
MK_CFLAGS += $(shell pkg-config --cflags-only-other clblast OpenCL)
490497
MK_CXXFLAGS += $(shell pkg-config --cflags-only-other clblast OpenCL)
@@ -603,6 +610,11 @@ ggml-mpi.o: ggml-mpi.c ggml-mpi.h
603610
$(CC) $(CFLAGS) -c $< -o $@
604611
endif # LLAMA_MPI
605612

613+
ifndef LLAMA_NO_LLAMAFILE
614+
sgemm.o: sgemm.cpp sgemm.h ggml.h
615+
$(CXX) $(CXXFLAGS) -c $< -o $@
616+
endif
617+
606618
GF_CC := $(CC)
607619
include scripts/get-flags.mk
608620

Package.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import PackageDescription
44

55
var sources = [
66
"ggml.c",
7+
"sgemm.cpp",
78
"llama.cpp",
89
"unicode.cpp",
910
"unicode-data.cpp",

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ Unless otherwise noted these projects are open-source with permissive licensing:
189189
- [MindMac](https://mindmac.app) (proprietary)
190190
- [KodiBot](https://github.com/firatkiral/kodibot) (GPL)
191191
- [eva](https://github.com/ylsdamxssjxxdd/eva) (MIT)
192+
- [AI Sublime Text plugin](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (MIT)
193+
192194
*(to have a project listed here, it should clearly state that it depends on `llama.cpp`)*
193195

194196
---

build.zig

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ pub fn build(b: *std.build.Builder) !void {
112112
make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false;
113113

114114
const ggml = make.obj("ggml", "ggml.c");
115+
const sgemm = make.obj("sgemm", "sgemm.cpp");
115116
const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
116117
const ggml_backend = make.obj("ggml-backend", "ggml-backend.c");
117118
const ggml_quants = make.obj("ggml-quants", "ggml-quants.c");
@@ -128,14 +129,14 @@ pub fn build(b: *std.build.Builder) !void {
128129
const clip = make.obj("clip", "examples/llava/clip.cpp");
129130
const llava = make.obj("llava", "examples/llava/llava.cpp");
130131

131-
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser });
132-
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
133-
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
134-
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
135-
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
136-
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
132+
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser });
133+
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
134+
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
135+
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
136+
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
137+
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
137138

138-
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, grammar_parser, clip, llava });
139+
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, grammar_parser, clip, llava });
139140
if (server.target.isWindows()) {
140141
server.linkSystemLibrary("ws2_32");
141142
}

common/common.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,79 @@ int32_t get_num_physical_cores() {
108108
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
109109
}
110110

111+
#if defined(__x86_64__) && defined(__linux__)
112+
#include <pthread.h>
113+
114+
static void cpuid(unsigned leaf, unsigned subleaf,
115+
unsigned *eax, unsigned *ebx, unsigned *ecx, unsigned *edx) {
116+
__asm__("movq\t%%rbx,%%rsi\n\t"
117+
"cpuid\n\t"
118+
"xchgq\t%%rbx,%%rsi"
119+
: "=a"(*eax), "=S"(*ebx), "=c"(*ecx), "=d"(*edx)
120+
: "0"(leaf), "2"(subleaf));
121+
}
122+
123+
static int pin_cpu(int cpu) {
124+
cpu_set_t mask;
125+
CPU_ZERO(&mask);
126+
CPU_SET(cpu, &mask);
127+
return pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask);
128+
}
129+
130+
static bool is_hybrid_cpu(void) {
131+
unsigned eax, ebx, ecx, edx;
132+
cpuid(7, 0, &eax, &ebx, &ecx, &edx);
133+
return !!(edx & (1u << 15));
134+
}
135+
136+
static bool is_running_on_efficiency_core(void) {
137+
unsigned eax, ebx, ecx, edx;
138+
cpuid(0x1a, 0, &eax, &ebx, &ecx, &edx);
139+
int intel_atom = 0x20;
140+
int core_type = (eax & 0xff000000u) >> 24;
141+
return core_type == intel_atom;
142+
}
143+
144+
static int count_math_cpus(int cpu_count) {
145+
int result = 0;
146+
for (int cpu = 0; cpu < cpu_count; ++cpu) {
147+
if (pin_cpu(cpu)) {
148+
return -1;
149+
}
150+
if (is_running_on_efficiency_core()) {
151+
continue; // efficiency cores harm lockstep threading
152+
}
153+
++cpu; // hyperthreading isn't useful for linear algebra
154+
++result;
155+
}
156+
return result;
157+
}
158+
159+
#endif // __x86_64__ && __linux__
160+
161+
/**
162+
* Returns number of CPUs on system that are useful for math.
163+
*/
164+
int get_math_cpu_count() {
165+
#if defined(__x86_64__) && defined(__linux__)
166+
int cpu_count = sysconf(_SC_NPROCESSORS_ONLN);
167+
if (cpu_count < 1) {
168+
return get_num_physical_cores();
169+
}
170+
if (is_hybrid_cpu()) {
171+
cpu_set_t affinity;
172+
if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) {
173+
int result = count_math_cpus(cpu_count);
174+
pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity);
175+
if (result > 0) {
176+
return result;
177+
}
178+
}
179+
}
180+
#endif
181+
return get_num_physical_cores();
182+
}
183+
111184
void process_escapes(std::string & input) {
112185
std::size_t input_len = input.length();
113186
std::size_t output_idx = 0;
@@ -827,6 +900,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
827900
params.cont_batching = true;
828901
return true;
829902
}
903+
if (arg == "-fa" || arg == "--flash-attn") {
904+
params.flash_attn = true;
905+
return true;
906+
}
830907
if (arg == "--color") {
831908
params.use_color = true;
832909
return true;
@@ -1763,6 +1840,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
17631840
cparams.cb_eval = params.cb_eval;
17641841
cparams.cb_eval_user_data = params.cb_eval_user_data;
17651842
cparams.offload_kqv = !params.no_kv_offload;
1843+
cparams.flash_attn = params.flash_attn;
17661844

17671845
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
17681846
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
@@ -2600,6 +2678,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
26002678
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
26012679
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
26022680
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
2681+
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
26032682
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
26042683

26052684
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());

common/common.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ extern char const *LLAMA_BUILD_TARGET;
3939

4040
struct llama_control_vector_load_info;
4141

42+
int get_math_cpu_count();
4243
int32_t get_num_physical_cores();
4344

4445
//
@@ -48,7 +49,7 @@ int32_t get_num_physical_cores();
4849
struct gpt_params {
4950
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
5051

51-
int32_t n_threads = get_num_physical_cores();
52+
int32_t n_threads = get_math_cpu_count();
5253
int32_t n_threads_draft = -1;
5354
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
5455
int32_t n_threads_batch_draft = -1;
@@ -147,6 +148,7 @@ struct gpt_params {
147148
bool multiline_input = false; // reverse the usage of `\`
148149
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
149150
bool cont_batching = true; // insert new sequences for decoding on-the-fly
151+
bool flash_attn = false; // flash attention
150152

151153
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
152154
bool ignore_eos = false; // ignore generated EOS tokens

common/sampling.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,18 @@ static llama_token_data_array llama_sampling_prepare_impl(
260260

261261
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
262262

263+
// repetition penalties
263264
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
264265
const float penalty_repeat = params.penalty_repeat;
265266
const float penalty_freq = params.penalty_freq;
266267
const float penalty_present = params.penalty_present;
267-
268268
const bool penalize_nl = params.penalize_nl;
269269

270+
// DRY sampler parameters
271+
const float dry_multiplier = params.dry_multiplier;
272+
const float dry_base = params.dry_base;
273+
const int dry_allowed_length = params.dry_allowed_length;
274+
270275
auto & prev = ctx_sampling->prev;
271276
auto & cur = ctx_sampling->cur;
272277

@@ -302,10 +307,20 @@ static llama_token_data_array llama_sampling_prepare_impl(
302307
if (penalty_tokens_used_size) {
303308
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
304309

310+
// repetition penalties
305311
llama_sample_repetition_penalties(ctx_main, &cur_p,
306312
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
307313
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
308314

315+
// DRY penalties (multiplier > 0 means enabled)
316+
if(dry_multiplier > 0.0f) {
317+
llama_sample_dry(ctx_main, &cur_p,
318+
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
319+
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
320+
params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size());
321+
}
322+
323+
309324
if (!penalize_nl) {
310325
for (size_t idx = 0; idx < cur_p.size; idx++) {
311326
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {

common/sampling.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ typedef struct llama_sampling_params {
3838
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
3939
float mirostat_tau = 5.00f; // target entropy
4040
float mirostat_eta = 0.10f; // learning rate
41-
bool penalize_nl = false; // consider newlines as a repeatable token
41+
bool penalize_nl = false; // consider newlines as a repeatable token
42+
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
43+
float dry_base = 1.75f;
44+
int dry_allowed_length = 2;
4245

4346
std::vector<llama_sampler_type> samplers_sequence = {
4447
llama_sampler_type::TOP_K,
@@ -59,6 +62,7 @@ typedef struct llama_sampling_params {
5962
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
6063

6164
std::vector<llama_token> penalty_prompt_tokens;
65+
std::vector<llama_token> dry_sequence_breakers; // sequence breakers for the DRY sampler
6266
bool use_penalty_prompt_tokens = false;
6367
} llama_sampling_params;
6468

0 commit comments

Comments
 (0)