Skip to content

Commit 76e1b63

Browse files
author
lexasub
committed
train: add simple loading already tokenized data from parquet dataset
1 parent bee2842 commit 76e1b63

File tree

8 files changed

+141
-15
lines changed

8 files changed

+141
-15
lines changed

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
8484
# 3rd party libs
8585
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
8686
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
87+
option(LLAMA_PARQUET "Enable Parquet dataset support via Arrow/Parquet C++" OFF)
8788

8889
# Required for relocatable CMake package
8990
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
@@ -173,6 +174,12 @@ if (MINGW)
173174
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
174175
endif()
175176

177+
if(LLAMA_PARQUET)
178+
find_package(Arrow REQUIRED)
179+
find_package(Parquet REQUIRED)
180+
add_definitions(-DLLAMA_PARQUET)
181+
endif()
182+
176183
#
177184
# build the library
178185
#

common/arg.cpp

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,14 +1470,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14701470
[](common_params & params) {
14711471
params.ctx_shift = false;
14721472
}
1473-
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
1473+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
14741474
add_opt(common_arg(
14751475
{"--chunks"}, "N",
14761476
string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
14771477
[](common_params & params, int value) {
14781478
params.n_chunks = value;
14791479
}
1480-
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
1480+
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE, LLAMA_EXAMPLE_RETRIEVAL}));
14811481
add_opt(common_arg(
14821482
{"-fa", "--flash-attn"},
14831483
string_format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"),
@@ -2115,70 +2115,70 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
21152115
[](common_params & params) {
21162116
params.hellaswag = true;
21172117
}
2118-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2118+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21192119
add_opt(common_arg(
21202120
{"--hellaswag-tasks"}, "N",
21212121
string_format("number of tasks to use when computing the HellaSwag score (default: %zu)", params.hellaswag_tasks),
21222122
[](common_params & params, int value) {
21232123
params.hellaswag_tasks = value;
21242124
}
2125-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2125+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21262126
add_opt(common_arg(
21272127
{"--winogrande"},
21282128
"compute Winogrande score over random tasks from datafile supplied with -f",
21292129
[](common_params & params) {
21302130
params.winogrande = true;
21312131
}
2132-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2132+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21332133
add_opt(common_arg(
21342134
{"--winogrande-tasks"}, "N",
21352135
string_format("number of tasks to use when computing the Winogrande score (default: %zu)", params.winogrande_tasks),
21362136
[](common_params & params, int value) {
21372137
params.winogrande_tasks = value;
21382138
}
2139-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2139+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21402140
add_opt(common_arg(
21412141
{"--multiple-choice"},
21422142
"compute multiple choice score over random tasks from datafile supplied with -f",
21432143
[](common_params & params) {
21442144
params.multiple_choice = true;
21452145
}
2146-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2146+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21472147
add_opt(common_arg(
21482148
{"--multiple-choice-tasks"}, "N",
21492149
string_format("number of tasks to use when computing the multiple choice score (default: %zu)", params.multiple_choice_tasks),
21502150
[](common_params & params, int value) {
21512151
params.multiple_choice_tasks = value;
21522152
}
2153-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2153+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21542154
add_opt(common_arg(
21552155
{"--kl-divergence"},
21562156
"computes KL-divergence to logits provided via --kl-divergence-base",
21572157
[](common_params & params) {
21582158
params.kl_divergence = true;
21592159
}
2160-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2160+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21612161
add_opt(common_arg(
21622162
{"--save-all-logits", "--kl-divergence-base"}, "FNAME",
21632163
"set logits file",
21642164
[](common_params & params, const std::string & value) {
21652165
params.logits_file = value;
21662166
}
2167-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2167+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21682168
add_opt(common_arg(
21692169
{"--ppl-stride"}, "N",
21702170
string_format("stride for perplexity calculation (default: %d)", params.ppl_stride),
21712171
[](common_params & params, int value) {
21722172
params.ppl_stride = value;
21732173
}
2174-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2174+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21752175
add_opt(common_arg(
21762176
{"--ppl-output-type"}, "<0|1>",
21772177
string_format("output type for perplexity calculation (default: %d)", params.ppl_output_type),
21782178
[](common_params & params, int value) {
21792179
params.ppl_output_type = value;
21802180
}
2181-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2181+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21822182
add_opt(common_arg(
21832183
{"-dt", "--defrag-thold"}, "N",
21842184
string_format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold),
@@ -3415,6 +3415,30 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
34153415
params.n_cache_reuse = 256;
34163416
}
34173417
).set_examples({LLAMA_EXAMPLE_SERVER}));
3418+
#ifdef LLAMA_PARQUET
3419+
add_opt(common_arg(
3420+
{"--dataset-format"}, "text",
3421+
string_format("Dataset format: text or parquet (requires LLAMA_PARQUET)"),
3422+
[](common_params & params, const std::string & format) {
3423+
params.dataset_format = format; //or parquet//TODO ENUM CLASS
3424+
}
3425+
).set_examples({LLAMA_EXAMPLE_FINETUNE}));
3426+
3427+
add_opt(common_arg(
3428+
{"--parquet-path"}, "parquet.parquet",
3429+
string_format("Parquet path"),
3430+
[](common_params & params, const std::string & filepath) {//TODO -read dir
3431+
params.parquet_path = filepath;
3432+
}
3433+
).set_examples({LLAMA_EXAMPLE_FINETUNE}));
34183434

3435+
add_opt(common_arg(
3436+
{"--tokens-column"}, "tokens",
3437+
string_format("Name of tokens column (list<int32>) in Parquet file"),
3438+
[](common_params & params, const std::string & column) {
3439+
params.tokens_column = column;
3440+
}
3441+
).set_examples({LLAMA_EXAMPLE_FINETUNE}));
3442+
#endif
34193443
return ctx_arg;
34203444
}

common/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ enum llama_example {
8383
LLAMA_EXAMPLE_TTS,
8484

8585
LLAMA_EXAMPLE_COUNT,
86+
LLAMA_EXAMPLE_FINETUNE,
8687
};
8788

8889
enum common_sampler_type {
@@ -282,6 +283,9 @@ struct common_params {
282283
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
283284
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
284285
std::string logits_file = ""; // file for saving *all* logits // NOLINT
286+
std::string dataset_format = "text"; // "text" | "parquet"
287+
std::string parquet_path; // path to Parquet
288+
std::string tokens_column = "tokens"; // name column list<int32>
285289

286290
std::vector<std::string> in_files; // all input files
287291
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)

examples/training/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,21 @@ Finetuning of Stories 260K and LLaMA 3.2 1b seems to work with 24 GB of memory.
88

99
Proof of concept:
1010

11+
With load data from common file:
12+
1113
``` sh
1214
export model_name=llama_3.2-1b && export quantization=f32
1315
./build/bin/llama-finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
1416
./build/bin/llama-perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
1517
```
1618

19+
With load data from parquet (without batching):
20+
21+
You need install arrow package and build with LLAMA_PARQUET=ON
22+
23+
``` sh
24+
mkdir build; cmake -DLLAMA_PARQUET=ON .. ; make
25+
export model_name=llama_3.2-1b && export quantization=f32
26+
./build/bin/llama-finetune -ngl 999 --dataset-format parquet --parquet-path parquet.parquet --tokens-column tokens --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
27+
```
1728
The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs.

examples/training/finetune.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "common.h"
33
#include "log.h"
44
#include "llama.h"
5+
#include "../../src/parquet_dataset.h"
56

67
#include <cmath>
78
#include <cstdio>
@@ -18,7 +19,7 @@ int main(int argc, char ** argv) {
1819

1920
params.escape = false;
2021

21-
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
22+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
2223
return 1;
2324
}
2425

@@ -57,7 +58,23 @@ int main(int argc, char ** argv) {
5758

5859
constexpr float val_split = 0.05f;
5960

60-
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
61+
std::vector<llama_token> tokens;
62+
#ifdef LLAMA_PARQUET
63+
if (params.dataset_format == "text") {
64+
#endif
65+
tokens = common_tokenize(ctx.get(), params.prompt, true); //load from text file
66+
#ifdef LLAMA_PARQUET
67+
}
68+
else if (params.dataset_format == "parquet") {
69+
tokens = load_parquet_dataset(params.parquet_path, params.tokens_column);
70+
if (tokens.empty()) {
71+
LOG_ERR("No tokens in %s, or column %s not found/invalid", params.parquet_path.c_str(), params.tokens_column.c_str());
72+
return 1;
73+
}
74+
LOG_INF("Loaded %zu tokens from Parquet", tokens.size());
75+
}
76+
#endif
77+
6178
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
6279

6380
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);

src/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ add_library(llama
3232
llama-quant.cpp
3333
llama-sampling.cpp
3434
llama-vocab.cpp
35+
parquet_dataset.cpp
3536
unicode-data.cpp
3637
unicode.cpp
3738
unicode.h
@@ -41,7 +42,12 @@ target_include_directories(llama PRIVATE .)
4142
target_include_directories(llama PUBLIC ../include)
4243
target_compile_features (llama PRIVATE cxx_std_17) # don't bump
4344

44-
target_link_libraries(llama PUBLIC ggml)
45+
46+
if(LLAMA_PARQUET)
47+
target_link_libraries(llama PUBLIC ggml Arrow::arrow_shared Parquet::parquet_shared)
48+
else()
49+
target_link_libraries(llama PUBLIC ggml)
50+
endif()
4551

4652
if (BUILD_SHARED_LIBS)
4753
set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON)

src/parquet_dataset.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#ifdef LLAMA_PARQUET
2+
#include "parquet_dataset.h"
3+
#include <arrow/api.h>
4+
#include <arrow/io/file.h>
5+
#include <parquet/arrow/reader.h>
6+
#include "llama-impl.h"
7+
8+
std::vector<llama_token> load_parquet_dataset(const std::string &path, const std::string &column) {
9+
arrow::MemoryPool *pool = arrow::default_memory_pool();
10+
std::shared_ptr<arrow::io::RandomAccessFile> infile;
11+
PARQUET_ASSIGN_OR_THROW(infile, arrow::io::ReadableFile::Open(path));
12+
arrow::Result<std::unique_ptr<parquet::arrow::FileReader>> reader_raw;
13+
PARQUET_ASSIGN_OR_THROW(reader_raw, parquet::arrow::OpenFile(infile, pool));
14+
15+
std::unique_ptr<parquet::arrow::FileReader> reader = std::move(reader_raw.ValueUnsafe());
16+
std::shared_ptr<arrow::Table> table;
17+
PARQUET_THROW_NOT_OK(reader->ReadTable(&table));
18+
19+
auto field = table->schema()->GetFieldByName(column);
20+
if (!field || !field->type()->Equals(arrow::list(arrow::int32()))) {
21+
LLAMA_LOG_ERROR("Parquet column '%s' missing or not list<int32>", column.c_str());
22+
return {};
23+
}
24+
25+
auto col = table->GetColumnByName(column);
26+
std::vector<llama_token> tokens;
27+
for (int chunk = 0; chunk < col->num_chunks(); ++chunk) {
28+
auto list_arr = std::static_pointer_cast<arrow::ListArray>(col->chunk(chunk));
29+
auto values_arr = std::static_pointer_cast<arrow::Int32Array>(list_arr->values());
30+
// get raw offsets (int32_t or int64_t based on ListArray template)
31+
const auto *offsets = list_arr->raw_value_offsets();
32+
// offsets length = list_arr->length() + 1
33+
int64_t values_length = values_arr->length();
34+
for (int64_t i = 0; i < list_arr->length(); ++i) {
35+
int64_t start = offsets[i];
36+
int64_t end = offsets[i + 1];
37+
// Clamp end
38+
if (start < 0) start = 0;
39+
if (end > values_length) end = values_length;
40+
for (int64_t j = start; j < end; ++j) {
41+
tokens.push_back(static_cast<llama_token>(values_arr->Value(j)));
42+
}
43+
}
44+
}
45+
return tokens;
46+
}
47+
#endif // LLAMA_PARQUET

src/parquet_dataset.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#ifndef PARQUET_DATASET_H
2+
#define PARQUET_DATASET_H
3+
#include <string>
4+
#include <vector>
5+
#include "llama.h"
6+
7+
#ifdef LLAMA_PARQUET
8+
std::vector<llama_token> load_parquet_dataset(const std::string &path, const std::string &column);
9+
#endif
10+
#endif //

0 commit comments

Comments
 (0)