11from __future__ import annotations
22
3+ import contextlib
34import os
45import ctypes
6+ from types import TracebackType
57
68from typing import (
79 List ,
810 Optional ,
911 Sequence ,
12+ Type ,
1013)
1114from dataclasses import dataclass , field
1215
2326# Python wrappers over llama.h structs
2427
2528
26- class _LlamaModel :
29+ class _LlamaModel ( contextlib . AbstractContextManager ) :
2730 """Intermediate Python wrapper for a llama.cpp llama_model.
2831 NOTE: For stability it's recommended you use the Llama class instead."""
2932
@@ -59,6 +62,14 @@ def __init__(
5962 def __del__ (self ) -> None :
6063 self .close ()
6164
65+ def __exit__ (
66+ self ,
67+ __exc_type : Optional [Type [BaseException ]],
68+ __exc_value : Optional [BaseException ],
69+ __traceback : Optional [TracebackType ]
70+ ) -> Optional [bool ]:
71+ return self .close ()
72+
6273 def close (self ) -> None :
6374 if self .model is not None and self ._llama_free_model is not None :
6475 self ._llama_free_model (self .model )
@@ -248,7 +259,7 @@ def default_params():
248259 return llama_cpp .llama_model_default_params ()
249260
250261
251- class _LlamaContext :
262+ class _LlamaContext ( contextlib . AbstractContextManager ) :
252263 """Intermediate Python wrapper for a llama.cpp llama_context.
253264 NOTE: For stability it's recommended you use the Llama class instead."""
254265
@@ -277,7 +288,18 @@ def __init__(
277288 if self .ctx is None :
278289 raise ValueError ("Failed to create llama_context" )
279290
280- def __del__ (self ):
291+ def __del__ (self ) -> None :
292+ self .close ()
293+
294+ def __exit__ (
295+ self ,
296+ __exc_type : Optional [Type [BaseException ]],
297+ __exc_value : Optional [BaseException ],
298+ __traceback : Optional [TracebackType ]
299+ ) -> Optional [bool ]:
300+ return self .close ()
301+
302+ def close (self ) -> None :
281303 if self .ctx is not None and self ._llama_free is not None :
282304 self ._llama_free (self .ctx )
283305 self .ctx = None
@@ -495,7 +517,7 @@ def default_params():
495517 return llama_cpp .llama_context_default_params ()
496518
497519
498- class _LlamaBatch :
520+ class _LlamaBatch ( contextlib . AbstractContextManager ) :
499521 _llama_batch_free = None
500522
501523 def __init__ (
@@ -513,7 +535,18 @@ def __init__(
513535 self ._n_tokens , self .embd , self .n_seq_max
514536 )
515537
516- def __del__ (self ):
538+ def __del__ (self ) -> None :
539+ self .close ()
540+
541+ def __exit__ (
542+ self ,
543+ __exc_type : Optional [Type [BaseException ]],
544+ __exc_value : Optional [BaseException ],
545+ __traceback : Optional [TracebackType ]
546+ ) -> Optional [bool ]:
547+ return self .close ()
548+
549+ def close (self ) -> None :
517550 if self .batch is not None and self ._llama_batch_free is not None :
518551 self ._llama_batch_free (self .batch )
519552 self .batch = None
0 commit comments