Skip to content

Commit 749222f

Browse files
committed
Cleaned integrations directory
1 parent d660d66 commit 749222f

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

nemoguardrails/integrations/langchain/runnable_rails.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515

1616
from __future__ import annotations
1717

18+
from typing import Any, List, Optional, Union, cast
1819
import logging
1920
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
2021

22+
from langchain_core.language_models import BaseChatModel
23+
from langchain_core.language_models.llms import BaseLLM
24+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
2125
from langchain_core.language_models import BaseLanguageModel
2226
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
2327
from langchain_core.runnables import Runnable, RunnableConfig
@@ -33,7 +37,7 @@
3337
message_to_dict,
3438
)
3539
from nemoguardrails.integrations.langchain.utils import async_wrap
36-
from nemoguardrails.rails.llm.options import GenerationOptions
40+
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse
3741

3842
logger = logging.getLogger(__name__)
3943

@@ -62,7 +66,7 @@ class RunnableRails(Runnable[Input, Output]):
6266
def __init__(
6367
self,
6468
config: RailsConfig,
65-
llm: Optional[BaseLanguageModel] = None,
69+
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
6670
tools: Optional[List[Tool]] = None,
6771
passthrough: bool = True,
6872
runnable: Optional[Runnable] = None,
@@ -110,7 +114,7 @@ def __init__(
110114
if self.passthrough_runnable:
111115
self._init_passthrough_fn()
112116

113-
def _init_passthrough_fn(self):
117+
def _init_passthrough_fn(self) -> None:
114118
"""Initialize the passthrough function for the LLM rails instance."""
115119

116120
async def passthrough_fn(context: dict, events: List[dict]):
@@ -134,7 +138,8 @@ async def passthrough_fn(context: dict, events: List[dict]):
134138

135139
return text, _output
136140

137-
self.rails.llm_generation_actions.passthrough_fn = passthrough_fn
141+
# Dynamically assign passthrough_fn to avoid type checker issues
142+
setattr(self.rails.llm_generation_actions, "passthrough_fn", passthrough_fn)
138143

139144
def __or__(
140145
self, other: Union[BaseLanguageModel, Runnable[Any, Any]]
@@ -687,6 +692,9 @@ def _full_rails_invoke(
687692
res = self.rails.generate(
688693
messages=input_messages, options=GenerationOptions(output_vars=True)
689694
)
695+
# When using output_vars=True, rails.generate returns a GenerationResponse
696+
if not isinstance(res, GenerationResponse):
697+
raise Exception(f"Expected GenerationResponse, got {type(res)}")
690698
context = res.output_data
691699
result = res.response
692700

0 commit comments

Comments
 (0)