Skip to content

Models are failing to be properly unloaded and freeing up VRAM #1442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Baquara opened this issue May 10, 2024 · 6 comments
Open

Models are failing to be properly unloaded and freeing up VRAM #1442

Baquara opened this issue May 10, 2024 · 6 comments
Labels
bug Something isn't working

Comments

@Baquara
Copy link

Baquara commented May 10, 2024

Expected Behavior

From the issue #302 , I expected the model to be unloaded with the following function:


def unload_model():
    global llm
    llama_free_model(llm)
    # Delete the model object
    del llm
    llm = None  # Ensure no reference remains
    
    # Explicitly invoke the garbage collector
    gc.collect()

    return {"message": "Model unloaded successfully"}

However, there are two problems here:

1 - Using llama_free_model with the object llm (which is conventionally loaded) is resulting in this:

Traceback (most recent call last):
  File "/run/media/myserver/5dcc41df-7194-4e57-a28f-833dc5ce81bb/llamacpp/app.py", line 48, in <module>
    llama_free_model(llm)
ctypes.ArgumentError: argument 1: TypeError: wrong type

'llm' is generated with this:

llm = Llama(
    model_path=model_path,
    chat_handler=chat_handler,
    n_gpu_layers=gpu_layers,
    n_ctx=n_ctx
)

2 - Even after deleting the object, assigning as None and invoking the garbage collection, the VRAM is still not freed. The VRAM only gets cleared after I kill the app along all of its processes and threads.

Current Behavior

1- llama_free_model does not work.
2 - Garbage collection not freeing up VRAM.

Environment and Context

I tried this on both an Arch Linux setup with an RTX 3090 and a Windows laptop with an eGPU. This problem was consistent on those two different OSes and different hardware setups.

  • Physical (or virtual) hardware you are using, e.g. for Linux:

AMD Ryzen 7 2700 Eight-Core Processor
NVIDIA GeForce RTX 3090

  • Operating System, e.g. for Linux:

Arch Linux 6.8.9-arch1-1
Windows 11

Python 3.12.3
GNU Make 4.4.1
g++ (GCC) 13.2.1 20240417

Failure Information (for bugs)

Traceback (most recent call last):
  File "/run/media/myserver/5dcc41df-7194-4e57-a28f-833dc5ce81bb/llamacpp/app.py", line 48, in <module>
    llama_free_model(llm)
ctypes.ArgumentError: argument 1: TypeError: wrong type

Steps to Reproduce

Please provide detailed steps for reproducing the issue. We are not sitting in front of your screen, so the more detail the better.

  1. Perform a free install of llama-cpp-python, with CUDA support
  2. Write a code snippet to load the model as usual
  3. Try to use llama_free_model to unload the model, or delete the model object and invoke garbage collection
  4. Make sure to keep the app running afterwards and check VRAM with nvidia-smi
@abetlen abetlen added the bug Something isn't working label May 10, 2024
@jkawamoto
Copy link
Contributor

From what I can see, llama_free_model is expected to take a lower-level object instead of the Llama object. In Python, determining when the garbage collector actually deletes an object is not straightforward. Here is a workaround that forces the release of the loaded model:

from llama_cpp import Llama

llama_model = Llama(…)

# Explicitly delete the model's internal object
llama_model._model.__del__()

This approach has worked for me so far.

@jndiogo
Copy link

jndiogo commented Jun 4, 2024

In my experience, @jkawamoto approach is a good one, because it frees RAM/CUDA/other memory, even if the Llama object is stuck.

I've tried calling del llama_model, but this is not guaranteed to actually call __del__ if there are references to the object (and this can happen in several cases, like for example from uncaught exceptions in interactive environments like Jupyterlab - see here )

@jkawamoto
Copy link
Contributor

Since calling a special method (__del__) of a private field is too ad hoc, I opened a PR #1513 that adds a close method to explicitly free the model.

@redshiva
Copy link

I am running llama_model._model.del() per the above comment, and I am still seeing the process use cuda ram.

Has there been any movement on creating a proper close method?

@jkawamoto
Copy link
Contributor

Llama class has close method now, and the following code should free up RAM:

from llama_cpp import Llama

llama_model = Llama(…)
...

llama_model.close()

@redshiva
Copy link

Thank you!!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants