1
1
from __future__ import annotations
2
2
3
+ import contextlib
3
4
import os
4
5
import ctypes
6
+ from types import TracebackType
5
7
6
8
from typing import (
7
9
List ,
8
10
Optional ,
9
11
Sequence ,
12
+ Type ,
10
13
)
11
14
from dataclasses import dataclass , field
12
15
23
26
# Python wrappers over llama.h structs
24
27
25
28
26
- class _LlamaModel :
29
+ class _LlamaModel ( contextlib . AbstractContextManager ) :
27
30
"""Intermediate Python wrapper for a llama.cpp llama_model.
28
31
NOTE: For stability it's recommended you use the Llama class instead."""
29
32
@@ -59,6 +62,14 @@ def __init__(
59
62
def __del__ (self ) -> None :
60
63
self .close ()
61
64
65
+ def __exit__ (
66
+ self ,
67
+ __exc_type : Optional [Type [BaseException ]],
68
+ __exc_value : Optional [BaseException ],
69
+ __traceback : Optional [TracebackType ]
70
+ ) -> Optional [bool ]:
71
+ return self .close ()
72
+
62
73
def close (self ) -> None :
63
74
if self .model is not None and self ._llama_free_model is not None :
64
75
self ._llama_free_model (self .model )
@@ -248,7 +259,7 @@ def default_params():
248
259
return llama_cpp .llama_model_default_params ()
249
260
250
261
251
- class _LlamaContext :
262
+ class _LlamaContext ( contextlib . AbstractContextManager ) :
252
263
"""Intermediate Python wrapper for a llama.cpp llama_context.
253
264
NOTE: For stability it's recommended you use the Llama class instead."""
254
265
@@ -277,7 +288,18 @@ def __init__(
277
288
if self .ctx is None :
278
289
raise ValueError ("Failed to create llama_context" )
279
290
280
- def __del__ (self ):
291
+ def __del__ (self ) -> None :
292
+ self .close ()
293
+
294
+ def __exit__ (
295
+ self ,
296
+ __exc_type : Optional [Type [BaseException ]],
297
+ __exc_value : Optional [BaseException ],
298
+ __traceback : Optional [TracebackType ]
299
+ ) -> Optional [bool ]:
300
+ return self .close ()
301
+
302
+ def close (self ) -> None :
281
303
if self .ctx is not None and self ._llama_free is not None :
282
304
self ._llama_free (self .ctx )
283
305
self .ctx = None
@@ -495,7 +517,7 @@ def default_params():
495
517
return llama_cpp .llama_context_default_params ()
496
518
497
519
498
- class _LlamaBatch :
520
+ class _LlamaBatch ( contextlib . AbstractContextManager ) :
499
521
_llama_batch_free = None
500
522
501
523
def __init__ (
@@ -513,7 +535,18 @@ def __init__(
513
535
self ._n_tokens , self .embd , self .n_seq_max
514
536
)
515
537
516
- def __del__ (self ):
538
+ def __del__ (self ) -> None :
539
+ self .close ()
540
+
541
+ def __exit__ (
542
+ self ,
543
+ __exc_type : Optional [Type [BaseException ]],
544
+ __exc_value : Optional [BaseException ],
545
+ __traceback : Optional [TracebackType ]
546
+ ) -> Optional [bool ]:
547
+ return self .close ()
548
+
549
+ def close (self ) -> None :
517
550
if self .batch is not None and self ._llama_batch_free is not None :
518
551
self ._llama_batch_free (self .batch )
519
552
self .batch = None
0 commit comments