@@ -349,9 +349,11 @@ def __init__(
349
349
if not os .path .exists (model_path ):
350
350
raise ValueError (f"Model path does not exist: { model_path } " )
351
351
352
- self ._model = _LlamaModel (
352
+ self ._stack = contextlib .ExitStack ()
353
+
354
+ self ._model = self ._stack .enter_context (_LlamaModel (
353
355
path_model = self .model_path , params = self .model_params , verbose = self .verbose
354
- )
356
+ ))
355
357
356
358
# Override tokenizer
357
359
self .tokenizer_ = tokenizer or LlamaTokenizer (self )
@@ -363,18 +365,18 @@ def __init__(
363
365
self .context_params .n_ctx = self ._model .n_ctx_train ()
364
366
self .context_params .n_batch = self .n_batch
365
367
366
- self ._ctx = _LlamaContext (
368
+ self ._ctx = self . _stack . enter_context ( _LlamaContext (
367
369
model = self ._model ,
368
370
params = self .context_params ,
369
371
verbose = self .verbose ,
370
- )
372
+ ))
371
373
372
- self ._batch = _LlamaBatch (
374
+ self ._batch = self . _stack . enter_context ( _LlamaBatch (
373
375
n_tokens = self .n_batch ,
374
376
embd = 0 ,
375
377
n_seq_max = self .context_params .n_ctx ,
376
378
verbose = self .verbose ,
377
- )
379
+ ))
378
380
379
381
if self .lora_path :
380
382
if self ._model .apply_lora_from_file (
@@ -1945,15 +1947,15 @@ def pooling_type(self) -> str:
1945
1947
1946
1948
def close (self ) -> None :
1947
1949
"""Explicitly free the model from memory."""
1948
- self ._model .close ()
1950
+ self ._stack .close ()
1949
1951
1950
1952
def __exit__ (
1951
1953
self ,
1952
- __exc_type : Optional [Type [BaseException ]],
1953
- __exc_value : Optional [BaseException ],
1954
- __traceback : Optional [TracebackType ]
1954
+ exc_type : Optional [Type [BaseException ]],
1955
+ exc_value : Optional [BaseException ],
1956
+ traceback : Optional [TracebackType ]
1955
1957
) -> Optional [bool ]:
1956
- return self .close ( )
1958
+ return self ._stack . __exit__ ( exc_type , exc_value , traceback )
1957
1959
1958
1960
@staticmethod
1959
1961
def logits_to_logprobs (
0 commit comments