Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ async def llm_call(
generation_llm, prompt, all_callbacks
)

_store_reasoning_traces(response)
_store_tool_calls(response)
_store_response_metadata(response)
return _extract_content(response)
Expand Down Expand Up @@ -172,6 +173,18 @@ def _convert_messages_to_langchain_format(prompt: List[dict]) -> List:
return dicts_to_messages(prompt)


def _store_reasoning_traces(response) -> None:
if hasattr(response, "additional_kwargs"):
additional_kwargs = response.additional_kwargs
if (
isinstance(additional_kwargs, dict)
and "reasoning_content" in additional_kwargs
):
reasoning_content = additional_kwargs["reasoning_content"]
if reasoning_content:
reasoning_trace_var.set(reasoning_content)


def _store_tool_calls(response) -> None:
"""Extract and store tool calls from response in context."""
tool_calls = getattr(response, "tool_calls", None)
Expand All @@ -192,15 +205,6 @@ def _store_response_metadata(response) -> None:
metadata[field_name] = getattr(response, field_name)
llm_response_metadata_var.set(metadata)

if hasattr(response, "additional_kwargs"):
additional_kwargs = response.additional_kwargs
if (
isinstance(additional_kwargs, dict)
and "reasoning_content" in additional_kwargs
):
reasoning_content = additional_kwargs["reasoning_content"]
if reasoning_content:
reasoning_trace_var.set(reasoning_content)
else:
llm_response_metadata_var.set(None)

Expand Down Expand Up @@ -704,6 +708,12 @@ def extract_tool_calls_from_events(events: list) -> Optional[list]:
return None


def extract_bot_thinking_from_events(events: list):
for event in events:
if event.get("type") == "BotThinking":
return event.get("content")


def get_and_clear_response_metadata_contextvar() -> Optional[dict]:
"""Get the current response metadata and clear it from the context.

Expand Down
6 changes: 3 additions & 3 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@

from nemoguardrails.actions.llm.generation import LLMGenerationActions
from nemoguardrails.actions.llm.utils import (
extract_bot_thinking_from_events,
extract_tool_calls_from_events,
get_and_clear_reasoning_trace_contextvar,
get_and_clear_response_metadata_contextvar,
get_colang_history,
)
Expand Down Expand Up @@ -1037,7 +1037,7 @@ async def generate_async(
else:
res = GenerationResponse(response=[new_message])

if reasoning_trace := get_and_clear_reasoning_trace_contextvar():
if reasoning_trace := extract_bot_thinking_from_events(events):
if prompt:
# For prompt mode, response should be a string
if isinstance(res.response, str):
Expand Down Expand Up @@ -1182,7 +1182,7 @@ async def generate_async(
else:
# If a prompt is used, we only return the content of the message.

if reasoning_trace := get_and_clear_reasoning_trace_contextvar():
if reasoning_trace := extract_bot_thinking_from_events(events):
new_message["content"] = reasoning_trace + new_message["content"]

if prompt:
Expand Down
306 changes: 306 additions & 0 deletions tests/test_reasoning_trace_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import AsyncMock

import pytest
from langchain_core.messages import AIMessage

from nemoguardrails.actions.llm.utils import _store_reasoning_traces
from nemoguardrails.context import reasoning_trace_var


class TestStoreReasoningTracesUnit:
def test_store_reasoning_traces_with_valid_reasoning_content(self):
test_reasoning = "Step 1: Analyze the question\nStep 2: Formulate response"

response = AIMessage(
content="The answer is 42",
additional_kwargs={"reasoning_content": test_reasoning},
)

_store_reasoning_traces(response)

stored_trace = reasoning_trace_var.get()
assert stored_trace == test_reasoning

reasoning_trace_var.set(None)

def test_store_reasoning_traces_with_empty_reasoning_content(self):
response = AIMessage(
content="Response", additional_kwargs={"reasoning_content": ""}
)

reasoning_trace_var.set(None)
_store_reasoning_traces(response)

stored_trace = reasoning_trace_var.get()
assert stored_trace is None

reasoning_trace_var.set(None)

def test_store_reasoning_traces_with_none_reasoning_content(self):
response = AIMessage(
content="Response", additional_kwargs={"reasoning_content": None}
)

reasoning_trace_var.set(None)
_store_reasoning_traces(response)

stored_trace = reasoning_trace_var.get()
assert stored_trace is None

reasoning_trace_var.set(None)

def test_store_reasoning_traces_without_reasoning_content_key(self):
response = AIMessage(
content="Response", additional_kwargs={"other_key": "other_value"}
)

reasoning_trace_var.set(None)
_store_reasoning_traces(response)

stored_trace = reasoning_trace_var.get()
assert stored_trace is None

reasoning_trace_var.set(None)

def test_store_reasoning_traces_with_empty_additional_kwargs(self):
response = AIMessage(content="Response", additional_kwargs={})

reasoning_trace_var.set(None)
_store_reasoning_traces(response)

stored_trace = reasoning_trace_var.get()
assert stored_trace is None

reasoning_trace_var.set(None)

def test_store_reasoning_traces_without_additional_kwargs_attribute(self):
class SimpleResponse:
def __init__(self, content):
self.content = content

response = SimpleResponse("Response")

reasoning_trace_var.set(None)
_store_reasoning_traces(response)

stored_trace = reasoning_trace_var.get()
assert stored_trace is None

reasoning_trace_var.set(None)

def test_store_reasoning_traces_with_non_dict_additional_kwargs(self):
class ResponseWithInvalidKwargs:
def __init__(self):
self.content = "Response"
self.additional_kwargs = "not_a_dict"

response = ResponseWithInvalidKwargs()

reasoning_trace_var.set(None)
_store_reasoning_traces(response)

stored_trace = reasoning_trace_var.get()
assert stored_trace is None

reasoning_trace_var.set(None)

def test_store_reasoning_traces_overwrites_previous_trace(self):
initial_trace = "Initial reasoning"
new_trace = "New reasoning"

reasoning_trace_var.set(initial_trace)

response = AIMessage(
content="Response", additional_kwargs={"reasoning_content": new_trace}
)

_store_reasoning_traces(response)

stored_trace = reasoning_trace_var.get()
assert stored_trace == new_trace
assert stored_trace != initial_trace

reasoning_trace_var.set(None)

def test_store_reasoning_traces_with_multiline_content(self):
multiline_reasoning = """Thought process:
1. First, understand the user's intent
2. Second, check available data
3. Third, formulate a response
4. Finally, validate the response"""

response = AIMessage(
content="Response",
additional_kwargs={"reasoning_content": multiline_reasoning},
)

_store_reasoning_traces(response)

stored_trace = reasoning_trace_var.get()
assert stored_trace == multiline_reasoning

reasoning_trace_var.set(None)

def test_store_reasoning_traces_with_special_characters(self):
special_reasoning = "Thinking: Let's analyze this <step> with \"quotes\" and 'apostrophes' & symbols!"

response = AIMessage(
content="Response",
additional_kwargs={"reasoning_content": special_reasoning},
)

_store_reasoning_traces(response)

stored_trace = reasoning_trace_var.get()
assert stored_trace == special_reasoning

reasoning_trace_var.set(None)


class TestReasoningTraceIntegration:
@pytest.mark.asyncio
async def test_llm_call_extracts_reasoning_from_additional_kwargs(self):
test_reasoning = "Let me think about this carefully..."

mock_llm = AsyncMock()
mock_response = AIMessage(
content="The answer is 42",
additional_kwargs={"reasoning_content": test_reasoning},
)
mock_llm.ainvoke = AsyncMock(return_value=mock_response)

from nemoguardrails.actions.llm.utils import llm_call

reasoning_trace_var.set(None)
result = await llm_call(mock_llm, "What is the answer?")

assert result == "The answer is 42"
stored_trace = reasoning_trace_var.get()
assert stored_trace == test_reasoning

reasoning_trace_var.set(None)

@pytest.mark.asyncio
async def test_llm_call_handles_missing_reasoning_content(self):
mock_llm = AsyncMock()
mock_response = AIMessage(content="Regular response", additional_kwargs={})
mock_llm.ainvoke = AsyncMock(return_value=mock_response)

from nemoguardrails.actions.llm.utils import llm_call

reasoning_trace_var.set(None)
result = await llm_call(mock_llm, "Hello")

assert result == "Regular response"
stored_trace = reasoning_trace_var.get()
assert stored_trace is None

reasoning_trace_var.set(None)

@pytest.mark.asyncio
async def test_llm_call_with_message_list_extracts_reasoning(self):
test_reasoning = "Analyzing the conversation context..."

mock_llm = AsyncMock()
mock_response = AIMessage(
content="Here's my response",
additional_kwargs={"reasoning_content": test_reasoning},
)
mock_llm.ainvoke = AsyncMock(return_value=mock_response)

from nemoguardrails.actions.llm.utils import llm_call

messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]

reasoning_trace_var.set(None)
result = await llm_call(mock_llm, messages)

assert result == "Here's my response"
stored_trace = reasoning_trace_var.get()
assert stored_trace == test_reasoning

reasoning_trace_var.set(None)

@pytest.mark.asyncio
async def test_multiple_llm_calls_preserve_separate_reasoning_traces(self):
first_reasoning = "First analysis"
second_reasoning = "Second analysis"

mock_llm = AsyncMock()
call_count = 0

async def mock_ainvoke(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return AIMessage(
content="First response",
additional_kwargs={"reasoning_content": first_reasoning},
)
else:
return AIMessage(
content="Second response",
additional_kwargs={"reasoning_content": second_reasoning},
)

mock_llm.ainvoke = mock_ainvoke

from nemoguardrails.actions.llm.utils import llm_call

reasoning_trace_var.set(None)
result1 = await llm_call(mock_llm, "First query")
trace1 = reasoning_trace_var.get()

reasoning_trace_var.set(None)
result2 = await llm_call(mock_llm, "Second query")
trace2 = reasoning_trace_var.get()

assert trace1 == first_reasoning
assert trace2 == second_reasoning

reasoning_trace_var.set(None)

@pytest.mark.asyncio
async def test_reasoning_content_with_other_additional_kwargs(self):
test_reasoning = "Complex reasoning process"

mock_llm = AsyncMock()
mock_response = AIMessage(
content="Response",
additional_kwargs={
"reasoning_content": test_reasoning,
"model": "test-model",
"finish_reason": "stop",
"other_metadata": {"key": "value"},
},
)
mock_llm.ainvoke = AsyncMock(return_value=mock_response)

from nemoguardrails.actions.llm.utils import llm_call

reasoning_trace_var.set(None)
result = await llm_call(mock_llm, "Query")

assert result == "Response"
stored_trace = reasoning_trace_var.get()
assert stored_trace == test_reasoning

reasoning_trace_var.set(None)