Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 55 additions & 38 deletions nemoguardrails/embeddings/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import asyncio
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast

from annoy import AnnoyIndex
from annoy import AnnoyIndex # type: ignore

from nemoguardrails.embeddings.cache import cache_embeddings
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
Expand Down Expand Up @@ -45,26 +45,16 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
max_batch_hold: The maximum time a batch is held before being processed
"""

embedding_model: str
embedding_engine: str
embedding_params: Dict[str, Any]
index: AnnoyIndex
embedding_size: int
cache_config: EmbeddingsCacheConfig
embeddings: List[List[float]]
search_threshold: float
use_batching: bool
max_batch_size: int
max_batch_hold: float
# Instance attributes are defined in __init__ and accessed via properties

def __init__(
self,
embedding_model=None,
embedding_engine=None,
embedding_params=None,
index=None,
cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None,
search_threshold: float = None,
embedding_model: Optional[str] = None,
embedding_engine: Optional[str] = None,
embedding_params: Optional[Dict[str, Any]] = None,
index: Optional[AnnoyIndex] = None,
cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None,
search_threshold: Optional[float] = None,
use_batching: bool = False,
max_batch_size: int = 10,
max_batch_hold: float = 0.01,
Expand All @@ -81,10 +71,16 @@ def __init__(
max_batch_hold: The maximum time a batch is held before being processed
"""
self._model: Optional[EmbeddingModel] = None
self._items = []
self._embeddings = []
self.embedding_model = embedding_model
self.embedding_engine = embedding_engine
self._items: List[IndexItem] = []
self._embeddings: List[List[float]] = []
self.embedding_model: str = (
embedding_model
if embedding_model
else "sentence-transformers/all-MiniLM-L6-v2"
)
self.embedding_engine: str = (
embedding_engine if embedding_engine else "SentenceTransformers"
)
self.embedding_params = embedding_params or {}
self._embedding_size = 0
self.search_threshold = search_threshold or float("inf")
Expand All @@ -95,12 +91,12 @@ def __init__(
self._index = index

# Data structures for batching embedding requests
self._req_queue = {}
self._req_results = {}
self._req_idx = 0
self._current_batch_finished_event = None
self._current_batch_full_event = None
self._current_batch_submitted = asyncio.Event()
self._req_queue: Dict[int, str] = {}
self._req_results: Dict[int, List[float]] = {}
self._req_idx: int = 0
self._current_batch_finished_event: Optional[asyncio.Event] = None
self._current_batch_full_event: Optional[asyncio.Event] = None
self._current_batch_submitted: asyncio.Event = asyncio.Event()

# Initialize the batching configuration
self.use_batching = use_batching
Expand All @@ -112,6 +108,11 @@ def embeddings_index(self):
"""Get the current embedding index"""
return self._index

@embeddings_index.setter
def embeddings_index(self, index):
"""Setter to allow replacing the index dynamically."""
self._index = index

@property
def cache_config(self):
"""Get the cache configuration."""
Expand All @@ -127,19 +128,22 @@ def embeddings(self):
"""Get the computed embeddings."""
return self._embeddings

@embeddings_index.setter
def embeddings_index(self, index):
"""Setter to allow replacing the index dynamically."""
self._index = index

def _init_model(self):
"""Initialize the model used for computing the embeddings."""
model = self.embedding_model
engine = self.embedding_engine

self._model = init_embedding_model(
embedding_model=self.embedding_model,
embedding_engine=self.embedding_engine,
embedding_model=model,
embedding_engine=engine,
embedding_params=self.embedding_params,
)

if not self._model:
raise ValueError(
f"Couldn't create embedding model with model {model} and engine {engine}"
)

@cache_embeddings
async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Compute embeddings for a list of texts.
Expand All @@ -153,7 +157,9 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
if self._model is None:
self._init_model()

embeddings = await self._model.encode_async(texts)
# self._model can't be None here, or self._init_model() would throw a ValueError
model: EmbeddingModel = cast(EmbeddingModel, self._model)
embeddings = await model.encode_async(texts)
return embeddings

async def add_item(self, item: IndexItem):
Expand Down Expand Up @@ -199,6 +205,10 @@ async def _run_batch(self):
"""Runs the current batch of embeddings."""

# Wait up to `max_batch_hold` time or until `max_batch_size` is reached.
if not self._current_batch_full_event:
raise Exception("self._current_batch_full_event not initialized")

assert self._current_batch_full_event is not None
done, pending = await asyncio.wait(
[
asyncio.create_task(asyncio.sleep(self.max_batch_hold)),
Expand All @@ -210,6 +220,9 @@ async def _run_batch(self):
task.cancel()

# Reset the batch event
if not self._current_batch_finished_event:
raise Exception("self._current_batch_finished_event not initialized")

batch_event: asyncio.Event = self._current_batch_finished_event
self._current_batch_finished_event = None

Expand Down Expand Up @@ -252,9 +265,13 @@ async def _batch_get_embeddings(self, text: str) -> List[float]:

# We check if we reached the max batch size
if len(self._req_queue) >= self.max_batch_size:
if not self._current_batch_full_event:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this is redundant, self._current_batch_full_event cannot be None here as per earlier check and assertion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the code statically, self._current_batch_full_event can be None if self._current_batch_finished_event is not None. Is there something about the code that prevents this from happeneing?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that static analysis cannot catch this. But on my side L260 should be:

if self._current_batch_finished_event is None or self._current_batch_full_event is None:

Any batch processing needs self._current_batch_full_event not to be None. These two events need to be created together.

raise Exception("self._current_batch_full_event not initialized")
self._current_batch_full_event.set()

# Wait for the batch to finish
# Wait for the batch to finish
if not self._current_batch_finished_event:
raise Exception("self._current_batch_finished_event not initialized")
await self._current_batch_finished_event.wait()

# Remove the result and return it
Expand Down
59 changes: 39 additions & 20 deletions nemoguardrails/embeddings/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from abc import ABC, abstractmethod
from functools import singledispatchmethod
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional

try:
import redis # type: ignore
except ImportError:
redis = None # type: ignore

from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig

Expand All @@ -30,18 +35,20 @@
class KeyGenerator(ABC):
"""Abstract class for key generators."""

name: str # Class attribute that should be defined in subclasses

@abstractmethod
def generate_key(self, text: str) -> str:
pass

@classmethod
def from_name(cls, name):
for subclass in cls.__subclasses__():
if subclass.name == name:
if hasattr(subclass, "name") and subclass.name == name:
return subclass
raise ValueError(
f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: "
f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}"
f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}"
". Make sure to import the derived class before using it."
)

Expand Down Expand Up @@ -76,6 +83,8 @@ def generate_key(self, text: str) -> str:
class CacheStore(ABC):
"""Abstract class for cache stores."""

name: str # Class attribute that should be defined in subclasses

@abstractmethod
def get(self, key):
"""Get a value from the cache."""
Expand All @@ -94,11 +103,11 @@ def clear(self):
@classmethod
def from_name(cls, name):
for subclass in cls.__subclasses__():
if subclass.name == name:
if hasattr(subclass, "name") and subclass.name == name:
return subclass
raise ValueError(
f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: "
f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}"
f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}"
". Make sure to import the derived class before using it."
)

Expand Down Expand Up @@ -147,7 +156,7 @@ class FilesystemCacheStore(CacheStore):

name = "filesystem"

def __init__(self, cache_dir: str = None):
def __init__(self, cache_dir: Optional[str] = None):
self._cache_dir = Path(cache_dir or ".cache/embeddings")
self._cache_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -190,8 +199,10 @@ class RedisCacheStore(CacheStore):
name = "redis"

def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0):
import redis

if redis is None:
raise ImportError(
"Could not import redis, please install it with `pip install redis`."
)
self._redis = redis.Redis(host=host, port=port, db=db)

def get(self, key):
Expand All @@ -207,9 +218,9 @@ def clear(self):
class EmbeddingsCache:
def __init__(
self,
key_generator: KeyGenerator = None,
cache_store: CacheStore = None,
store_config: dict = None,
key_generator: Optional[KeyGenerator] = None,
cache_store: Optional[CacheStore] = None,
store_config: Optional[dict] = None,
):
self._key_generator = key_generator
self._cache_store = cache_store
Expand All @@ -218,7 +229,10 @@ def __init__(
@classmethod
def from_dict(cls, d: Dict[str, str]):
key_generator = KeyGenerator.from_name(d.get("key_generator"))()
store_config = d.get("store_config")
store_config_raw = d.get("store_config")
store_config: dict = (
store_config_raw if isinstance(store_config_raw, dict) else {}
)
cache_store = CacheStore.from_name(d.get("store"))(**store_config)

return cls(key_generator=key_generator, cache_store=cache_store)
Expand All @@ -230,25 +244,27 @@ def from_config(cls, config: EmbeddingsCacheConfig):

def get_config(self):
return EmbeddingsCacheConfig(
key_generator=self._key_generator.name,
store=self._cache_store.name,
store_config=self._store_config,
key_generator=self._key_generator.name if self._key_generator else None,
store=self._cache_store.name if self._cache_store else None,
store_config=self._store_config if self._store_config else None,
)

@singledispatchmethod
def get(self, texts):
raise NotImplementedError

@get.register
@get.register(str)
def _(self, text: str):
if self._key_generator is None or self._cache_store is None:
return None
key = self._key_generator.generate_key(text)
log.info(f"Fetching key {key} for text '{text[:20]}...' from cache")

result = self._cache_store.get(key)

return result

@get.register
@get.register(list)
def _(self, texts: list):
cached = {}

Expand All @@ -266,19 +282,22 @@ def _(self, texts: list):
def set(self, texts):
raise NotImplementedError

@set.register
@set.register(str)
def _(self, text: str, value: List[float]):
if self._key_generator is None or self._cache_store is None:
return
key = self._key_generator.generate_key(text)
log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.")
self._cache_store.set(key, value)

@set.register
@set.register(list)
def _(self, texts: list, values: List[List[float]]):
for text, value in zip(texts, values):
self.set(text, value)

def clear(self):
self._cache_store.clear()
if self._cache_store is not None:
self._cache_store.clear()


def cache_embeddings(func):
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/embeddings/providers/fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class FastEmbedEmbeddingModel(EmbeddingModel):
engine_name = "FastEmbed"

def __init__(self, embedding_model: str, **kwargs):
from fastembed import TextEmbedding as Embedding
from fastembed import TextEmbedding as Embedding # type: ignore

# Enabling a short form model name for all-MiniLM-L6-v2.
if embedding_model == "all-MiniLM-L6-v2":
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/embeddings/providers/nim.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class NIMEmbeddingModel(EmbeddingModel):

def __init__(self, embedding_model: str, **kwargs):
try:
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings # type: ignore

self.model = embedding_model
self.document_embedder = NVIDIAEmbeddings(model=embedding_model, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions nemoguardrails/embeddings/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def __init__(
**kwargs,
):
try:
import openai
from openai import AsyncOpenAI, OpenAI
import openai # type: ignore
from openai import AsyncOpenAI, OpenAI # type: ignore
except ImportError:
raise ImportError(
"Could not import openai, please install it with "
"`pip install openai`."
)
if openai.__version__ < "1.0.0":
if openai.__version__ < "1.0.0": # type: ignore
raise RuntimeError(
"`openai<1.0.0` is no longer supported. "
"Please upgrade using `pip install openai>=1.0.0`."
Expand Down
4 changes: 2 additions & 2 deletions nemoguardrails/embeddings/providers/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):

def __init__(self, embedding_model: str, **kwargs):
try:
from sentence_transformers import SentenceTransformer
from sentence_transformers import SentenceTransformer # type: ignore
except ImportError:
raise ImportError(
"Could not import sentence-transformers, please install it with "
"`pip install sentence-transformers`."
)

try:
from torch import cuda
from torch import cuda # type: ignore
except ImportError:
raise ImportError(
"Could not import torch, please install it with `pip install torch`."
Expand Down
Loading