Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion nemoguardrails/library/content_safety/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
from nemoguardrails.actions.actions import action
from nemoguardrails.actions.llm.utils import llm_call
from nemoguardrails.context import llm_call_info_var
from nemoguardrails.llm.cache import CacheInterface
from nemoguardrails.llm.cache.utils import (
CacheEntry,
create_normalized_cache_key,
extract_llm_stats_for_cache,
get_from_cache_and_restore_stats,
)
from nemoguardrails.llm.taskmanager import LLMTaskManager
from nemoguardrails.logging.explain import LLMCallInfo

Expand All @@ -33,6 +40,7 @@ async def content_safety_check_input(
llm_task_manager: LLMTaskManager,
model_name: Optional[str] = None,
context: Optional[dict] = None,
model_caches: Optional[Dict[str, CacheInterface]] = None,
**kwargs,
) -> dict:
_MAX_TOKENS = 3
Expand Down Expand Up @@ -75,6 +83,15 @@ async def content_safety_check_input(

max_tokens = max_tokens or _MAX_TOKENS

cache = model_caches.get(model_name) if model_caches else None

if cache:
cache_key = create_normalized_cache_key(check_input_prompt)
cached_result = get_from_cache_and_restore_stats(cache, cache_key)
if cached_result is not None:
log.debug(f"Content safety cache hit for model '{model_name}'")
return cached_result

result = await llm_call(
llm,
check_input_prompt,
Expand All @@ -86,7 +103,18 @@ async def content_safety_check_input(

is_safe, *violated_policies = result

return {"allowed": is_safe, "policy_violations": violated_policies}
final_result = {"allowed": is_safe, "policy_violations": violated_policies}

if cache:
cache_key = create_normalized_cache_key(check_input_prompt)
cache_entry: CacheEntry = {
"result": final_result,
"llm_stats": extract_llm_stats_for_cache(),
}
cache.put(cache_key, cache_entry)
log.debug(f"Content safety result cached for model '{model_name}'")

return final_result


def content_safety_check_output_mapping(result: dict) -> bool:
Expand Down
21 changes: 21 additions & 0 deletions nemoguardrails/llm/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""General-purpose caching utilities for NeMo Guardrails."""

from nemoguardrails.llm.cache.interface import CacheInterface
from nemoguardrails.llm.cache.lfu import LFUCache

__all__ = ["CacheInterface", "LFUCache"]
207 changes: 207 additions & 0 deletions nemoguardrails/llm/cache/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Cache interface for NeMo Guardrails caching system.

This module defines the abstract base class for cache implementations
that can be used interchangeably throughout the guardrails system.
"""

from abc import ABC, abstractmethod
from typing import Any, Callable, Optional


class CacheInterface(ABC):
"""
Abstract base class defining the interface for cache implementations.

All cache implementations must inherit from this class and implement
the required methods to ensure compatibility with the caching system.
"""

@abstractmethod
def get(self, key: Any, default: Any = None) -> Any:
"""
Retrieve an item from the cache.

Args:
key: The key to look up in the cache.
default: Value to return if key is not found (default: None).

Returns:
The value associated with the key, or default if not found.
"""
pass

@abstractmethod
def put(self, key: Any, value: Any) -> None:
"""
Store an item in the cache.

If the cache is at maxsize, this method should evict an item
according to the cache's eviction policy (e.g., LFU, LRU, etc.).

Args:
key: The key to store.
value: The value to associate with the key.
"""
pass

@abstractmethod
def size(self) -> int:
"""
Get the current number of items in the cache.

Returns:
The number of items currently stored in the cache.
"""
pass

@abstractmethod
def is_empty(self) -> bool:
"""
Check if the cache is empty.

Returns:
True if the cache contains no items, False otherwise.
"""
pass

@abstractmethod
def clear(self) -> None:
"""
Remove all items from the cache.

After calling this method, the cache should be empty.
"""
pass

def contains(self, key: Any) -> bool:
"""
Check if a key exists in the cache.

This is an optional method that can be overridden for efficiency.
The default implementation uses get() to check existence.

Args:
key: The key to check.

Returns:
True if the key exists in the cache, False otherwise.
"""
# Default implementation - can be overridden for efficiency
sentinel = object()
return self.get(key, sentinel) is not sentinel

@property
@abstractmethod
def maxsize(self) -> int:
"""
Get the maximum size of the cache.

Returns:
The maximum number of items the cache can hold.
"""
pass

def get_stats(self) -> dict:
"""
Get cache statistics.

Returns:
Dictionary with cache statistics. The format and contents
may vary by implementation. Common fields include:
- hits: Number of cache hits
- misses: Number of cache misses
- evictions: Number of items evicted
- hit_rate: Percentage of requests that were hits
- current_size: Current number of items in cache
- maxsize: Maximum size of the cache

The default implementation returns a message indicating that
statistics tracking is not supported.
"""
return {
"message": "Statistics tracking is not supported by this cache implementation"
}

def reset_stats(self) -> None:
"""
Reset cache statistics.

This is an optional method that cache implementations can override
if they support statistics tracking. The default implementation does nothing.
"""
# Default no-op implementation
pass

def log_stats_now(self) -> None:
"""
Force immediate logging of cache statistics.

This is an optional method that cache implementations can override
if they support statistics logging. The default implementation does nothing.

Implementations that support statistics logging should output the
current cache statistics to their configured logging backend.
"""
# Default no-op implementation
pass

def supports_stats_logging(self) -> bool:
"""
Check if this cache implementation supports statistics logging.

Returns:
True if the cache supports statistics logging, False otherwise.

The default implementation returns False. Cache implementations
that support statistics logging should override this to return True
when logging is enabled.
"""
return False

async def get_or_compute(
self, key: Any, compute_fn: Callable[[], Any], default: Any = None
) -> Any:
"""
Atomically get a value from the cache or compute it if not present.

This method ensures that the compute function is called at most once
even in the presence of concurrent requests for the same key.

Args:
key: The key to look up
compute_fn: Async function to compute the value if key is not found
default: Value to return if compute_fn raises an exception

Returns:
The cached value or the computed value

This is an optional method with a default implementation. Cache
implementations should override this for better thread-safety guarantees.
"""
# Default implementation - not thread-safe for computation
value = self.get(key)
if value is not None:
return value

try:
computed_value = await compute_fn()
self.put(key, computed_value)
return computed_value
except Exception:
return default
Loading