Skip to content

Commit fa702b4

Browse files
committed
feat: add ExitStack for Llama's internal class closure
This update implements ExitStack to manage and close internal classes in Llama, enhancing efficient and safe resource management.
1 parent eb82354 commit fa702b4

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

llama_cpp/llama.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,11 @@ def __init__(
349349
if not os.path.exists(model_path):
350350
raise ValueError(f"Model path does not exist: {model_path}")
351351

352-
self._model = _LlamaModel(
352+
self._stack = contextlib.ExitStack()
353+
354+
self._model = self._stack.enter_context(_LlamaModel(
353355
path_model=self.model_path, params=self.model_params, verbose=self.verbose
354-
)
356+
))
355357

356358
# Override tokenizer
357359
self.tokenizer_ = tokenizer or LlamaTokenizer(self)
@@ -363,18 +365,18 @@ def __init__(
363365
self.context_params.n_ctx = self._model.n_ctx_train()
364366
self.context_params.n_batch = self.n_batch
365367

366-
self._ctx = _LlamaContext(
368+
self._ctx = self._stack.enter_context(_LlamaContext(
367369
model=self._model,
368370
params=self.context_params,
369371
verbose=self.verbose,
370-
)
372+
))
371373

372-
self._batch = _LlamaBatch(
374+
self._batch = self._stack.enter_context(_LlamaBatch(
373375
n_tokens=self.n_batch,
374376
embd=0,
375377
n_seq_max=self.context_params.n_ctx,
376378
verbose=self.verbose,
377-
)
379+
))
378380

379381
if self.lora_path:
380382
if self._model.apply_lora_from_file(
@@ -1945,15 +1947,15 @@ def pooling_type(self) -> str:
19451947

19461948
def close(self) -> None:
19471949
"""Explicitly free the model from memory."""
1948-
self._model.close()
1950+
self._stack.close()
19491951

19501952
def __exit__(
19511953
self,
1952-
__exc_type: Optional[Type[BaseException]],
1953-
__exc_value: Optional[BaseException],
1954-
__traceback: Optional[TracebackType]
1954+
exc_type: Optional[Type[BaseException]],
1955+
exc_value: Optional[BaseException],
1956+
traceback: Optional[TracebackType]
19551957
) -> Optional[bool]:
1956-
return self.close()
1958+
return self._stack.__exit__(exc_type, exc_value, traceback)
19571959

19581960
@staticmethod
19591961
def logits_to_logprobs(

0 commit comments

Comments
 (0)