1515
1616from  __future__ import  annotations 
1717
18+ from  typing  import  Any , List , Optional , Union , cast 
1819import  logging 
1920from  typing  import  Any , AsyncIterator , Dict , Iterator , List , Optional , Union 
2021
22+ from  langchain_core .language_models  import  BaseChatModel 
23+ from  langchain_core .language_models .llms  import  BaseLLM 
24+ from  langchain_core .messages  import  AIMessage , HumanMessage , SystemMessage 
2125from  langchain_core .language_models  import  BaseLanguageModel 
2226from  langchain_core .prompt_values  import  ChatPromptValue , StringPromptValue 
2327from  langchain_core .runnables  import  Runnable , RunnableConfig 
3337    message_to_dict ,
3438)
3539from  nemoguardrails .integrations .langchain .utils  import  async_wrap 
36- from  nemoguardrails .rails .llm .options  import  GenerationOptions 
40+ from  nemoguardrails .rails .llm .options  import  GenerationOptions ,  GenerationResponse 
3741
3842logger  =  logging .getLogger (__name__ )
3943
@@ -62,7 +66,7 @@ class RunnableRails(Runnable[Input, Output]):
6266    def  __init__ (
6367        self ,
6468        config : RailsConfig ,
65-         llm : Optional [BaseLanguageModel ] =  None ,
69+         llm : Optional [Union [ BaseLLM ,  BaseChatModel ] ] =  None ,
6670        tools : Optional [List [Tool ]] =  None ,
6771        passthrough : bool  =  True ,
6872        runnable : Optional [Runnable ] =  None ,
@@ -110,7 +114,7 @@ def __init__(
110114        if  self .passthrough_runnable :
111115            self ._init_passthrough_fn ()
112116
113-     def  _init_passthrough_fn (self ):
117+     def  _init_passthrough_fn (self )  ->   None :
114118        """Initialize the passthrough function for the LLM rails instance.""" 
115119
116120        async  def  passthrough_fn (context : dict , events : List [dict ]):
@@ -134,7 +138,8 @@ async def passthrough_fn(context: dict, events: List[dict]):
134138
135139            return  text , _output 
136140
137-         self .rails .llm_generation_actions .passthrough_fn  =  passthrough_fn 
141+         # Dynamically assign passthrough_fn to avoid type checker issues 
142+         setattr (self .rails .llm_generation_actions , "passthrough_fn" , passthrough_fn )
138143
139144    def  __or__ (
140145        self , other : Union [BaseLanguageModel , Runnable [Any , Any ]]
@@ -687,6 +692,9 @@ def _full_rails_invoke(
687692        res  =  self .rails .generate (
688693            messages = input_messages , options = GenerationOptions (output_vars = True )
689694        )
695+         # When using output_vars=True, rails.generate returns a GenerationResponse 
696+         if  not  isinstance (res , GenerationResponse ):
697+             raise  Exception (f"Expected GenerationResponse, got { type (res )}  )
690698        context  =  res .output_data 
691699        result  =  res .response 
692700
0 commit comments