@@ -349,9 +349,11 @@ def __init__(
349349 if not os .path .exists (model_path ):
350350 raise ValueError (f"Model path does not exist: { model_path } " )
351351
352- self ._model = _LlamaModel (
352+ self ._stack = contextlib .ExitStack ()
353+
354+ self ._model = self ._stack .enter_context (_LlamaModel (
353355 path_model = self .model_path , params = self .model_params , verbose = self .verbose
354- )
356+ ))
355357
356358 # Override tokenizer
357359 self .tokenizer_ = tokenizer or LlamaTokenizer (self )
@@ -363,18 +365,18 @@ def __init__(
363365 self .context_params .n_ctx = self ._model .n_ctx_train ()
364366 self .context_params .n_batch = self .n_batch
365367
366- self ._ctx = _LlamaContext (
368+ self ._ctx = self . _stack . enter_context ( _LlamaContext (
367369 model = self ._model ,
368370 params = self .context_params ,
369371 verbose = self .verbose ,
370- )
372+ ))
371373
372- self ._batch = _LlamaBatch (
374+ self ._batch = self . _stack . enter_context ( _LlamaBatch (
373375 n_tokens = self .n_batch ,
374376 embd = 0 ,
375377 n_seq_max = self .context_params .n_ctx ,
376378 verbose = self .verbose ,
377- )
379+ ))
378380
379381 if self .lora_path :
380382 if self ._model .apply_lora_from_file (
@@ -1945,15 +1947,15 @@ def pooling_type(self) -> str:
19451947
19461948 def close (self ) -> None :
19471949 """Explicitly free the model from memory."""
1948- self ._model .close ()
1950+ self ._stack .close ()
19491951
19501952 def __exit__ (
19511953 self ,
1952- __exc_type : Optional [Type [BaseException ]],
1953- __exc_value : Optional [BaseException ],
1954- __traceback : Optional [TracebackType ]
1954+ exc_type : Optional [Type [BaseException ]],
1955+ exc_value : Optional [BaseException ],
1956+ traceback : Optional [TracebackType ]
19551957 ) -> Optional [bool ]:
1956- return self .close ( )
1958+ return self ._stack . __exit__ ( exc_type , exc_value , traceback )
19571959
19581960 @staticmethod
19591961 def logits_to_logprobs (
0 commit comments