Skip to content

Commit 0dea80e

Browse files
committed
Avoid contexts being destroyed in thread safety test until threads join
1 parent 600e3e9 commit 0dea80e

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

tests/test-thread-safety.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,30 @@
66
#include <thread>
77
#include <vector>
88
#include <atomic>
9-
#include "llama.h"
9+
#include <functional>
10+
#include <mutex>
11+
#include <vector>
1012
#include "arg.h"
1113
#include "common.h"
1214
#include "log.h"
1315
#include "sampling.h"
1416

17+
// Define a scope guard for RAII lifetime management
18+
class ScopeGuard {
19+
public:
20+
template<class Callable>
21+
ScopeGuard(Callable&& func) : m_func(std::forward<Callable>(func)) {}
22+
~ScopeGuard() { if (m_func) m_func(); }
23+
ScopeGuard(const ScopeGuard&) = delete;
24+
ScopeGuard& operator=(const ScopeGuard&) = delete;
25+
ScopeGuard(ScopeGuard&& other) noexcept : m_func(std::move(other.m_func)) {
26+
other.m_func = nullptr;
27+
}
28+
private:
29+
std::function<void()> m_func;
30+
};
31+
32+
1533
int main(int argc, char ** argv) {
1634
common_params params;
1735

@@ -72,6 +90,10 @@ int main(int argc, char ** argv) {
7290
models.emplace_back(model);
7391
}
7492

93+
94+
std::vector<llama_context_ptr> kept_contexts; // Stores contexts after thread exit
95+
std::mutex kept_contexts_mutex; // Protects kept_contexts
96+
7597
for (int m = 0; m < num_models; ++m) {
7698
auto * model = models[m].get();
7799
for (int c = 0; c < num_contexts; ++c) {
@@ -85,6 +107,12 @@ int main(int argc, char ** argv) {
85107
return;
86108
}
87109

110+
// Scope guard moves ctx to kept_contexts when thread exits
111+
ScopeGuard guard([&] {
112+
std::lock_guard<std::mutex> lock(kept_contexts_mutex);
113+
kept_contexts.push_back(std::move(ctx));
114+
});
115+
88116
std::unique_ptr<common_sampler, decltype(&common_sampler_free)> sampler { common_sampler_init(model, params.sampling), common_sampler_free };
89117
if (sampler == NULL) {
90118
LOG_ERR("failed to create sampler\n");
@@ -142,6 +170,8 @@ int main(int argc, char ** argv) {
142170
thread.join();
143171
}
144172

173+
kept_contexts.clear();
174+
145175
if (failed) {
146176
LOG_ERR("One or more threads failed.\n");
147177
return 1;

0 commit comments

Comments
 (0)