Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions haystack/components/builders/chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from haystack.lazy_imports import LazyImport
from haystack.utils import Jinja2TimeExtension
from haystack.utils.jinja2_chat_extension import ChatMessageExtension, templatize_part
from haystack.utils.jinja2_extensions import _collect_assigned_variables

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -179,13 +180,17 @@ def __init__(
raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=message.role.value, message=message))
if message.text and "templatize_part" in message.text:
raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE)
ast = self._env.parse(message.text)
template_variables = meta.find_undeclared_variables(ast)
extracted_variables += list(template_variables)
jinja2_ast = self._env.parse(message.text)
template_variables = meta.find_undeclared_variables(jinja2_ast)
assigned_variables = _collect_assigned_variables(jinja2_ast)
extracted_variables += list(template_variables - assigned_variables)
elif isinstance(template, str):
ast = self._env.parse(template)
extracted_variables = list(meta.find_undeclared_variables(ast))
jinja2_ast = self._env.parse(template)
template_variables = meta.find_undeclared_variables(jinja2_ast)
assigned_variables = _collect_assigned_variables(jinja2_ast)
extracted_variables = list(template_variables - assigned_variables)

extracted_variables = extracted_variables or []
self.variables = variables or extracted_variables
self.required_variables = required_variables or []

Expand Down
11 changes: 7 additions & 4 deletions haystack/components/builders/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from haystack import component, default_to_dict, logging
from haystack.utils import Jinja2TimeExtension
from haystack.utils.jinja2_extensions import _collect_assigned_variables

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -174,11 +175,13 @@ def __init__(
self._env = SandboxedEnvironment()

self.template = self._env.from_string(template)

if not variables:
# infer variables from template
ast = self._env.parse(template)
template_variables = meta.find_undeclared_variables(ast)
variables = list(template_variables)
jinja2_ast = self._env.parse(template)
template_variables = meta.find_undeclared_variables(jinja2_ast)
assigned_variables = _collect_assigned_variables(jinja2_ast)
variables = list(template_variables - assigned_variables)

variables = variables or []
self.variables = variables

Expand Down
11 changes: 7 additions & 4 deletions haystack/components/converters/output_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
from haystack.utils.jinja2_extensions import _collect_assigned_variables

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(
output_type: TypeAlias,
custom_filters: Optional[dict[str, Callable]] = None,
unsafe: bool = False,
):
) -> None:
"""
Create an OutputAdapter component.

Expand Down Expand Up @@ -179,7 +180,9 @@ def _extract_variables(self, env: Environment) -> set[str]:
Extracts all variables from a list of Jinja template strings.

:param env: A Jinja environment.
:return: A set of variable names extracted from the template strings.
:returns: A set of variable names extracted from the template strings.
"""
ast = env.parse(self.template)
return meta.find_undeclared_variables(ast)
jinja2_ast = env.parse(self.template)
template_variables = meta.find_undeclared_variables(jinja2_ast)
assigned_variables = _collect_assigned_variables(jinja2_ast)
return template_variables - assigned_variables
9 changes: 7 additions & 2 deletions haystack/components/routers/conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
from haystack.utils.jinja2_extensions import _collect_assigned_variables

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -403,7 +404,8 @@ def _validate_routes(self, routes: list[Route]):
if not self._validate_template(self._env, output):
raise ValueError(f"Invalid template for output: {output}")

def _extract_variables(self, env: Environment, templates: list[str]) -> set[str]:
@staticmethod
def _extract_variables(env: Environment, templates: list[str]) -> set[str]:
"""
Extracts all variables from a list of Jinja template strings.

Expand All @@ -413,7 +415,10 @@ def _extract_variables(self, env: Environment, templates: list[str]) -> set[str]
"""
variables = set()
for template in templates:
variables.update(meta.find_undeclared_variables(env.parse(template)))
jinja2_ast = env.parse(template)
template_variables = meta.find_undeclared_variables(jinja2_ast)
assigned_variables = _collect_assigned_variables(jinja2_ast)
variables.update(template_variables - assigned_variables)
return variables

def _validate_template(self, env: Environment, template_text: str):
Expand Down
23 changes: 23 additions & 0 deletions haystack/utils/jinja2_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,26 @@ def parse(self, parser: Any) -> Union[nodes.Node, list[nodes.Node]]:
)

return nodes.Output([call_method], lineno=lineno)


def _collect_assigned_variables(ast: nodes.Template) -> set[str]:
"""
Extract variables assigned within the Jinja2 template AST.

:param ast: The Jinja2 Abstract Syntax Tree (AST) of the template.

:returns:
A set of variable names that are assigned within the template.
"""
# Collect all variables assigned inside the template via {% set %}
assigned_variables = set()

for node in ast.find_all(nodes.Assign):
if isinstance(node.target, nodes.Name):
assigned_variables.add(node.target.name)
elif isinstance(node.target, (nodes.List, nodes.Tuple)):
for name_node in node.target.items:
if isinstance(name_node, nodes.Name):
assigned_variables.add(name_node.name)

return assigned_variables
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
fixes:
- |
Fixes jinja2 variable detection in ``ConditionalRouter``, ``ChatPromptBuilder``, ``PromptBuilder`` and ``OutputAdapter`` by properly
skipping variables that are assigned within the template.
Previously under specific scenarios variables assigned within a template would falsely be picked up as input variables to the component.
For more information you can check out the parent issue in the Jinja2 library here: https://github.com/pallets/jinja/issues/2069
50 changes: 50 additions & 0 deletions test/components/builders/test_chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,3 +957,53 @@ def test_from_dict(self):
assert builder.template == template
assert builder.variables == ["name", "assistant_name"]
assert builder.required_variables == ["name"]

def test_variables_correct_with_assignment(self):
template = """{% message role="user" %}
{% if existing_documents is not none -%}
{% set x = existing_documents|length -%}
{% else -%}
{% set x = 0 -%}
{% endif -%}
The number is {{ x }}!
{% endmessage %}
"""
builder = ChatPromptBuilder(template=template, required_variables="*")
assert builder.variables == ["existing_documents"]
assert builder.required_variables == "*"
res = builder.run(existing_documents=None)
assert res["prompt"][0].text == "The number is 0!"

def test_variables_correct_with_tuple_assignment(self):
template = """{% message role="user" %}
{% if name is not none -%}
{% set x, y = (0, 1) %}
{% else -%}
{% set x, y = (2, 3) %}
{% endif -%}
x={{ x }}, y={{ y }}
Hello, my name is {{name}}!
{% endmessage %}
"""
builder = ChatPromptBuilder(template=template, required_variables="*")
assert builder.variables == ["name"]
assert builder.required_variables == "*"
res = builder.run(name="John")
assert res["prompt"][0].text == "x=0, y=1\nHello, my name is John!"

def test_variables_correct_with_list_assignment(self):
template = """{% message role="user" %}
{% if name is not none -%}
{% set x, y = [0, 1] %}
{% else -%}
{% set x, y = [2, 3] %}
{% endif -%}
x={{ x }}, y={{ y }}
Hello, my name is {{name}}!
{% endmessage %}
"""
builder = ChatPromptBuilder(template=template, required_variables="*")
assert builder.variables == ["name"]
assert builder.required_variables == "*"
res = builder.run(name="John")
assert res["prompt"][0].text == "x=0, y=1\nHello, my name is John!"
44 changes: 44 additions & 0 deletions test/components/builders/test_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,47 @@ def test_warning_no_required_variables(self, caplog):
with caplog.at_level(logging.WARNING):
_ = PromptBuilder(template="This is a {{ variable }}")
assert "but `required_variables` is not set." in caplog.text

def test_variables_correct_with_assignment(self) -> None:
template = """{% if existing_documents is not none %}
{% set existing_doc_len = existing_documents|length %}
{% else %}
{% set existing_doc_len = 0 %}
{% endif %}
{% for doc in docs %}
<document reference="{{loop.index + existing_doc_len}}">
{{ doc.content }}
</document>
{% endfor %}
"""
builder = PromptBuilder(template=template, required_variables="*")
assert set(builder.variables) == {"docs", "existing_documents"}
assert builder.required_variables == "*"

def test_variables_correct_with_tuple_assignment(self):
template = """{% if existing_documents is not none -%}
{% set x, y = (existing_documents|length, 1) -%}
{% else -%}
{% set x, y = (0, 1) -%}
{% endif -%}
x={{ x }}, y={{ y }}
"""
builder = PromptBuilder(template=template, required_variables="*")
assert builder.variables == ["existing_documents"]
assert builder.required_variables == "*"
res = builder.run(existing_documents=None)
assert res["prompt"] == "x=0, y=1"

def test_variables_correct_with_list_assignment(self):
template = """{% if existing_documents is not none -%}
{% set x, y = [existing_documents|length, 1] -%}
{% else -%}
{% set x, y = [0, 1] -%}
{% endif -%}
x={{ x }}, y={{ y }}
"""
builder = PromptBuilder(template=template, required_variables="*")
assert builder.variables == ["existing_documents"]
assert builder.required_variables == "*"
res = builder.run(existing_documents=None)
assert res["prompt"] == "x=0, y=1"
16 changes: 15 additions & 1 deletion test/components/converters/test_output_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
# SPDX-License-Identifier: Apache-2.0

import json
from typing import List
from typing import Any, List

import pytest

from haystack import Pipeline, component
from haystack.components.converters import OutputAdapter
from haystack.components.converters.output_adapter import OutputAdaptationException
from haystack.core.component.sockets import InputSocket
from haystack.dataclasses import Document


Expand Down Expand Up @@ -203,3 +204,16 @@ def test_unsafe(self):
]
res = adapter.run(documents=documents)
assert res["output"] == documents[0]

def test_variables_correct_with_assignment(self) -> None:
template = """{% if control == 'something' %}
{% set output = 1 %}
{% else %}
{% set output = 3 %}
{% endif %}
{{ output }}
"""
adapter = OutputAdapter(template=template, output_type=int)
assert adapter.__haystack_input__._sockets_dict == {"control": InputSocket(name="control", type=Any)}
res = adapter.run(control="something")
assert res["output"] == 1
13 changes: 13 additions & 0 deletions test/components/routers/test_conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from unittest import mock

import pytest
from jinja2.nativetypes import NativeEnvironment

from haystack import Pipeline
from haystack.components.routers import ConditionalRouter
Expand Down Expand Up @@ -636,3 +637,15 @@ def test_sede_multiple_outputs(self):
reloaded_router = ConditionalRouter.from_dict(router.to_dict())
assert reloaded_router.custom_filters == router.custom_filters
assert reloaded_router.routes == router.routes

def test_extract_variables_correct_with_assignment(self):
condition = """{%- if control == 'something' -%}
{% set streams = 1 %}
{%- else -%}
{% set streams = 2 %}
{%- endif -%}
{{streams == 1}}
"""
templates = [condition, "{{query}}"]
extracted_variables = ConditionalRouter._extract_variables(env=NativeEnvironment(), templates=templates)
assert extracted_variables == {"control", "query"}
Loading