From 71d00f083fb59bda34c82b82eea85602c1710265 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 2 Sep 2025 11:17:40 -0500 Subject: [PATCH 1/2] Dummy commit to set up the chore/type-clean-guardrails PR and branch --- nemoguardrails/actions/llm/generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 2a57e1c26..cd11e70a7 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -137,7 +137,7 @@ async def init(self): self._init_flows_index(), ) - def _extract_user_message_example(self, flow: Flow): + def _extract_user_message_example(self, flow: Flow) -> None: """Heuristic to extract user message examples from a flow.""" elements = [ item From 85c4a12335d8c78a7fa6e5033fb4608f1f27b809 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Sun, 14 Sep 2025 23:38:53 -0500 Subject: [PATCH 2/2] Initial checkin --- nemoguardrails/server/api.py | 105 ++++++++++++------ .../server/datastore/redis_store.py | 10 +- 2 files changed, 79 insertions(+), 36 deletions(-) diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index d07cb63df..4189e71d7 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -22,7 +22,7 @@ import time import warnings from contextlib import asynccontextmanager -from typing import Any, List, Optional +from typing import Any, Callable, List, Optional from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware @@ -42,14 +42,32 @@ logging.basicConfig(level=logging.INFO) log = logging.getLogger(__name__) + +class GuardrailsApp(FastAPI): + """Custom FastAPI subclass with additional attributes for Guardrails server.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Initialize custom attributes + self.default_config_id: Optional[str] = None + self.rails_config_path: str = "" + self.disable_chat_ui: bool = False + self.auto_reload: bool = False + self.stop_signal: bool = False + self.single_config_mode: bool = False + self.single_config_id: Optional[str] = None + self.loop: Optional[asyncio.AbstractEventLoop] = None + self.task: Optional[asyncio.Future] = None + + # The list of registered loggers. Can be used to send logs to various # backends and storage engines. -registered_loggers = [] +registered_loggers: List[Callable] = [] api_description = """Guardrails Sever API.""" # The headers for each request -api_request_headers = contextvars.ContextVar("headers") +api_request_headers: contextvars.ContextVar = contextvars.ContextVar("headers") # The datastore that the Server should use. # This is currently used only for storing threads. @@ -59,7 +77,7 @@ @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: GuardrailsApp): # Startup logic here """Register any additional challenges, if available at startup.""" challenges_files = os.path.join(app.rails_config_path, "challenges.json") @@ -82,8 +100,11 @@ async def lifespan(app: FastAPI): if os.path.exists(filepath): filename = os.path.basename(filepath) spec = importlib.util.spec_from_file_location(filename, filepath) - config_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(config_module) + if spec is not None and spec.loader is not None: + config_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config_module) + else: + config_module = None # If there is an `init` function, we call it with the reference to the app. if config_module is not None and hasattr(config_module, "init"): @@ -110,6 +131,7 @@ async def root_handler(): if app.auto_reload: app.loop = asyncio.get_running_loop() + # Store the future directly as task app.task = app.loop.run_in_executor(None, start_auto_reload_monitoring) yield @@ -117,14 +139,14 @@ async def root_handler(): # Shutdown logic here if app.auto_reload: app.stop_signal = True - if hasattr(app, "task"): + if hasattr(app, "task") and app.task is not None: app.task.cancel() log.info("Shutting down file observer") else: pass -app = FastAPI( +app = GuardrailsApp( title="Guardrails Server API", description=api_description, version="0.1.0", @@ -186,7 +208,7 @@ class RequestBody(BaseModel): max_length=255, description="The id of an existing thread to which the messages should be added.", ) - messages: List[dict] = Field( + messages: Optional[List[dict]] = Field( default=None, description="The list of messages in the current conversation." ) context: Optional[dict] = Field( @@ -232,7 +254,7 @@ def ensure_config_ids(cls, v, values): class ResponseBody(BaseModel): - messages: List[dict] = Field( + messages: Optional[List[dict]] = Field( default=None, description="The new messages in the conversation" ) llm_output: Optional[dict] = Field( @@ -282,8 +304,8 @@ async def get_rails_configs(): # One instance of LLMRails per config id -llm_rails_instances = {} -llm_rails_events_history_cache = {} +llm_rails_instances: dict[str, LLMRails] = {} +llm_rails_events_history_cache: dict[str, dict] = {} def _generate_cache_key(config_ids: List[str]) -> str: @@ -310,7 +332,7 @@ def _get_rails(config_ids: List[str]) -> LLMRails: # get the same thing. config_ids = [""] - full_llm_rails_config = None + full_llm_rails_config: Optional[RailsConfig] = None for config_id in config_ids: base_path = os.path.abspath(app.rails_config_path) @@ -330,6 +352,9 @@ def _get_rails(config_ids: List[str]) -> LLMRails: else: full_llm_rails_config += rails_config + if full_llm_rails_config is None: + raise ValueError("No valid rails configuration found.") + llm_rails = LLMRails(config=full_llm_rails_config, verbose=True) llm_rails_instances[configs_cache_key] = llm_rails @@ -368,22 +393,27 @@ async def chat_completion(body: RequestBody, request: Request): "No 'config_id' provided and no default configuration is set for the server. " "You must set a 'config_id' in your request or set use --default-config-id when . " ) + + # Ensure config_ids is not None before passing to _get_rails + if config_ids is None: + raise GuardrailsConfigurationError("No valid configuration IDs available.") + try: llm_rails = _get_rails(config_ids) except ValueError as ex: log.exception(ex) - return { - "messages": [ + return ResponseBody( + messages=[ { "role": "assistant", "content": f"Could not load the {config_ids} guardrails configuration. " f"An internal error has occurred.", } ] - } + ) try: - messages = body.messages + messages = body.messages or [] if body.context: messages.insert(0, {"role": "context", "content": body.context}) @@ -396,14 +426,14 @@ async def chat_completion(body: RequestBody, request: Request): # We make sure the `thread_id` meets the minimum complexity requirement. if len(body.thread_id) < 16: - return { - "messages": [ + return ResponseBody( + messages=[ { "role": "assistant", "content": "The `thread_id` must have a minimum length of 16 characters.", } ] - } + ) # Fetch the existing thread messages. For easier management, we prepend # the string `thread-` to all thread keys. @@ -440,32 +470,37 @@ async def chat_completion(body: RequestBody, request: Request): ) if isinstance(res, GenerationResponse): - bot_message = res.response[0] + bot_message_content = res.response[0] + # Ensure bot_message is always a dict + if isinstance(bot_message_content, str): + bot_message = {"role": "assistant", "content": bot_message_content} + else: + bot_message = bot_message_content else: assert isinstance(res, dict) bot_message = res # If we're using threads, we also need to update the data before returning # the message. - if body.thread_id: + if body.thread_id and datastore is not None and datastore_key is not None: await datastore.set(datastore_key, json.dumps(messages + [bot_message])) - result = {"messages": [bot_message]} + result = ResponseBody(messages=[bot_message]) # If we have additional GenerationResponse fields, we return as well if isinstance(res, GenerationResponse): - result["llm_output"] = res.llm_output - result["output_data"] = res.output_data - result["log"] = res.log - result["state"] = res.state + result.llm_output = res.llm_output + result.output_data = res.output_data + result.log = res.log + result.state = res.state return result except Exception as ex: log.exception(ex) - return { - "messages": [{"role": "assistant", "content": "Internal server error."}] - } + return ResponseBody( + messages=[{"role": "assistant", "content": "Internal server error."}] + ) # By default, there are no challenges @@ -498,7 +533,7 @@ def register_datastore(datastore_instance: DataStore): datastore = datastore_instance -def register_logger(logger: callable): +def register_logger(logger: Callable): """Register an additional logger""" registered_loggers.append(logger) @@ -510,8 +545,7 @@ def start_auto_reload_monitoring(): from watchdog.observers import Observer class Handler(FileSystemEventHandler): - @staticmethod - def on_any_event(event): + def on_any_event(self, event): if event.is_directory: return None @@ -521,7 +555,8 @@ def on_any_event(event): ) # Compute the relative path - rel_path = os.path.relpath(event.src_path, app.rails_config_path) + src_path_str = str(event.src_path) + rel_path = os.path.relpath(src_path_str, app.rails_config_path) # The config_id is the first component parts = rel_path.split(os.path.sep) @@ -530,7 +565,7 @@ def on_any_event(event): if ( not parts[-1].startswith(".") and ".ipynb_checkpoints" not in parts - and os.path.isfile(event.src_path) + and os.path.isfile(src_path_str) ): # We just remove the config from the cache so that a new one is used next time if config_id in llm_rails_instances: diff --git a/nemoguardrails/server/datastore/redis_store.py b/nemoguardrails/server/datastore/redis_store.py index f5f941ab2..f59d89e31 100644 --- a/nemoguardrails/server/datastore/redis_store.py +++ b/nemoguardrails/server/datastore/redis_store.py @@ -16,7 +16,10 @@ import asyncio from typing import Optional -import aioredis +try: + import aioredis # type: ignore[import] +except ImportError: + aioredis = None # type: ignore[assignment] from nemoguardrails.server.datastore.datastore import DataStore @@ -35,6 +38,11 @@ def __init__( username: [Optional] The username to use for authentication. password: [Optional] The password to use for authentication """ + if aioredis is None: + raise ImportError( + "aioredis is required for RedisStore. Install it with: pip install aioredis" + ) + self.url = url self.username = username self.password = password