Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 1db4075

Browse files
authored
Modify to use GPU (if available) for embedding (#1064)
1 parent 19186fd commit 1db4075

File tree

5 files changed

+38
-11
lines changed

5 files changed

+38
-11
lines changed

scripts/import_packages.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import sqlite_vec_sl_tmp
99

10+
from codegate.config import Config
1011
from codegate.inference.inference_engine import LlamaCppInferenceEngine
1112
from codegate.utils.utils import generate_vector_string
1213

@@ -55,7 +56,9 @@ def setup_schema(self):
5556

5657
async def process_package(self, package):
5758
vector_str = generate_vector_string(package)
58-
vector = await self.inference_engine.embed(self.model_path, [vector_str])
59+
vector = await self.inference_engine.embed(
60+
self.model_path, [vector_str], n_gpu_layers=Config.get_config().chat_model_n_gpu_layers
61+
)
5962
vector_array = np.array(vector[0], dtype=np.float32)
6063

6164
cursor = self.conn.cursor()

src/codegate/inference/inference_engine.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1+
from typing import Iterator, List, Union
2+
13
import structlog
2-
from llama_cpp import Llama
4+
from llama_cpp import (
5+
CreateChatCompletionResponse,
6+
CreateChatCompletionStreamResponse,
7+
CreateCompletionResponse,
8+
CreateCompletionStreamResponse,
9+
Llama,
10+
)
311

412
logger = structlog.get_logger("codegate")
513

@@ -35,7 +43,9 @@ def _close_models(self):
3543
model._sampler.close()
3644
model.close()
3745

38-
async def __get_model(self, model_path, embedding=False, n_ctx=512, n_gpu_layers=0) -> Llama:
46+
async def __get_model(
47+
self, model_path: str, embedding: bool = False, n_ctx: int = 512, n_gpu_layers: int = 0
48+
) -> Llama:
3949
"""
4050
Returns Llama model object from __models if present. Otherwise, the model
4151
is loaded and added to __models and returned.
@@ -55,7 +65,9 @@ async def __get_model(self, model_path, embedding=False, n_ctx=512, n_gpu_layers
5565

5666
return self.__models[model_path]
5767

58-
async def complete(self, model_path, n_ctx=512, n_gpu_layers=0, **completion_request):
68+
async def complete(
69+
self, model_path: str, n_ctx: int = 512, n_gpu_layers: int = 0, **completion_request
70+
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
5971
"""
6072
Generates a chat completion using the specified model and request parameters.
6173
"""
@@ -64,7 +76,9 @@ async def complete(self, model_path, n_ctx=512, n_gpu_layers=0, **completion_req
6476
)
6577
return model.create_completion(**completion_request)
6678

67-
async def chat(self, model_path, n_ctx=512, n_gpu_layers=0, **chat_completion_request):
79+
async def chat(
80+
self, model_path: str, n_ctx: int = 512, n_gpu_layers: int = 0, **chat_completion_request
81+
) -> Union[CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]]:
6882
"""
6983
Generates a chat completion using the specified model and request parameters.
7084
"""
@@ -73,18 +87,20 @@ async def chat(self, model_path, n_ctx=512, n_gpu_layers=0, **chat_completion_re
7387
)
7488
return model.create_chat_completion(**chat_completion_request)
7589

76-
async def embed(self, model_path, content):
90+
async def embed(self, model_path: str, content: List[str], n_gpu_layers=0) -> List[List[float]]:
7791
"""
7892
Generates an embedding for the given content using the specified model.
7993
"""
8094
logger.debug(
8195
"Generating embedding",
8296
model=model_path.split("/")[-1],
83-
content=content,
97+
content=content[0][0 : min(100, len(content[0]))],
8498
content_length=len(content[0]) if content else 0,
8599
)
86100

87-
model = await self.__get_model(model_path=model_path, embedding=True)
101+
model = await self.__get_model(
102+
model_path=model_path, embedding=True, n_gpu_layers=n_gpu_layers
103+
)
88104
embedding = model.embed(content)
89105

90106
logger.debug(

src/codegate/pipeline/suspicious_commands/suspicious_commands.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ async def compute_embeddings(self, phrases):
8080
Returns:
8181
torch.Tensor: Tensor of embeddings.
8282
"""
83-
embeddings = await self.inference_engine.embed(self.model_path, phrases)
83+
embeddings = await self.inference_engine.embed(
84+
self.model_path, phrases, n_gpu_layers=Config.get_config().chat_model_n_gpu_layers
85+
)
8486
return embeddings
8587

8688
async def classify_phrase(self, phrase, embeddings=None):

src/codegate/pipeline/suspicious_commands/suspicious_commands_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ async def train(self, phrases, labels):
107107
phrases (list of str): List of phrases to train on.
108108
labels (list of int): Corresponding labels for the phrases.
109109
"""
110-
embeds = await self.inference_engine.embed(self.model_path, phrases)
110+
embeds = await self.inference_engine.embed(
111+
self.model_path, phrases, n_gpu_layers=Config.get_config().chat_model_n_gpu_layers
112+
)
111113
if isinstance(embeds[0], list):
112114
embedding_dim = len(embeds[0])
113115
else:

src/codegate/storage/storage_engine.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,11 @@ async def search(
185185

186186
elif query:
187187
# Generate embedding for the query
188-
query_vector = await self.inference_engine.embed(self.model_path, [query])
188+
query_vector = await self.inference_engine.embed(
189+
self.model_path,
190+
[query],
191+
n_gpu_layers=Config.get_config().chat_model_n_gpu_layers,
192+
)
189193
query_embedding = np.array(query_vector[0], dtype=np.float32)
190194
query_embedding_bytes = query_embedding.tobytes()
191195

0 commit comments

Comments
 (0)