Skip to content

Commit 5a3059d

Browse files
authored
Merge branch 'layla-build' into dry-sampler
2 parents 7e08885 + a69169f commit 5a3059d

28 files changed

+2672
-294
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ cmake-build-*
3939
out/
4040
tmp/
4141

42+
loras/*
4243
models/*
4344
models-mnt
4445

CMakeLists.txt

Lines changed: 22 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ if (LLAMA_METAL)
295295
endif()
296296

297297
if (LLAMA_BLAS)
298+
message(STATUS "Building with OpenBLAS")
299+
298300
if (LLAMA_STATIC)
299301
set(BLA_STATIC ON)
300302
endif()
@@ -303,77 +305,14 @@ if (LLAMA_BLAS)
303305
endif()
304306

305307
set(BLA_VENDOR ${LLAMA_BLAS_VENDOR})
306-
find_package(BLAS)
307-
308-
if (BLAS_FOUND)
309-
message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}")
310-
311-
if ("${BLAS_INCLUDE_DIRS}" STREQUAL "")
312-
# BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake.
313-
# see https://gitlab.kitware.com/cmake/cmake/-/issues/20268
314-
find_package(PkgConfig REQUIRED)
315-
if (${LLAMA_BLAS_VENDOR} MATCHES "Generic")
316-
pkg_check_modules(DepBLAS REQUIRED blas)
317-
elseif (${LLAMA_BLAS_VENDOR} MATCHES "OpenBLAS")
318-
# As of openblas v0.3.22, the 64-bit is named openblas64.pc
319-
pkg_check_modules(DepBLAS openblas64)
320-
if (NOT DepBLAS_FOUND)
321-
pkg_check_modules(DepBLAS REQUIRED openblas)
322-
endif()
323-
elseif (${LLAMA_BLAS_VENDOR} MATCHES "FLAME")
324-
pkg_check_modules(DepBLAS REQUIRED blis)
325-
elseif (${LLAMA_BLAS_VENDOR} MATCHES "ATLAS")
326-
pkg_check_modules(DepBLAS REQUIRED blas-atlas)
327-
elseif (${LLAMA_BLAS_VENDOR} MATCHES "FlexiBLAS")
328-
pkg_check_modules(DepBLAS REQUIRED flexiblas_api)
329-
elseif (${LLAMA_BLAS_VENDOR} MATCHES "Intel")
330-
# all Intel* libraries share the same include path
331-
pkg_check_modules(DepBLAS REQUIRED mkl-sdl)
332-
elseif (${LLAMA_BLAS_VENDOR} MATCHES "NVHPC")
333-
# this doesn't provide pkg-config
334-
# suggest to assign BLAS_INCLUDE_DIRS on your own
335-
if ("${NVHPC_VERSION}" STREQUAL "")
336-
message(WARNING "Better to set NVHPC_VERSION")
337-
else()
338-
set(DepBLAS_FOUND ON)
339-
set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include")
340-
endif()
341-
endif()
342-
if (DepBLAS_FOUND)
343-
set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS})
344-
else()
345-
message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically"
346-
" detected by pkgconfig, trying to find cblas.h from possible paths...")
347-
find_path(BLAS_INCLUDE_DIRS
348-
NAMES cblas.h
349-
HINTS
350-
/usr/include
351-
/usr/local/include
352-
/usr/include/openblas
353-
/opt/homebrew/opt/openblas/include
354-
/usr/local/opt/openblas/include
355-
/usr/include/x86_64-linux-gnu/openblas/include
356-
)
357-
endif()
358-
endif()
308+
add_compile_options(${BLAS_LINKER_FLAGS})
359309

360-
message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}")
310+
add_compile_definitions(GGML_USE_OPENBLAS)
361311

362-
add_compile_options(${BLAS_LINKER_FLAGS})
312+
add_subdirectory(../OpenBLAS ${CMAKE_CURRENT_BINARY_DIR}/OpenBLAS)
363313

364-
add_compile_definitions(GGML_USE_OPENBLAS)
365-
366-
if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${LLAMA_BLAS_VENDOR} MATCHES "Generic" OR ${LLAMA_BLAS_VENDOR} MATCHES "Intel"))
367-
add_compile_definitions(GGML_BLAS_USE_MKL)
368-
endif()
369-
370-
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES})
371-
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS})
372-
else()
373-
message(WARNING "BLAS not found, please refer to "
374-
"https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
375-
" to set correct LLAMA_BLAS_VENDOR")
376-
endif()
314+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas_shared)
315+
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ../OpenBLAS ${CMAKE_CURRENT_BINARY_DIR}/OpenBLAS)
377316
endif()
378317

379318
if (LLAMA_LLAMAFILE)
@@ -489,19 +428,24 @@ if (LLAMA_MPI)
489428
endif()
490429

491430
if (LLAMA_CLBLAST)
492-
find_package(CLBlast)
493-
if (CLBlast_FOUND)
494-
message(STATUS "CLBlast found")
431+
message(STATUS "Building with CLBlast")
495432

496-
set(GGML_HEADERS_OPENCL ggml-opencl.h)
497-
set(GGML_SOURCES_OPENCL ggml-opencl.cpp)
433+
set(GGML_HEADERS_OPENCL ggml-opencl.h)
434+
set(GGML_SOURCES_OPENCL ggml-opencl.cpp)
498435

499-
add_compile_definitions(GGML_USE_CLBLAST)
436+
add_compile_definitions(GGML_USE_CLBLAST)
500437

501-
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast)
502-
else()
503-
message(WARNING "CLBlast not found")
504-
endif()
438+
# link our libOpenCL.so (this is only used during compile time)
439+
add_library(OpenCL SHARED IMPORTED)
440+
set_target_properties(OpenCL PROPERTIES IMPORTED_LOCATION ${PROJECT_SOURCE_DIR}/../OpenCL/lib/libOpenCL.so)
441+
442+
# add our prebuilt clblast library
443+
add_library(clblast SHARED IMPORTED)
444+
set_target_properties(clblast PROPERTIES IMPORTED_LOCATION ${PROJECT_SOURCE_DIR}/../../android/app/src/main/jniLibs/${ANDROID_ABI}/libclblast.so)
445+
446+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast OpenCL)
447+
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ../CLBlast/include)
448+
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ../OpenCL/include)
505449
endif()
506450

507451
if (LLAMA_VULKAN)

common/common.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
924924
params.cont_batching = true;
925925
return true;
926926
}
927+
if (arg == "-fa" || arg == "--flash-attn") {
928+
params.flash_attn = true;
929+
return true;
930+
}
927931
if (arg == "--color") {
928932
params.use_color = true;
929933
return true;
@@ -1864,6 +1868,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
18641868
cparams.cb_eval = params.cb_eval;
18651869
cparams.cb_eval_user_data = params.cb_eval_user_data;
18661870
cparams.offload_kqv = !params.no_kv_offload;
1871+
cparams.flash_attn = params.flash_attn;
18671872

18681873
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
18691874
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
@@ -2701,6 +2706,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
27012706
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
27022707
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
27032708
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
2709+
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
27042710
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
27052711

27062712
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ struct gpt_params {
148148
bool multiline_input = false; // reverse the usage of `\`
149149
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
150150
bool cont_batching = true; // insert new sequences for decoding on-the-fly
151+
bool flash_attn = false; // flash attention
151152

152153
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
153154
bool ignore_eos = false; // ignore generated EOS tokens

common/sampling.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
4444
delete ctx;
4545
}
4646

47-
void llama_sampling_reset(llama_sampling_context * ctx) {
47+
void llama_sampling_reset_grammar(struct llama_sampling_context * ctx) {
4848
if (ctx->grammar != NULL) {
4949
llama_grammar_free(ctx->grammar);
50-
ctx->grammar = NULL;
50+
ctx->grammar = nullptr;
5151
}
5252

5353
if (!ctx->parsed_grammar.rules.empty()) {
@@ -57,6 +57,10 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
5757
grammar_rules.data(),
5858
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
5959
}
60+
}
61+
62+
void llama_sampling_reset(llama_sampling_context * ctx) {
63+
llama_sampling_reset_grammar(ctx);
6064

6165
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
6266
ctx->cur.clear();
@@ -310,13 +314,12 @@ static llama_token_data_array llama_sampling_prepare_impl(
310314

311315
// DRY penalties (multiplier > 0 means enabled)
312316
if(dry_multiplier > 0.0f) {
313-
llama_sample_dry(&cur_p,
317+
llama_sample_dry(&cur_p,
314318
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
315319
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
316320
params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size());
317321
}
318322

319-
320323
if (!penalize_nl) {
321324
for (size_t idx = 0; idx < cur_p.size; idx++) {
322325
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
@@ -366,3 +369,14 @@ void llama_sampling_accept(
366369
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
367370
}
368371
}
372+
373+
374+
void llama_sampling_rollback(
375+
struct llama_sampling_context * ctx_sampling,
376+
int rollback_num) {
377+
if(rollback_num > ctx_sampling->prev.size()) {
378+
rollback_num = ctx_sampling->prev.size();
379+
}
380+
381+
ctx_sampling->prev.erase(ctx_sampling->prev.end() - rollback_num, ctx_sampling->prev.end());
382+
}

common/sampling.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
9292

9393
void llama_sampling_free(struct llama_sampling_context * ctx);
9494

95+
// Reset the sampler grammar without resetting the context
96+
void llama_sampling_reset_grammar(struct llama_sampling_context * ctx);
97+
9598
// Reset the sampler context
9699
// - clear prev tokens
97100
// - reset grammar
@@ -149,3 +152,7 @@ void llama_sampling_accept(
149152
struct llama_context * ctx_main,
150153
llama_token id,
151154
bool apply_grammar);
155+
156+
void llama_sampling_rollback(
157+
struct llama_sampling_context * ctx_sampling,
158+
int rollback_num);

examples/batched-bench/batched-bench.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ int main(int argc, char ** argv) {
3232
gpt_params params;
3333

3434
if (argc == 1 || argv[1][0] == '-') {
35-
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
35+
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [FATTN] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
3636
printf(" <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n");
3737
printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]);
3838
return 1 ;
@@ -41,6 +41,7 @@ int main(int argc, char ** argv) {
4141
int n_kv_max = 2048;
4242
int n_batch = 2048;
4343
int n_ubatch = 512;
44+
bool flash_attn = false;
4445
int is_pp_shared = 0;
4546
int n_gpu_layers = 0;
4647

@@ -66,23 +67,27 @@ int main(int argc, char ** argv) {
6667
}
6768

6869
if (argc >= 6) {
69-
is_pp_shared = std::atoi(argv[5]);
70+
flash_attn = std::atoi(argv[5]);
7071
}
7172

7273
if (argc >= 7) {
73-
n_gpu_layers = std::atoi(argv[6]);
74+
is_pp_shared = std::atoi(argv[6]);
7475
}
7576

7677
if (argc >= 8) {
77-
n_pp = parse_list(argv[7]);
78+
n_gpu_layers = std::atoi(argv[7]);
7879
}
7980

8081
if (argc >= 9) {
81-
n_tg = parse_list(argv[8]);
82+
n_pp = parse_list(argv[8]);
8283
}
8384

8485
if (argc >= 10) {
85-
n_pl = parse_list(argv[9]);
86+
n_tg = parse_list(argv[9]);
87+
}
88+
89+
if (argc >= 11) {
90+
n_pl = parse_list(argv[10]);
8691
}
8792

8893
// init LLM
@@ -108,10 +113,11 @@ int main(int argc, char ** argv) {
108113

109114
llama_context_params ctx_params = llama_context_default_params();
110115

111-
ctx_params.seed = 1234;
112-
ctx_params.n_ctx = n_kv_max;
113-
ctx_params.n_batch = n_batch;
114-
ctx_params.n_ubatch = n_ubatch;
116+
ctx_params.seed = 1234;
117+
ctx_params.n_ctx = n_kv_max;
118+
ctx_params.n_batch = n_batch;
119+
ctx_params.n_ubatch = n_ubatch;
120+
ctx_params.flash_attn = flash_attn;
115121

116122
ctx_params.n_threads = params.n_threads;
117123
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
@@ -169,7 +175,7 @@ int main(int argc, char ** argv) {
169175
}
170176

171177
LOG_TEE("\n");
172-
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
178+
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, flash_attn, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
173179
LOG_TEE("\n");
174180

175181
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");

0 commit comments

Comments
 (0)