2020from langchain .callbacks .base import AsyncCallbackHandler , BaseCallbackManager
2121from langchain .prompts .base import StringPromptValue
2222from langchain .prompts .chat import ChatPromptValue
23- from langchain . schema import AIMessage , HumanMessage , SystemMessage
23+ from langchain_core . messages import AIMessage , HumanMessage , SystemMessage , ToolMessage
2424
2525from nemoguardrails .colang .v2_x .lang .colang_ast import Flow
2626from nemoguardrails .colang .v2_x .runtime .flows import InternalEvent , InternalEvents
27- from nemoguardrails .context import llm_call_info_var , reasoning_trace_var
27+ from nemoguardrails .context import (
28+ llm_call_info_var ,
29+ reasoning_trace_var ,
30+ tool_calls_var ,
31+ )
2832from nemoguardrails .logging .callbacks import logging_callbacks
2933from nemoguardrails .logging .explain import LLMCallInfo
3034
@@ -72,7 +76,22 @@ async def llm_call(
7276 custom_callback_handlers : Optional [List [AsyncCallbackHandler ]] = None ,
7377) -> str :
7478 """Calls the LLM with a prompt and returns the generated text."""
75- # We initialize a new LLM call if we don't have one already
79+ _setup_llm_call_info (llm , model_name , model_provider )
80+ all_callbacks = _prepare_callbacks (custom_callback_handlers )
81+
82+ if isinstance (prompt , str ):
83+ response = await _invoke_with_string_prompt (llm , prompt , all_callbacks , stop )
84+ else :
85+ response = await _invoke_with_message_list (llm , prompt , all_callbacks , stop )
86+
87+ _store_tool_calls (response )
88+ return _extract_content (response )
89+
90+
91+ def _setup_llm_call_info (
92+ llm : BaseLanguageModel , model_name : Optional [str ], model_provider : Optional [str ]
93+ ) -> None :
94+ """Initialize or update LLM call info in context."""
7695 llm_call_info = llm_call_info_var .get ()
7796 if llm_call_info is None :
7897 llm_call_info = LLMCallInfo ()
@@ -81,52 +100,84 @@ async def llm_call(
81100 llm_call_info .llm_model_name = model_name or _infer_model_name (llm )
82101 llm_call_info .llm_provider_name = model_provider
83102
103+
104+ def _prepare_callbacks (
105+ custom_callback_handlers : Optional [List [AsyncCallbackHandler ]],
106+ ) -> BaseCallbackManager :
107+ """Prepare callback manager with custom handlers if provided."""
84108 if custom_callback_handlers and custom_callback_handlers != [None ]:
85- all_callbacks = BaseCallbackManager (
109+ return BaseCallbackManager (
86110 handlers = logging_callbacks .handlers + custom_callback_handlers ,
87111 inheritable_handlers = logging_callbacks .handlers + custom_callback_handlers ,
88112 )
89- else :
90- all_callbacks = logging_callbacks
113+ return logging_callbacks
91114
92- if isinstance (prompt , str ):
93- # stop sinks here
94- try :
95- result = await llm .agenerate_prompt (
96- [StringPromptValue (text = prompt )], callbacks = all_callbacks , stop = stop
115+
116+ async def _invoke_with_string_prompt (
117+ llm : BaseLanguageModel ,
118+ prompt : str ,
119+ callbacks : BaseCallbackManager ,
120+ stop : Optional [List [str ]],
121+ ):
122+ """Invoke LLM with string prompt."""
123+ try :
124+ return await llm .ainvoke (prompt , config = {"callbacks" : callbacks , "stop" : stop })
125+ except Exception as e :
126+ raise LLMCallException (e )
127+
128+
129+ async def _invoke_with_message_list (
130+ llm : BaseLanguageModel ,
131+ prompt : List [dict ],
132+ callbacks : BaseCallbackManager ,
133+ stop : Optional [List [str ]],
134+ ):
135+ """Invoke LLM with message list after converting to LangChain format."""
136+ messages = _convert_messages_to_langchain_format (prompt )
137+ try :
138+ return await llm .ainvoke (
139+ messages , config = {"callbacks" : callbacks , "stop" : stop }
140+ )
141+ except Exception as e :
142+ raise LLMCallException (e )
143+
144+
145+ def _convert_messages_to_langchain_format (prompt : List [dict ]) -> List :
146+ """Convert message list to LangChain message format."""
147+ messages = []
148+ for msg in prompt :
149+ msg_type = msg ["type" ] if "type" in msg else msg ["role" ]
150+
151+ if msg_type == "user" :
152+ messages .append (HumanMessage (content = msg ["content" ]))
153+ elif msg_type in ["bot" , "assistant" ]:
154+ messages .append (AIMessage (content = msg ["content" ]))
155+ elif msg_type == "system" :
156+ messages .append (SystemMessage (content = msg ["content" ]))
157+ elif msg_type == "tool" :
158+ messages .append (
159+ ToolMessage (
160+ content = msg ["content" ],
161+ tool_call_id = msg .get ("tool_call_id" , "" ),
162+ )
97163 )
98- except Exception as e :
99- raise LLMCallException (e )
100- llm_call_info .raw_response = result .llm_output
164+ else :
165+ raise ValueError (f"Unknown message type { msg_type } " )
101166
102- # TODO: error handling
103- return result .generations [0 ][0 ].text
104- else :
105- # We first need to translate the array of messages into LangChain message format
106- messages = []
107- for _msg in prompt :
108- msg_type = _msg ["type" ] if "type" in _msg else _msg ["role" ]
109- if msg_type == "user" :
110- messages .append (HumanMessage (content = _msg ["content" ]))
111- elif msg_type in ["bot" , "assistant" ]:
112- messages .append (AIMessage (content = _msg ["content" ]))
113- elif msg_type == "system" :
114- messages .append (SystemMessage (content = _msg ["content" ]))
115- else :
116- # TODO: add support for tool-related messages
117- raise ValueError (f"Unknown message type { msg_type } " )
167+ return messages
118168
119- try :
120- result = await llm .agenerate_prompt (
121- [ChatPromptValue (messages = messages )], callbacks = all_callbacks , stop = stop
122- )
123169
124- except Exception as e :
125- raise LLMCallException (e )
170+ def _store_tool_calls (response ) -> None :
171+ """Extract and store tool calls from response in context."""
172+ tool_calls = getattr (response , "tool_calls" , None )
173+ tool_calls_var .set (tool_calls )
126174
127- llm_call_info .raw_response = result .llm_output
128175
129- return result .generations [0 ][0 ].text
176+ def _extract_content (response ) -> str :
177+ """Extract text content from response."""
178+ if hasattr (response , "content" ):
179+ return response .content
180+ return str (response )
130181
131182
132183def get_colang_history (
@@ -175,15 +226,15 @@ def get_colang_history(
175226 history += f'user "{ event ["text" ]} "\n '
176227 elif event ["type" ] == "UserIntent" :
177228 if include_texts :
178- history += f' { event [" intent" ]} \n '
229+ history += f" { event [' intent' ]} \n "
179230 else :
180- history += f' user { event [" intent" ]} \n '
231+ history += f" user { event [' intent' ]} \n "
181232 elif event ["type" ] == "BotIntent" :
182233 # If we have instructions, we add them before the bot message.
183234 # But we only do that for the last bot message.
184235 if "instructions" in event and idx == last_bot_intent_idx :
185236 history += f"# { event ['instructions' ]} \n "
186- history += f' bot { event [" intent" ]} \n '
237+ history += f" bot { event [' intent' ]} \n "
187238 elif event ["type" ] == "StartUtteranceBotAction" and include_texts :
188239 history += f' "{ event ["script" ]} "\n '
189240 # We skip system actions from this log
@@ -352,9 +403,9 @@ def flow_to_colang(flow: Union[dict, Flow]) -> str:
352403 if "_type" not in element :
353404 raise Exception ("bla" )
354405 if element ["_type" ] == "UserIntent" :
355- colang_flow += f' user { element [" intent_name" ]} \n '
406+ colang_flow += f" user { element [' intent_name' ]} \n "
356407 elif element ["_type" ] == "run_action" and element ["action_name" ] == "utter" :
357- colang_flow += f' bot { element [" action_params" ][ " value" ]} \n '
408+ colang_flow += f" bot { element [' action_params' ][ ' value' ]} \n "
358409
359410 return colang_flow
360411
@@ -592,3 +643,15 @@ def get_and_clear_reasoning_trace_contextvar() -> Optional[str]:
592643 reasoning_trace_var .set (None )
593644 return reasoning_trace
594645 return None
646+
647+
648+ def get_and_clear_tool_calls_contextvar () -> Optional [list ]:
649+ """Get the current tool calls and clear them from the context.
650+
651+ Returns:
652+ Optional[list]: The tool calls if they exist, None otherwise.
653+ """
654+ if tool_calls := tool_calls_var .get ():
655+ tool_calls_var .set (None )
656+ return tool_calls
657+ return None
0 commit comments