Skip to content

Commit 97563d3

Browse files
committed
llama : use a thread pool for quantize -> 1.3% faster
1 parent 2238915 commit 97563d3

File tree

3 files changed

+100
-13
lines changed

3 files changed

+100
-13
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
473473

474474
OBJS += ggml-alloc.o
475475

476-
llama.o: llama.cpp ggml.h ggml-alloc.h ggml-cuda.h ggml-metal.h llama.h
476+
llama.o: llama.cpp ggml.h ggml-alloc.h ggml-cuda.h ggml-metal.h llama.h threadpool.h
477477
$(CXX) $(CXXFLAGS) -c $< -o $@
478478

479479
common.o: common/common.cpp common/common.h build-info.h common/log.h

llama.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
#include "llama.h"
2-
32
#include "ggml.h"
4-
53
#include "ggml-alloc.h"
4+
#include "threadpool.h"
65

76
#ifdef GGML_USE_CUBLAS
87
# include "ggml-cuda.h"
@@ -60,6 +59,7 @@
6059
#include <cstring>
6160
#include <ctime>
6261
#include <fstream>
62+
#include <functional>
6363
#include <initializer_list>
6464
#include <map>
6565
#include <memory>
@@ -4640,8 +4640,8 @@ struct no_init {
46404640
};
46414641

46424642
static void llama_convert_tensor_internal(
4643-
struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
4644-
const size_t nelements, const int nthread
4643+
thread_pool<void> & pool, struct ggml_tensor * tensor, std::vector<no_init<float>> & output, const size_t nelements,
4644+
const int nthread
46454645
) {
46464646
if (output.size() < nelements) {
46474647
output.resize(nelements);
@@ -4677,6 +4677,8 @@ static void llama_convert_tensor_internal(
46774677
auto blocks_per_thread = nblocks / nthread;
46784678
auto spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
46794679

4680+
std::vector<std::future<void>> workers;
4681+
workers.reserve(nthread);
46804682
for (auto tnum = 0, in_buff_offs = 0, out_buff_offs = 0; tnum < nthread; tnum++) {
46814683
auto thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
46824684
auto thr_elems = thr_blocks * block_size; // number of elements for this thread
@@ -4689,12 +4691,14 @@ static void llama_convert_tensor_internal(
46894691
qtype.to_float(inbuf, outbuf, nels);
46904692
}
46914693
};
4692-
workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems);
4694+
auto future = pool.push(std::bind(
4695+
compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems
4696+
));
4697+
workers.push_back(std::move(future));
46934698
in_buff_offs += thr_block_bytes;
46944699
out_buff_offs += thr_elems;
46954700
}
4696-
for (auto & w : workers) { w.join(); }
4697-
workers.clear();
4701+
for (auto & w : workers) { w.wait(); }
46984702
}
46994703

47004704
#ifdef GGML_USE_K_QUANTS
@@ -4892,8 +4896,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
48924896
size_t total_size_new = 0;
48934897
std::vector<int64_t> hist_all(1 << 4, 0);
48944898

4895-
std::vector<std::thread> workers;
4896-
workers.reserve(nthread);
4899+
std::vector<std::future<void>> workers;
4900+
workers.reserve(nthread - 1);
4901+
thread_pool<void> pool(nthread);
48974902
std::mutex mutex;
48984903

48994904
int idx = 0;
@@ -4974,7 +4979,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
49744979
} else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
49754980
throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
49764981
} else {
4977-
llama_convert_tensor_internal(tensor, f32_conv_buf, workers, nelements, nthread);
4982+
llama_convert_tensor_internal(pool, tensor, f32_conv_buf, nelements, nthread);
49784983
f32_data = (float *) f32_conv_buf.data();
49794984
}
49804985

@@ -5016,10 +5021,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
50165021
}
50175022
};
50185023
for (int it = 0; it < nthread_use - 1; ++it) {
5019-
workers.emplace_back(compute);
5024+
workers.push_back(pool.push(compute));
50205025
}
50215026
compute();
5022-
for (auto & w : workers) { w.join(); }
5027+
for (auto & w : workers) { w.wait(); }
50235028
workers.clear();
50245029
}
50255030

threadpool.h

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include <atomic>
2+
#include <condition_variable>
3+
#include <cassert>
4+
#include <cstddef>
5+
#include <functional>
6+
#include <future>
7+
#include <mutex>
8+
#include <queue>
9+
#include <thread>
10+
#include <utility>
11+
#include <vector>
12+
13+
template <typename R>
14+
class thread_pool {
15+
static constexpr size_t max_tasks = 1024;
16+
using task = std::packaged_task<R ()>;
17+
std::atomic<bool> m_abort = {};
18+
std::vector<std::thread> m_threads;
19+
std::queue<task> m_tasks;
20+
std::condition_variable m_ready, m_not_full;
21+
std::mutex m_mutex;
22+
23+
public:
24+
explicit thread_pool(size_t nthreads);
25+
~thread_pool();
26+
27+
std::future<R> push(std::function<R ()> f);
28+
29+
thread_pool(thread_pool &) = delete;
30+
thread_pool(thread_pool &&) = delete;
31+
void operator=(thread_pool &) = delete;
32+
void operator=(thread_pool &&) = delete;
33+
34+
protected:
35+
void process_tasks();
36+
};
37+
38+
template <typename R>
39+
thread_pool<R>::thread_pool(size_t nthreads) {
40+
m_threads.reserve(nthreads);
41+
for (size_t i = 0; i < nthreads; ++i) {
42+
m_threads.emplace_back([this]() { process_tasks(); });
43+
}
44+
}
45+
46+
template <typename R>
47+
thread_pool<R>::~thread_pool() {
48+
m_abort = true;
49+
m_ready.notify_all();
50+
for (auto & thread : m_threads) { thread.join(); }
51+
}
52+
53+
template <typename R>
54+
std::future<R> thread_pool<R>::push(std::function<R ()> f) {
55+
auto t = task(std::move(f));
56+
auto r = t.get_future();
57+
std::unique_lock<std::mutex> lock(m_mutex);
58+
m_not_full.wait(lock, [this]() { return m_tasks.size() < max_tasks; });
59+
m_tasks.emplace(std::move(t));
60+
lock.unlock();
61+
m_ready.notify_one();
62+
return r;
63+
}
64+
65+
template <typename R>
66+
void thread_pool<R>::process_tasks() {
67+
task t;
68+
for (;;) {
69+
std::unique_lock<std::mutex> lock(m_mutex);
70+
for (;;) {
71+
if (m_abort) { return; }
72+
if (!m_tasks.empty()) { break; }
73+
m_ready.wait(lock);
74+
}
75+
t = std::move(m_tasks.front());
76+
m_tasks.pop();
77+
lock.unlock();
78+
m_not_full.notify_one();
79+
80+
t();
81+
}
82+
}

0 commit comments

Comments
 (0)