diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index 9de7ef439..b116fb25c 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -18,9 +18,6 @@ from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager -from langchain.prompts.base import StringPromptValue -from langchain.prompts.chat import ChatPromptValue -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from nemoguardrails.colang.v2_x.lang.colang_ast import Flow from nemoguardrails.colang.v2_x.runtime.flows import InternalEvent, InternalEvents @@ -30,6 +27,7 @@ reasoning_trace_var, tool_calls_var, ) +from nemoguardrails.integrations.langchain.message_utils import dicts_to_messages from nemoguardrails.logging.callbacks import logging_callbacks from nemoguardrails.logging.explain import LLMCallInfo @@ -146,34 +144,7 @@ async def _invoke_with_message_list( def _convert_messages_to_langchain_format(prompt: List[dict]) -> List: """Convert message list to LangChain message format.""" - messages = [] - for msg in prompt: - msg_type = msg["type"] if "type" in msg else msg["role"] - - if msg_type == "user": - messages.append(HumanMessage(content=msg["content"])) - elif msg_type in ["bot", "assistant"]: - tool_calls = msg.get("tool_calls") - if tool_calls: - messages.append( - AIMessage(content=msg["content"], tool_calls=tool_calls) - ) - else: - messages.append(AIMessage(content=msg["content"])) - elif msg_type == "system": - messages.append(SystemMessage(content=msg["content"])) - elif msg_type == "tool": - tool_message = ToolMessage( - content=msg["content"], - tool_call_id=msg.get("tool_call_id", ""), - ) - if msg.get("name"): - tool_message.name = msg["name"] - messages.append(tool_message) - else: - raise ValueError(f"Unknown message type {msg_type}") - - return messages + return dicts_to_messages(prompt) def _store_tool_calls(response) -> None: diff --git a/nemoguardrails/integrations/langchain/message_utils.py b/nemoguardrails/integrations/langchain/message_utils.py new file mode 100644 index 000000000..fa679d792 --- /dev/null +++ b/nemoguardrails/integrations/langchain/message_utils.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for converting between LangChain messages and dictionary format.""" + +from typing import Any, Dict, List, Optional, Type + +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) + + +def get_message_role(msg: BaseMessage) -> str: + """Get the role string for a BaseMessage.""" + if isinstance(msg, AIMessage): + return "assistant" + elif isinstance(msg, HumanMessage): + return "user" + elif isinstance(msg, SystemMessage): + return "system" + elif isinstance(msg, ToolMessage): + return "tool" + else: + return getattr(msg, "type", "user") + + +def get_message_class(msg_type: str) -> Type[BaseMessage]: + """Get the appropriate message class for a given type/role.""" + if msg_type == "user": + return HumanMessage + elif msg_type in ["bot", "assistant"]: + return AIMessage + elif msg_type in ["system", "developer"]: + return SystemMessage + elif msg_type == "tool": + return ToolMessage + else: + raise ValueError(f"Unknown message type: {msg_type}") + + +def message_to_dict(msg: BaseMessage) -> Dict[str, Any]: + """ + Convert a BaseMessage to dictionary format, preserving all model fields. + + Args: + msg: The BaseMessage to convert + + Returns: + Dictionary representation with role, content, and all other fields + """ + result = {"role": get_message_role(msg), "content": msg.content} + + if isinstance(msg, ToolMessage): + result["tool_call_id"] = msg.tool_call_id + + exclude_fields = {"type", "content", "example"} + + if hasattr(msg, "model_fields"): + for field_name in msg.model_fields: + if field_name not in exclude_fields and field_name not in result: + value = getattr(msg, field_name, None) + if value is not None: + result[field_name] = value + + return result + + +def dict_to_message(msg_dict: Dict[str, Any]) -> BaseMessage: + """ + Convert a dictionary to the appropriate BaseMessage type. + + Args: + msg_dict: Dictionary with role/type, content, and optional fields + + Returns: + The appropriate BaseMessage instance + """ + msg_type = msg_dict.get("type") or msg_dict.get("role") + if not msg_type: + raise ValueError("Message dictionary must have 'type' or 'role' field") + + content = msg_dict.get("content", "") + message_class = get_message_class(msg_type) + + exclude_keys = {"role", "type", "content"} + + valid_fields = ( + set(message_class.model_fields.keys()) + if hasattr(message_class, "model_fields") + else set() + ) + + kwargs = { + k: v + for k, v in msg_dict.items() + if k not in exclude_keys and k in valid_fields and v is not None + } + + if message_class == ToolMessage: + kwargs["tool_call_id"] = msg_dict.get("tool_call_id", "") + + return message_class(content=content, **kwargs) + + +def messages_to_dicts(messages: List[BaseMessage]) -> List[Dict[str, Any]]: + """ + Convert a list of BaseMessage objects to dictionary format. + + Args: + messages: List of BaseMessage objects + + Returns: + List of dictionary representations + """ + return [message_to_dict(msg) for msg in messages] + + +def dicts_to_messages(msg_dicts: List[Dict[str, Any]]) -> List[BaseMessage]: + """ + Convert a list of dictionaries to BaseMessage objects. + + Args: + msg_dicts: List of message dictionaries + + Returns: + List of appropriate BaseMessage instances + """ + return [dict_to_message(msg_dict) for msg_dict in msg_dicts] + + +def is_message_type(obj: Any, message_type: Type[BaseMessage]) -> bool: + """Check if an object is an instance of a specific message type.""" + return isinstance(obj, message_type) + + +def is_base_message(obj: Any) -> bool: + """Check if an object is any type of BaseMessage.""" + return isinstance(obj, BaseMessage) + + +def is_ai_message(obj: Any) -> bool: + """Check if an object is an AIMessage.""" + return isinstance(obj, AIMessage) + + +def is_human_message(obj: Any) -> bool: + """Check if an object is a HumanMessage.""" + return isinstance(obj, HumanMessage) + + +def is_system_message(obj: Any) -> bool: + """Check if an object is a SystemMessage.""" + return isinstance(obj, SystemMessage) + + +def is_tool_message(obj: Any) -> bool: + """Check if an object is a ToolMessage.""" + return isinstance(obj, ToolMessage) + + +def all_base_messages(items: List[Any]) -> bool: + """Check if all items in a list are BaseMessage instances.""" + return all(isinstance(item, BaseMessage) for item in items) + + +def create_ai_message( + content: str, + tool_calls: Optional[list] = None, + additional_kwargs: Optional[dict] = None, + response_metadata: Optional[dict] = None, + id: Optional[str] = None, + name: Optional[str] = None, + usage_metadata: Optional[dict] = None, + **extra_kwargs, +) -> AIMessage: + """Create an AIMessage with optional fields.""" + kwargs = {} + if tool_calls is not None: + kwargs["tool_calls"] = tool_calls + if additional_kwargs is not None: + kwargs["additional_kwargs"] = additional_kwargs + if response_metadata is not None: + kwargs["response_metadata"] = response_metadata + if id is not None: + kwargs["id"] = id + if name is not None: + kwargs["name"] = name + if usage_metadata is not None: + kwargs["usage_metadata"] = usage_metadata + + valid_fields = ( + set(AIMessage.model_fields.keys()) + if hasattr(AIMessage, "model_fields") + else set() + ) + for key, value in extra_kwargs.items(): + if key in valid_fields and key not in kwargs: + kwargs[key] = value + + return AIMessage(content=content, **kwargs) + + +def create_ai_message_chunk(content: str, **metadata) -> AIMessageChunk: + """Create an AIMessageChunk with optional metadata.""" + return AIMessageChunk(content=content, **metadata) + + +def create_human_message( + content: str, + additional_kwargs: Optional[dict] = None, + response_metadata: Optional[dict] = None, + id: Optional[str] = None, + name: Optional[str] = None, +) -> HumanMessage: + """Create a HumanMessage with optional fields.""" + kwargs = {} + if additional_kwargs is not None: + kwargs["additional_kwargs"] = additional_kwargs + if response_metadata is not None: + kwargs["response_metadata"] = response_metadata + if id is not None: + kwargs["id"] = id + if name is not None: + kwargs["name"] = name + + return HumanMessage(content=content, **kwargs) + + +def create_system_message( + content: str, + additional_kwargs: Optional[dict] = None, + response_metadata: Optional[dict] = None, + id: Optional[str] = None, + name: Optional[str] = None, +) -> SystemMessage: + """Create a SystemMessage with optional fields.""" + kwargs = {} + if additional_kwargs is not None: + kwargs["additional_kwargs"] = additional_kwargs + if response_metadata is not None: + kwargs["response_metadata"] = response_metadata + if id is not None: + kwargs["id"] = id + if name is not None: + kwargs["name"] = name + + return SystemMessage(content=content, **kwargs) + + +def create_tool_message( + content: str, + tool_call_id: str, + name: Optional[str] = None, + additional_kwargs: Optional[dict] = None, + response_metadata: Optional[dict] = None, + id: Optional[str] = None, + artifact: Optional[Any] = None, + status: Optional[str] = None, +) -> ToolMessage: + """Create a ToolMessage with optional fields.""" + kwargs = {"tool_call_id": tool_call_id} + if name is not None: + kwargs["name"] = name + if additional_kwargs is not None: + kwargs["additional_kwargs"] = additional_kwargs + if response_metadata is not None: + kwargs["response_metadata"] = response_metadata + if id is not None: + kwargs["id"] = id + if artifact is not None: + kwargs["artifact"] = artifact + if status is not None: + kwargs["status"] = status + + return ToolMessage(content=content, **kwargs) diff --git a/nemoguardrails/integrations/langchain/runnable_rails.py b/nemoguardrails/integrations/langchain/runnable_rails.py index 836f259c9..2e8e0fbf2 100644 --- a/nemoguardrails/integrations/langchain/runnable_rails.py +++ b/nemoguardrails/integrations/langchain/runnable_rails.py @@ -15,25 +15,23 @@ from __future__ import annotations -import asyncio import logging -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, TypeVar, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union from langchain_core.language_models import BaseLanguageModel -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - HumanMessage, - SystemMessage, - ToolMessage, -) from langchain_core.prompt_values import ChatPromptValue, StringPromptValue -from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable +from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.runnables.utils import Input, Output, gather_with_concurrency from langchain_core.tools import Tool from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails.integrations.langchain.message_utils import ( + all_base_messages, + create_ai_message, + create_ai_message_chunk, + is_base_message, + message_to_dict, +) from nemoguardrails.integrations.langchain.utils import async_wrap from nemoguardrails.rails.llm.options import GenerationOptions @@ -129,7 +127,7 @@ async def passthrough_fn(context: dict, events: List[dict]): # If the output is a string, we consider it to be the output text if isinstance(_output, str): text = _output - elif isinstance(_output, BaseMessage): + elif is_base_message(_output): text = _output.content else: text = _output.get(self.passthrough_bot_output_key) @@ -204,7 +202,7 @@ def _extract_text_from_input(self, _input) -> str: """Extract text content from various input types for passthrough mode.""" if isinstance(_input, str): return _input - elif isinstance(_input, BaseMessage): + elif is_base_message(_input): return _input.content elif isinstance(_input, dict) and self.passthrough_user_input_key in _input: return _input.get(self.passthrough_user_input_key) @@ -229,41 +227,11 @@ def _create_passthrough_messages(self, _input) -> List[Dict[str, Any]]: }, ] - def _message_to_dict(self, msg: BaseMessage) -> Dict[str, Any]: - """Convert a BaseMessage to dictionary format.""" - if isinstance(msg, AIMessage): - result = {"role": "assistant", "content": msg.content} - if hasattr(msg, "tool_calls") and msg.tool_calls: - result["tool_calls"] = msg.tool_calls - return result - elif isinstance(msg, HumanMessage): - return {"role": "user", "content": msg.content} - elif isinstance(msg, SystemMessage): - return {"role": "system", "content": msg.content} - elif isinstance(msg, ToolMessage): - result = { - "role": "tool", - "content": msg.content, - "tool_call_id": msg.tool_call_id, - } - if hasattr(msg, "name") and msg.name: - result["name"] = msg.name - return result - else: # Handle other message types - role = getattr(msg, "type", "user") - return {"role": role, "content": msg.content} - def _transform_chat_prompt_value( self, _input: ChatPromptValue ) -> List[Dict[str, Any]]: """Transform ChatPromptValue to messages list.""" - return [self._message_to_dict(msg) for msg in _input.messages] - - def _transform_base_message_list( - self, _input: List[BaseMessage] - ) -> List[Dict[str, Any]]: - """Transform list of BaseMessage objects to messages list.""" - return [self._message_to_dict(msg) for msg in _input] + return [message_to_dict(msg) for msg in _input.messages] def _extract_user_input_from_dict(self, _input: dict): """Extract user input from dictionary, checking configured key first.""" @@ -281,9 +249,9 @@ def _extract_user_input_from_dict(self, _input: dict): def _transform_dict_message_list(self, user_input: list) -> List[Dict[str, Any]]: """Transform list from dictionary input to messages.""" - if all(isinstance(msg, BaseMessage) for msg in user_input): + if all_base_messages(user_input): # Handle BaseMessage objects in the list - return [self._message_to_dict(msg) for msg in user_input] + return [message_to_dict(msg) for msg in user_input] elif all(isinstance(msg, dict) for msg in user_input): # Handle dict-style messages for msg in user_input: @@ -301,8 +269,8 @@ def _transform_dict_user_input(self, user_input) -> List[Dict[str, Any]]: """Transform user input value from dictionary.""" if isinstance(user_input, str): return [{"role": "user", "content": user_input}] - elif isinstance(user_input, BaseMessage): - return [self._message_to_dict(user_input)] + elif is_base_message(user_input): + return [message_to_dict(user_input)] elif isinstance(user_input, list): return self._transform_dict_message_list(user_input) else: @@ -344,12 +312,10 @@ def _transform_input_to_rails_format(self, _input) -> List[Dict[str, Any]]: return self._transform_chat_prompt_value(_input) elif isinstance(_input, StringPromptValue): return [{"role": "user", "content": _input.text}] - elif isinstance(_input, BaseMessage): - return [self._message_to_dict(_input)] - elif isinstance(_input, list) and all( - isinstance(msg, BaseMessage) for msg in _input - ): - return self._transform_base_message_list(_input) + elif is_base_message(_input): + return [message_to_dict(_input)] + elif isinstance(_input, list) and all_base_messages(_input): + return [message_to_dict(msg) for msg in _input] elif isinstance(_input, dict): return self._transform_dict_input(_input) elif isinstance(_input, str): @@ -420,10 +386,10 @@ def _format_chat_prompt_output( metadata_copy.pop("content", None) if tool_calls: metadata_copy["tool_calls"] = tool_calls - return AIMessage(content=content, **metadata_copy) + return create_ai_message(content=content, **metadata_copy) elif tool_calls: - return AIMessage(content=content, tool_calls=tool_calls) - return AIMessage(content=content) + return create_ai_message(content=content, tool_calls=tool_calls) + return create_ai_message(content=content) def _format_string_prompt_output(self, result: Any) -> str: """Format output for StringPromptValue input.""" @@ -443,10 +409,10 @@ def _format_message_output( metadata_copy.pop("content", None) if tool_calls: metadata_copy["tool_calls"] = tool_calls - return AIMessage(content=content, **metadata_copy) + return create_ai_message(content=content, **metadata_copy) elif tool_calls: - return AIMessage(content=content, tool_calls=tool_calls) - return AIMessage(content=content) + return create_ai_message(content=content, tool_calls=tool_calls) + return create_ai_message(content=content) def _format_dict_output_for_string_input( self, result: Any, output_key: str @@ -482,10 +448,12 @@ def _format_dict_output_for_base_message_list( metadata_copy.pop("content", None) if tool_calls: metadata_copy["tool_calls"] = tool_calls - return {output_key: AIMessage(content=content, **metadata_copy)} + return {output_key: create_ai_message(content=content, **metadata_copy)} elif tool_calls: - return {output_key: AIMessage(content=content, tool_calls=tool_calls)} - return {output_key: AIMessage(content=content)} + return { + output_key: create_ai_message(content=content, tool_calls=tool_calls) + } + return {output_key: create_ai_message(content=content)} def _format_dict_output_for_base_message( self, @@ -501,10 +469,12 @@ def _format_dict_output_for_base_message( metadata_copy = metadata.copy() if tool_calls: metadata_copy["tool_calls"] = tool_calls - return {output_key: AIMessage(content=content, **metadata_copy)} + return {output_key: create_ai_message(content=content, **metadata_copy)} elif tool_calls: - return {output_key: AIMessage(content=content, tool_calls=tool_calls)} - return {output_key: AIMessage(content=content)} + return { + output_key: create_ai_message(content=content, tool_calls=tool_calls) + } + return {output_key: create_ai_message(content=content)} def _format_dict_output( self, @@ -528,13 +498,13 @@ def _format_dict_output( return self._format_dict_output_for_dict_message_list( result, output_key ) - elif all(isinstance(msg, BaseMessage) for msg in user_input): + elif all_base_messages(user_input): return self._format_dict_output_for_base_message_list( result, output_key, tool_calls, metadata ) else: return {output_key: result} - elif isinstance(user_input, BaseMessage): + elif is_base_message(user_input): return self._format_dict_output_for_base_message( result, output_key, tool_calls, metadata ) @@ -575,11 +545,9 @@ def _format_output( return self._format_chat_prompt_output(result, tool_calls, metadata) elif isinstance(input, StringPromptValue): return self._format_string_prompt_output(result) - elif isinstance(input, (HumanMessage, AIMessage, BaseMessage)): + elif is_base_message(input): return self._format_message_output(result, tool_calls, metadata) - elif isinstance(input, list) and all( - isinstance(msg, BaseMessage) for msg in input - ): + elif isinstance(input, list) and all_base_messages(input): return self._format_message_output(result, tool_calls, metadata) elif isinstance(input, dict): return self._format_dict_output(input, result, tool_calls, metadata) @@ -659,7 +627,7 @@ def _convert_messages_to_rails_format(self, messages) -> List[dict]: rails_messages = [] for msg in messages: if hasattr(msg, "role") and hasattr(msg, "content"): - # LangChain BaseMessage format + # LangChain message format rails_messages.append( { "role": ( @@ -893,15 +861,13 @@ def _format_streaming_chunk(self, input: Any, chunk) -> Any: if generation_info: metadata = generation_info.copy() if isinstance(input, ChatPromptValue): - return AIMessageChunk(content=text_content, **metadata) + return create_ai_message_chunk(content=text_content, **metadata) elif isinstance(input, StringPromptValue): return text_content # String outputs don't support metadata - elif isinstance(input, (HumanMessage, AIMessage, BaseMessage)): - return AIMessageChunk(content=text_content, **metadata) - elif isinstance(input, list) and all( - isinstance(msg, BaseMessage) for msg in input - ): - return AIMessageChunk(content=text_content, **metadata) + elif is_base_message(input): + return create_ai_message_chunk(content=text_content, **metadata) + elif isinstance(input, list) and all_base_messages(input): + return create_ai_message_chunk(content=text_content, **metadata) elif isinstance(input, dict): output_key = self.passthrough_bot_output_key if self.passthrough_user_input_key in input or "input" in input: @@ -917,18 +883,22 @@ def _format_streaming_chunk(self, input: Any, chunk) -> Any: return { output_key: {"role": "assistant", "content": text_content} } - elif all(isinstance(msg, BaseMessage) for msg in user_input): + elif all_base_messages(user_input): return { - output_key: AIMessageChunk(content=text_content, **metadata) + output_key: create_ai_message_chunk( + content=text_content, **metadata + ) } return {output_key: text_content} - elif isinstance(user_input, BaseMessage): + elif is_base_message(user_input): return { - output_key: AIMessageChunk(content=text_content, **metadata) + output_key: create_ai_message_chunk( + content=text_content, **metadata + ) } return {output_key: text_content} elif isinstance(input, str): - return AIMessageChunk(content=text_content, **metadata) + return create_ai_message_chunk(content=text_content, **metadata) else: raise ValueError(f"Unexpected input type: {type(input)}") diff --git a/tests/runnable_rails/test_message_utils.py b/tests/runnable_rails/test_message_utils.py new file mode 100644 index 000000000..d424433bb --- /dev/null +++ b/tests/runnable_rails/test_message_utils.py @@ -0,0 +1,496 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + HumanMessage, + SystemMessage, + ToolMessage, +) + +from nemoguardrails.integrations.langchain.message_utils import ( + all_base_messages, + create_ai_message, + create_ai_message_chunk, + create_human_message, + create_system_message, + create_tool_message, + dict_to_message, + dicts_to_messages, + get_message_class, + get_message_role, + is_ai_message, + is_base_message, + is_human_message, + is_message_type, + is_system_message, + is_tool_message, + message_to_dict, + messages_to_dicts, +) + + +class TestMessageRoleAndClass: + def test_get_message_role_ai(self): + msg = AIMessage(content="test") + assert get_message_role(msg) == "assistant" + + def test_get_message_role_human(self): + msg = HumanMessage(content="test") + assert get_message_role(msg) == "user" + + def test_get_message_role_system(self): + msg = SystemMessage(content="test") + assert get_message_role(msg) == "system" + + def test_get_message_role_tool(self): + msg = ToolMessage(content="test", tool_call_id="123") + assert get_message_role(msg) == "tool" + + def test_get_message_class_user(self): + assert get_message_class("user") == HumanMessage + + def test_get_message_class_assistant(self): + assert get_message_class("assistant") == AIMessage + + def test_get_message_class_bot(self): + assert get_message_class("bot") == AIMessage + + def test_get_message_class_system(self): + assert get_message_class("system") == SystemMessage + + def test_get_message_class_developer(self): + assert get_message_class("developer") == SystemMessage + + def test_get_message_class_tool(self): + assert get_message_class("tool") == ToolMessage + + def test_get_message_class_unknown(self): + with pytest.raises(ValueError, match="Unknown message type"): + get_message_class("unknown") + + +class TestMessageConversion: + def test_ai_message_with_tool_calls(self): + original = AIMessage( + content="", + tool_calls=[ + { + "name": "search", + "args": {"query": "test"}, + "id": "call_123", + "type": "tool_call", + } + ], + additional_kwargs={ + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "search", + "arguments": '{"query": "test"}', + }, + } + ] + }, + response_metadata={"model": "gpt-4"}, + id="msg-001", + ) + + msg_dict = message_to_dict(original) + recreated = dict_to_message(msg_dict) + + assert isinstance(recreated, AIMessage) + assert recreated.content == original.content + assert recreated.tool_calls == original.tool_calls + assert recreated.additional_kwargs == original.additional_kwargs + assert recreated.response_metadata == original.response_metadata + assert recreated.id == original.id + + def test_ai_message_with_invalid_tool_calls(self): + original = AIMessage( + content="", + invalid_tool_calls=[ + { + "name": "malformed_tool", + "args": "invalid json string", + "id": "call_invalid", + "error": "Invalid JSON in arguments", + } + ], + id="msg-invalid", + ) + + msg_dict = message_to_dict(original) + recreated = dict_to_message(msg_dict) + + assert isinstance(recreated, AIMessage) + assert recreated.content == original.content + assert recreated.invalid_tool_calls == original.invalid_tool_calls + assert recreated.id == original.id + + def test_tool_message(self): + original = ToolMessage( + content="Result data", + tool_call_id="call_123", + name="search", + additional_kwargs={"extra": "data"}, + id="tool-msg-001", + ) + + msg_dict = message_to_dict(original) + recreated = dict_to_message(msg_dict) + + assert isinstance(recreated, ToolMessage) + assert recreated.content == original.content + assert recreated.tool_call_id == original.tool_call_id + assert recreated.name == original.name + assert recreated.additional_kwargs == original.additional_kwargs + assert recreated.id == original.id + + def test_human_message_basic(self): + original = HumanMessage(content="Hello", id="human-1") + msg_dict = message_to_dict(original) + recreated = dict_to_message(msg_dict) + + assert isinstance(recreated, HumanMessage) + assert recreated.content == original.content + assert recreated.id == original.id + + def test_system_message_basic(self): + original = SystemMessage(content="System prompt", id="sys-1") + msg_dict = message_to_dict(original) + recreated = dict_to_message(msg_dict) + + assert isinstance(recreated, SystemMessage) + assert recreated.content == original.content + assert recreated.id == original.id + + def test_developer_role_conversion(self): + msg_dict = {"role": "developer", "content": "Developer instructions"} + msg = dict_to_message(msg_dict) + assert isinstance(msg, SystemMessage) + assert msg.content == "Developer instructions" + + def test_empty_collections_now_included(self): + msg = AIMessage(content="Test", additional_kwargs={}, tool_calls=[]) + msg_dict = message_to_dict(msg) + + assert "additional_kwargs" in msg_dict + assert "tool_calls" in msg_dict + assert msg_dict["additional_kwargs"] == {} + assert msg_dict["tool_calls"] == [] + + def test_message_to_dict_preserves_role(self): + human_msg = HumanMessage(content="test") + ai_msg = AIMessage(content="test") + system_msg = SystemMessage(content="test") + + assert message_to_dict(human_msg)["role"] == "user" + assert message_to_dict(ai_msg)["role"] == "assistant" + assert message_to_dict(system_msg)["role"] == "system" + + +class TestBatchConversion: + def test_messages_to_dicts(self): + originals = [ + HumanMessage(content="Hello", id="human-1"), + AIMessage( + content="Hi there", + tool_calls=[ + {"name": "tool", "args": {}, "id": "c1", "type": "tool_call"} + ], + id="ai-1", + ), + ToolMessage(content="Tool result", tool_call_id="c1", name="tool"), + SystemMessage(content="System prompt", id="sys-1"), + ] + + dicts = messages_to_dicts(originals) + + assert len(dicts) == len(originals) + assert dicts[0]["role"] == "user" + assert dicts[1]["role"] == "assistant" + assert dicts[2]["role"] == "tool" + assert dicts[3]["role"] == "system" + + def test_dicts_to_messages(self): + msg_dicts = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + {"role": "tool", "content": "Result", "tool_call_id": "123"}, + {"role": "system", "content": "System"}, + ] + + messages = dicts_to_messages(msg_dicts) + + assert len(messages) == len(msg_dicts) + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert isinstance(messages[2], ToolMessage) + assert isinstance(messages[3], SystemMessage) + + def test_round_trip_conversion(self): + originals = [ + HumanMessage(content="Test 1", id="h1", name="user1"), + AIMessage( + content="Test 2", + id="a1", + tool_calls=[ + {"name": "func", "args": {"x": 1}, "id": "tc1", "type": "tool_call"} + ], + ), + SystemMessage(content="Test 3", id="s1"), + ToolMessage(content="Test 4", tool_call_id="tc1", name="func", id="t1"), + ] + + dicts = messages_to_dicts(originals) + recreated = dicts_to_messages(dicts) + + for orig, recr in zip(originals, recreated): + assert type(orig) is type(recr) + assert orig.content == recr.content + if hasattr(orig, "id") and orig.id: + assert orig.id == recr.id + if hasattr(orig, "name") and orig.name: + assert orig.name == recr.name + + +class TestTypeChecking: + def test_is_message_type(self): + ai_msg = AIMessage(content="test") + human_msg = HumanMessage(content="test") + + assert is_message_type(ai_msg, AIMessage) + assert not is_message_type(ai_msg, HumanMessage) + assert is_message_type(human_msg, HumanMessage) + assert not is_message_type(human_msg, AIMessage) + + def test_is_base_message(self): + assert is_base_message(AIMessage(content="test")) + assert is_base_message(HumanMessage(content="test")) + assert is_base_message(SystemMessage(content="test")) + assert is_base_message(ToolMessage(content="test", tool_call_id="123")) + assert not is_base_message("not a message") + assert not is_base_message({"role": "user", "content": "test"}) + + def test_is_ai_message(self): + ai_msg = AIMessage(content="test") + assert is_ai_message(ai_msg) + + assert not is_ai_message(HumanMessage(content="test")) + assert not is_ai_message(SystemMessage(content="test")) + assert not is_ai_message(ToolMessage(content="test", tool_call_id="123")) + assert not is_ai_message("not a message") + + def test_is_human_message(self): + human_msg = HumanMessage(content="test") + assert is_human_message(human_msg) + + assert not is_human_message(AIMessage(content="test")) + assert not is_human_message(SystemMessage(content="test")) + assert not is_human_message(ToolMessage(content="test", tool_call_id="123")) + assert not is_human_message("not a message") + + def test_is_system_message(self): + assert is_system_message(SystemMessage(content="test")) + assert not is_system_message(AIMessage(content="test")) + + def test_is_tool_message(self): + assert is_tool_message(ToolMessage(content="test", tool_call_id="123")) + assert not is_tool_message(AIMessage(content="test")) + + def test_all_base_messages(self): + messages = [ + AIMessage(content="1"), + HumanMessage(content="2"), + SystemMessage(content="3"), + ] + assert all_base_messages(messages) + + mixed = [AIMessage(content="1"), "not a message"] + assert not all_base_messages(mixed) + + assert all_base_messages([]) + + +class TestMessageCreation: + def test_create_ai_message_basic(self): + msg = create_ai_message("test content") + assert msg.content == "test content" + assert isinstance(msg, AIMessage) + + def test_create_ai_message_with_tool_calls(self): + tool_calls = [{"name": "func", "args": {}, "id": "123", "type": "tool_call"}] + usage_metadata = { + "input_tokens": 50, + "output_tokens": 50, + "total_tokens": 100, + } + msg = create_ai_message( + "content", + tool_calls=tool_calls, + additional_kwargs={"key": "value"}, + response_metadata={"model": "gpt-4"}, + id="msg-1", + usage_metadata=usage_metadata, + ) + + assert msg.content == "content" + assert msg.tool_calls == tool_calls + assert msg.additional_kwargs == {"key": "value"} + assert msg.response_metadata == {"model": "gpt-4"} + assert msg.id == "msg-1" + assert msg.usage_metadata == usage_metadata + + def test_create_ai_message_chunk(self): + chunk = create_ai_message_chunk("chunk content", id="chunk-1") + assert chunk.content == "chunk content" + assert isinstance(chunk, AIMessageChunk) + assert chunk.id == "chunk-1" + + def test_create_human_message(self): + msg = create_human_message( + "user input", + additional_kwargs={"meta": "data"}, + response_metadata={"source": "user"}, + id="human-1", + name="user1", + ) + + assert msg.content == "user input" + assert msg.additional_kwargs == {"meta": "data"} + assert msg.response_metadata == {"source": "user"} + assert msg.id == "human-1" + assert msg.name == "user1" + + def test_create_system_message(self): + msg = create_system_message( + "system prompt", + additional_kwargs={"sys": "info"}, + response_metadata={"type": "system"}, + id="sys-1", + name="system", + ) + + assert msg.content == "system prompt" + assert msg.additional_kwargs == {"sys": "info"} + assert msg.response_metadata == {"type": "system"} + assert msg.id == "sys-1" + assert msg.name == "system" + + def test_create_tool_message(self): + msg = create_tool_message( + "tool result", + tool_call_id="call-123", + name="calculator", + additional_kwargs={"result": "success"}, + response_metadata={"tool": "calc"}, + id="tool-1", + status="success", + ) + + assert msg.content == "tool result" + assert msg.tool_call_id == "call-123" + assert msg.name == "calculator" + assert msg.additional_kwargs == {"result": "success"} + assert msg.response_metadata == {"tool": "calc"} + assert msg.id == "tool-1" + assert msg.status == "success" + + +class TestEdgeCases: + def test_falsey_values_preservation(self): + original = AIMessage( + content="Test", + additional_kwargs={}, + tool_calls=[], + name="", + response_metadata={}, + id="test-id", + ) + msg_dict = message_to_dict(original) + recreated = dict_to_message(msg_dict) + + assert recreated.additional_kwargs == {} + assert recreated.tool_calls == [] + assert recreated.name == "" + assert recreated.response_metadata == {} + assert recreated.id == "test-id" + + def test_human_message_with_empty_name(self): + original = HumanMessage(content="Hello", name="") + msg_dict = message_to_dict(original) + recreated = dict_to_message(msg_dict) + + assert isinstance(recreated, HumanMessage) + assert recreated.name == "" + + def test_system_message_with_empty_additional_kwargs(self): + original = SystemMessage(content="System prompt", additional_kwargs={}) + msg_dict = message_to_dict(original) + recreated = dict_to_message(msg_dict) + + assert isinstance(recreated, SystemMessage) + assert recreated.additional_kwargs == {} + + def test_dict_to_message_missing_role_and_type(self): + with pytest.raises(ValueError, match="must have 'type' or 'role'"): + dict_to_message({"content": "test"}) + + def test_dict_to_message_with_type_field(self): + msg = dict_to_message({"type": "user", "content": "test"}) + assert isinstance(msg, HumanMessage) + + def test_dict_to_message_with_role_field(self): + msg = dict_to_message({"role": "user", "content": "test"}) + assert isinstance(msg, HumanMessage) + + def test_tool_message_without_tool_call_id(self): + msg_dict = {"role": "tool", "content": "test"} + msg = dict_to_message(msg_dict) + assert isinstance(msg, ToolMessage) + assert msg.tool_call_id == "" + + def test_message_with_none_values(self): + original = AIMessage( + content="test", + additional_kwargs={"valid": "value"}, + ) + msg_dict = message_to_dict(original) + + assert msg_dict["content"] == "test" + assert msg_dict["role"] == "assistant" + assert "additional_kwargs" in msg_dict + assert msg_dict["additional_kwargs"] == {"valid": "value"} + + def test_preserves_unknown_fields_in_dict(self): + msg_dict = { + "role": "assistant", + "content": "test", + "id": "123", + "name": "bot", + } + msg = dict_to_message(msg_dict) + + assert isinstance(msg, AIMessage) + assert msg.content == "test" + assert msg.id == "123" + assert msg.name == "bot" diff --git a/tests/runnable_rails/test_transform_input.py b/tests/runnable_rails/test_transform_input.py index 4b8c026fc..b2bf12ec9 100644 --- a/tests/runnable_rails/test_transform_input.py +++ b/tests/runnable_rails/test_transform_input.py @@ -63,9 +63,26 @@ def test_transform_chat_prompt_value(rails): result = rails._transform_input_to_rails_format(chat_prompt) expected = [ - {"role": "system", "content": "You are helpful"}, - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there"}, + { + "role": "system", + "content": "You are helpful", + "additional_kwargs": {}, + "response_metadata": {}, + }, + { + "role": "user", + "content": "Hello", + "additional_kwargs": {}, + "response_metadata": {}, + }, + { + "role": "assistant", + "content": "Hi there", + "additional_kwargs": {}, + "response_metadata": {}, + "tool_calls": [], + "invalid_tool_calls": [], + }, ] assert result == expected @@ -139,8 +156,20 @@ def test_transform_list_of_base_messages(rails): result = rails._transform_input_to_rails_format(messages) expected = [ - {"role": "user", "content": "What is Python?"}, - {"role": "assistant", "content": "Python is a programming language"}, + { + "role": "user", + "content": "What is Python?", + "additional_kwargs": {}, + "response_metadata": {}, + }, + { + "role": "assistant", + "content": "Python is a programming language", + "additional_kwargs": {}, + "response_metadata": {}, + "tool_calls": [], + "invalid_tool_calls": [], + }, ] assert result == expected @@ -150,7 +179,14 @@ def test_transform_single_human_message(rails): message = HumanMessage(content="Hello there") result = rails._transform_input_to_rails_format(message) - expected = [{"role": "user", "content": "Hello there"}] + expected = [ + { + "role": "user", + "content": "Hello there", + "additional_kwargs": {}, + "response_metadata": {}, + } + ] assert result == expected @@ -159,7 +195,16 @@ def test_transform_single_ai_message(rails): message = AIMessage(content="Hello back") result = rails._transform_input_to_rails_format(message) - expected = [{"role": "assistant", "content": "Hello back"}] + expected = [ + { + "role": "assistant", + "content": "Hello back", + "additional_kwargs": {}, + "response_metadata": {}, + "tool_calls": [], + "invalid_tool_calls": [], + } + ] assert result == expected diff --git a/tests/test_tool_calling_utils.py b/tests/test_tool_calling_utils.py index 5312b0b6f..e18fa6c67 100644 --- a/tests/test_tool_calling_utils.py +++ b/tests/test_tool_calling_utils.py @@ -141,7 +141,7 @@ def test_convert_messages_to_langchain_format_unknown_type(): """Test that unknown message types raise ValueError.""" messages = [{"role": "unknown", "content": "Unknown message"}] - with pytest.raises(ValueError, match="Unknown message type unknown"): + with pytest.raises(ValueError, match="Unknown message type: unknown"): _convert_messages_to_langchain_format(messages)