Skip to content

Add methods to explicitly free model from memory #1513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 32 additions & 24 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Sequence,
)
from dataclasses import dataclass, field
from contextlib import ExitStack

import numpy as np
import numpy.typing as npt
Expand All @@ -27,9 +28,6 @@ class _LlamaModel:
"""Intermediate Python wrapper for a llama.cpp llama_model.
NOTE: For stability it's recommended you use the Llama class instead."""

_llama_free_model = None
# NOTE: this must be "saved" here to avoid exceptions when calling __del__

def __init__(
self,
*,
Expand All @@ -40,8 +38,7 @@ def __init__(
self.path_model = path_model
self.params = params
self.verbose = verbose

self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore
self._exit_stack = ExitStack()

self.model = None

Expand All @@ -56,11 +53,17 @@ def __init__(
if self.model is None:
raise ValueError(f"Failed to load model from file: {path_model}")

def __del__(self):
if self.model is not None and self._llama_free_model is not None:
self._llama_free_model(self.model)
def free_model():
if self.model is None:
return
llama_cpp.llama_free_model(self.model)
self.model = None

self._exit_stack.callback(free_model)

def close(self):
self._exit_stack.close()

def vocab_type(self) -> int:
assert self.model is not None
return llama_cpp.llama_vocab_type(self.model)
Expand Down Expand Up @@ -257,8 +260,6 @@ class _LlamaContext:
"""Intermediate Python wrapper for a llama.cpp llama_context.
NOTE: For stability it's recommended you use the Llama class instead."""

_llama_free = None

def __init__(
self,
*,
Expand All @@ -269,24 +270,28 @@ def __init__(
self.model = model
self.params = params
self.verbose = verbose
self._exit_stack = ExitStack()

self._llama_free = llama_cpp._lib.llama_free # type: ignore
self.ctx = None

assert self.model.model is not None

self.ctx = llama_cpp.llama_new_context_with_model(
self.model.model, self.params
)
self.ctx = llama_cpp.llama_new_context_with_model(self.model.model, self.params)

if self.ctx is None:
raise ValueError("Failed to create llama_context")

def __del__(self):
if self.ctx is not None and self._llama_free is not None:
self._llama_free(self.ctx)
def free_ctx():
if self.ctx is None:
return
llama_cpp.llama_free(self.ctx)
self.ctx = None

self._exit_stack.callback(free_ctx)

def close(self):
self._exit_stack.close()

def n_ctx(self) -> int:
assert self.ctx is not None
return llama_cpp.llama_n_ctx(self.ctx)
Expand Down Expand Up @@ -501,28 +506,31 @@ def default_params():


class _LlamaBatch:
_llama_batch_free = None

def __init__(
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
):
self._n_tokens = n_tokens
self.embd = embd
self.n_seq_max = n_seq_max
self.verbose = verbose

self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore
self._exit_stack = ExitStack()

self.batch = None
self.batch = llama_cpp.llama_batch_init(
self._n_tokens, self.embd, self.n_seq_max
)

def __del__(self):
if self.batch is not None and self._llama_batch_free is not None:
self._llama_batch_free(self.batch)
def free_batch():
if self.batch is None:
return
llama_cpp.llama_batch_free(self.batch)
self.batch = None

self._exit_stack.callback(free_batch)

def close(self):
self._exit_stack.close()

def n_tokens(self) -> int:
assert self.batch is not None
return self.batch.n_tokens
Expand Down
21 changes: 15 additions & 6 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import typing
import fnmatch
import warnings
import contextlib
import multiprocessing
from types import TracebackType

from typing import (
List,
Expand All @@ -21,6 +23,7 @@
Deque,
Callable,
Dict,
Type,
)
from collections import deque
from pathlib import Path
Expand Down Expand Up @@ -350,9 +353,11 @@ def __init__(
if not os.path.exists(model_path):
raise ValueError(f"Model path does not exist: {model_path}")

self._model = _LlamaModel(
self._stack = contextlib.ExitStack()

self._model = self._stack.enter_context(contextlib.closing(_LlamaModel(
path_model=self.model_path, params=self.model_params, verbose=self.verbose
)
)))

# Override tokenizer
self.tokenizer_ = tokenizer or LlamaTokenizer(self)
Expand All @@ -364,18 +369,18 @@ def __init__(
self.context_params.n_ctx = self._model.n_ctx_train()
self.context_params.n_batch = self.n_batch

self._ctx = _LlamaContext(
self._ctx = self._stack.enter_context(contextlib.closing(_LlamaContext(
model=self._model,
params=self.context_params,
verbose=self.verbose,
)
)))

self._batch = _LlamaBatch(
self._batch = self._stack.enter_context(contextlib.closing(_LlamaBatch(
n_tokens=self.n_batch,
embd=0,
n_seq_max=self.context_params.n_ctx,
verbose=self.verbose,
)
)))

if self.lora_path:
if self._model.apply_lora_from_file(
Expand Down Expand Up @@ -1959,6 +1964,10 @@ def pooling_type(self) -> str:
"""Return the pooling type."""
return self._ctx.pooling_type()

def close(self) -> None:
"""Explicitly free the model from memory."""
self._stack.close()

@staticmethod
def logits_to_logprobs(
logits: Union[npt.NDArray[np.single], List], axis: int = -1
Expand Down
3 changes: 3 additions & 0 deletions llama_cpp/server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __call__(self, model: Optional[str] = None) -> llama_cpp.Llama:
if self._current_model is not None:
return self._current_model

if self._current_model:
self._current_model.close()
self._current_model = None

settings = self._model_settings_dict[model]
Expand All @@ -65,6 +67,7 @@ def __iter__(self):

def free(self):
if self._current_model:
self._current_model.close()
del self._current_model

@staticmethod
Expand Down