Skip to content

Commit eb82354

Browse files
committed
feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch
This commit enables automatic resource management by implementing the `ContextManager` protocol in `_LlamaModel`, `_LlamaContext`, and `_LlamaBatch`. This ensures that resources are properly managed and released within a `with` statement, enhancing robustness and safety in resource handling.
1 parent e38f5e2 commit eb82354

File tree

1 file changed

+38
-5
lines changed

1 file changed

+38
-5
lines changed

llama_cpp/_internals.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import os
45
import ctypes
6+
from types import TracebackType
57

68
from typing import (
79
List,
810
Optional,
911
Sequence,
12+
Type,
1013
)
1114
from dataclasses import dataclass, field
1215

@@ -23,7 +26,7 @@
2326
# Python wrappers over llama.h structs
2427

2528

26-
class _LlamaModel:
29+
class _LlamaModel(contextlib.AbstractContextManager):
2730
"""Intermediate Python wrapper for a llama.cpp llama_model.
2831
NOTE: For stability it's recommended you use the Llama class instead."""
2932

@@ -59,6 +62,14 @@ def __init__(
5962
def __del__(self) -> None:
6063
self.close()
6164

65+
def __exit__(
66+
self,
67+
__exc_type: Optional[Type[BaseException]],
68+
__exc_value: Optional[BaseException],
69+
__traceback: Optional[TracebackType]
70+
) -> Optional[bool]:
71+
return self.close()
72+
6273
def close(self) -> None:
6374
if self.model is not None and self._llama_free_model is not None:
6475
self._llama_free_model(self.model)
@@ -248,7 +259,7 @@ def default_params():
248259
return llama_cpp.llama_model_default_params()
249260

250261

251-
class _LlamaContext:
262+
class _LlamaContext(contextlib.AbstractContextManager):
252263
"""Intermediate Python wrapper for a llama.cpp llama_context.
253264
NOTE: For stability it's recommended you use the Llama class instead."""
254265

@@ -277,7 +288,18 @@ def __init__(
277288
if self.ctx is None:
278289
raise ValueError("Failed to create llama_context")
279290

280-
def __del__(self):
291+
def __del__(self) -> None:
292+
self.close()
293+
294+
def __exit__(
295+
self,
296+
__exc_type: Optional[Type[BaseException]],
297+
__exc_value: Optional[BaseException],
298+
__traceback: Optional[TracebackType]
299+
) -> Optional[bool]:
300+
return self.close()
301+
302+
def close(self) -> None:
281303
if self.ctx is not None and self._llama_free is not None:
282304
self._llama_free(self.ctx)
283305
self.ctx = None
@@ -495,7 +517,7 @@ def default_params():
495517
return llama_cpp.llama_context_default_params()
496518

497519

498-
class _LlamaBatch:
520+
class _LlamaBatch(contextlib.AbstractContextManager):
499521
_llama_batch_free = None
500522

501523
def __init__(
@@ -513,7 +535,18 @@ def __init__(
513535
self._n_tokens, self.embd, self.n_seq_max
514536
)
515537

516-
def __del__(self):
538+
def __del__(self) -> None:
539+
self.close()
540+
541+
def __exit__(
542+
self,
543+
__exc_type: Optional[Type[BaseException]],
544+
__exc_value: Optional[BaseException],
545+
__traceback: Optional[TracebackType]
546+
) -> Optional[bool]:
547+
return self.close()
548+
549+
def close(self) -> None:
517550
if self.batch is not None and self._llama_batch_free is not None:
518551
self._llama_batch_free(self.batch)
519552
self.batch = None

0 commit comments

Comments
 (0)