66#include < thread>
77#include < vector>
88#include < atomic>
9- #include " llama.h"
9+ #include < functional>
10+ #include < mutex>
11+ #include < vector>
1012#include " arg.h"
1113#include " common.h"
1214#include " log.h"
1315#include " sampling.h"
1416
17+ // Define a scope guard for RAII lifetime management
18+ class ScopeGuard {
19+ public:
20+ template <class Callable >
21+ ScopeGuard (Callable&& func) : m_func(std::forward<Callable>(func)) {}
22+ ~ScopeGuard () { if (m_func) m_func (); }
23+ ScopeGuard (const ScopeGuard&) = delete ;
24+ ScopeGuard& operator =(const ScopeGuard&) = delete ;
25+ ScopeGuard (ScopeGuard&& other) noexcept : m_func(std::move(other.m_func)) {
26+ other.m_func = nullptr ;
27+ }
28+ private:
29+ std::function<void ()> m_func;
30+ };
31+
32+
1533int main (int argc, char ** argv) {
1634 common_params params;
1735
@@ -72,6 +90,10 @@ int main(int argc, char ** argv) {
7290 models.emplace_back (model);
7391 }
7492
93+
94+ std::vector<llama_context_ptr> kept_contexts; // Stores contexts after thread exit
95+ std::mutex kept_contexts_mutex; // Protects kept_contexts
96+
7597 for (int m = 0 ; m < num_models; ++m) {
7698 auto * model = models[m].get ();
7799 for (int c = 0 ; c < num_contexts; ++c) {
@@ -85,6 +107,12 @@ int main(int argc, char ** argv) {
85107 return ;
86108 }
87109
110+ // Scope guard moves ctx to kept_contexts when thread exits
111+ ScopeGuard guard ([&] {
112+ std::lock_guard<std::mutex> lock (kept_contexts_mutex);
113+ kept_contexts.push_back (std::move (ctx));
114+ });
115+
88116 std::unique_ptr<common_sampler, decltype (&common_sampler_free)> sampler { common_sampler_init (model, params.sampling ), common_sampler_free };
89117 if (sampler == NULL ) {
90118 LOG_ERR (" failed to create sampler\n " );
@@ -142,6 +170,8 @@ int main(int argc, char ** argv) {
142170 thread.join ();
143171 }
144172
173+ kept_contexts.clear ();
174+
145175 if (failed) {
146176 LOG_ERR (" One or more threads failed.\n " );
147177 return 1 ;
0 commit comments