diff --git a/haystack/components/builders/chat_prompt_builder.py b/haystack/components/builders/chat_prompt_builder.py index 8956d51506..7483df59a9 100644 --- a/haystack/components/builders/chat_prompt_builder.py +++ b/haystack/components/builders/chat_prompt_builder.py @@ -6,7 +6,6 @@ from copy import deepcopy from typing import Any, Literal, Optional, Union -from jinja2 import meta from jinja2.sandbox import SandboxedEnvironment from haystack import component, default_from_dict, default_to_dict, logging @@ -14,6 +13,7 @@ from haystack.lazy_imports import LazyImport from haystack.utils import Jinja2TimeExtension from haystack.utils.jinja2_chat_extension import ChatMessageExtension, templatize_part +from haystack.utils.jinja2_extensions import _extract_template_variables_and_assignments logger = logging.getLogger(__name__) @@ -179,13 +179,17 @@ def __init__( raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=message.role.value, message=message)) if message.text and "templatize_part" in message.text: raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE) - ast = self._env.parse(message.text) - template_variables = meta.find_undeclared_variables(ast) - extracted_variables += list(template_variables) + assigned_variables, template_variables = _extract_template_variables_and_assignments( + env=self._env, template=message.text + ) + extracted_variables += list(template_variables - assigned_variables) elif isinstance(template, str): - ast = self._env.parse(template) - extracted_variables = list(meta.find_undeclared_variables(ast)) + assigned_variables, template_variables = _extract_template_variables_and_assignments( + env=self._env, template=template + ) + extracted_variables = list(template_variables - assigned_variables) + extracted_variables = extracted_variables or [] self.variables = variables or extracted_variables self.required_variables = required_variables or [] diff --git a/haystack/components/builders/prompt_builder.py b/haystack/components/builders/prompt_builder.py index 3e1d3603ca..a7c3f1950c 100644 --- a/haystack/components/builders/prompt_builder.py +++ b/haystack/components/builders/prompt_builder.py @@ -4,11 +4,11 @@ from typing import Any, Literal, Optional, Union -from jinja2 import meta from jinja2.sandbox import SandboxedEnvironment from haystack import component, default_to_dict, logging from haystack.utils import Jinja2TimeExtension +from haystack.utils.jinja2_extensions import _extract_template_variables_and_assignments logger = logging.getLogger(__name__) @@ -174,11 +174,13 @@ def __init__( self._env = SandboxedEnvironment() self.template = self._env.from_string(template) + if not variables: - # infer variables from template - ast = self._env.parse(template) - template_variables = meta.find_undeclared_variables(ast) - variables = list(template_variables) + assigned_variables, template_variables = _extract_template_variables_and_assignments( + env=self._env, template=template + ) + variables = list(template_variables - assigned_variables) + variables = variables or [] self.variables = variables diff --git a/haystack/components/converters/output_adapter.py b/haystack/components/converters/output_adapter.py index 24cfdeaa42..bb5fc96ba3 100644 --- a/haystack/components/converters/output_adapter.py +++ b/haystack/components/converters/output_adapter.py @@ -7,13 +7,14 @@ from typing import Any, Callable, Optional import jinja2.runtime -from jinja2 import Environment, TemplateSyntaxError, meta +from jinja2 import TemplateSyntaxError from jinja2.nativetypes import NativeEnvironment from jinja2.sandbox import SandboxedEnvironment from typing_extensions import TypeAlias from haystack import component, default_from_dict, default_to_dict, logging from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type +from haystack.utils.jinja2_extensions import _extract_template_variables_and_assignments logger = logging.getLogger(__name__) @@ -46,7 +47,7 @@ def __init__( output_type: TypeAlias, custom_filters: Optional[dict[str, Callable]] = None, unsafe: bool = False, - ): + ) -> None: """ Create an OutputAdapter component. @@ -92,7 +93,10 @@ def __init__( self._env.filters[name] = filter_func # b) extract variables in the template - route_input_names = self._extract_variables(self._env) + assigned_variables, template_variables = _extract_template_variables_and_assignments( + env=self._env, template=self.template + ) + route_input_names = template_variables - assigned_variables input_types.update(route_input_names) # the env is not needed, discarded automatically @@ -173,13 +177,3 @@ def from_dict(cls, data: dict[str, Any]) -> "OutputAdapter": for name, filter_func in custom_filters.items() } return default_from_dict(cls, data) - - def _extract_variables(self, env: Environment) -> set[str]: - """ - Extracts all variables from a list of Jinja template strings. - - :param env: A Jinja environment. - :return: A set of variable names extracted from the template strings. - """ - ast = env.parse(self.template) - return meta.find_undeclared_variables(ast) diff --git a/haystack/components/routers/conditional_router.py b/haystack/components/routers/conditional_router.py index c8125b869d..7b310d0419 100644 --- a/haystack/components/routers/conditional_router.py +++ b/haystack/components/routers/conditional_router.py @@ -6,12 +6,13 @@ import contextlib from typing import Any, Callable, Mapping, Optional, Sequence, TypedDict, Union, get_args, get_origin -from jinja2 import Environment, TemplateSyntaxError, meta +from jinja2 import Environment, TemplateSyntaxError from jinja2.nativetypes import NativeEnvironment from jinja2.sandbox import SandboxedEnvironment from haystack import component, default_from_dict, default_to_dict, logging from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type +from haystack.utils.jinja2_extensions import _extract_template_variables_and_assignments logger = logging.getLogger(__name__) @@ -403,7 +404,8 @@ def _validate_routes(self, routes: list[Route]): if not self._validate_template(self._env, output): raise ValueError(f"Invalid template for output: {output}") - def _extract_variables(self, env: Environment, templates: list[str]) -> set[str]: + @staticmethod + def _extract_variables(env: Environment, templates: list[str]) -> set[str]: """ Extracts all variables from a list of Jinja template strings. @@ -413,7 +415,10 @@ def _extract_variables(self, env: Environment, templates: list[str]) -> set[str] """ variables = set() for template in templates: - variables.update(meta.find_undeclared_variables(env.parse(template))) + assigned_variables, template_variables = _extract_template_variables_and_assignments( + env=env, template=template + ) + variables.update(template_variables - assigned_variables) return variables def _validate_template(self, env: Environment, template_text: str): diff --git a/haystack/utils/jinja2_extensions.py b/haystack/utils/jinja2_extensions.py index 348b571d60..7d8cb32b47 100644 --- a/haystack/utils/jinja2_extensions.py +++ b/haystack/utils/jinja2_extensions.py @@ -4,7 +4,7 @@ from typing import Any, Optional, Union -from jinja2 import Environment, nodes +from jinja2 import Environment, meta, nodes from jinja2.ext import Extension from haystack.lazy_imports import LazyImport @@ -94,3 +94,42 @@ def parse(self, parser: Any) -> Union[nodes.Node, list[nodes.Node]]: ) return nodes.Output([call_method], lineno=lineno) + + +def _collect_assigned_variables(ast: nodes.Template) -> set[str]: + """ + Extract variables assigned within the Jinja2 template AST. + + :param ast: The Jinja2 Abstract Syntax Tree (AST) of the template. + + :returns: + A set of variable names that are assigned within the template. + """ + # Collect all variables assigned inside the template via {% set %} + assigned_variables = set() + + for node in ast.find_all(nodes.Assign): + if isinstance(node.target, nodes.Name): + assigned_variables.add(node.target.name) + elif isinstance(node.target, (nodes.List, nodes.Tuple)): + for name_node in node.target.items: + if isinstance(name_node, nodes.Name): + assigned_variables.add(name_node.name) + + return assigned_variables + + +def _extract_template_variables_and_assignments(env: Environment, template: str) -> tuple[set[str], set[str]]: + """ + Extract variables from a Jinja2 template and variables assigned within it. + + :param env: A Jinja2 environment. + :param template: A Jinja2 template string. + :returns: A tuple of (assigned_variables, template_variables) where: + - assigned_variables: Variables assigned within the template (e.g., via {% set %}) + - template_variables: All undeclared variables used in the template + """ + jinja2_ast = env.parse(template) + template_variables = meta.find_undeclared_variables(jinja2_ast) + assigned_variables = _collect_assigned_variables(jinja2_ast) + return assigned_variables, template_variables diff --git a/releasenotes/notes/fix-jinja2-variable-extraction-57f34c6bd249e214.yaml b/releasenotes/notes/fix-jinja2-variable-extraction-57f34c6bd249e214.yaml new file mode 100644 index 0000000000..5f3b645721 --- /dev/null +++ b/releasenotes/notes/fix-jinja2-variable-extraction-57f34c6bd249e214.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Fixes jinja2 variable detection in ``ConditionalRouter``, ``ChatPromptBuilder``, ``PromptBuilder`` and ``OutputAdapter`` by properly + skipping variables that are assigned within the template. + Previously under specific scenarios variables assigned within a template would falsely be picked up as input variables to the component. + For more information you can check out the parent issue in the Jinja2 library here: https://github.com/pallets/jinja/issues/2069 diff --git a/test/components/builders/test_chat_prompt_builder.py b/test/components/builders/test_chat_prompt_builder.py index fa15dcbc56..7d8ec0d987 100644 --- a/test/components/builders/test_chat_prompt_builder.py +++ b/test/components/builders/test_chat_prompt_builder.py @@ -957,3 +957,53 @@ def test_from_dict(self): assert builder.template == template assert builder.variables == ["name", "assistant_name"] assert builder.required_variables == ["name"] + + def test_variables_correct_with_assignment(self): + template = """{% message role="user" %} +{% if existing_documents is not none -%} +{% set x = existing_documents|length -%} +{% else -%} +{% set x = 0 -%} +{% endif -%} +The number is {{ x }}! +{% endmessage %} +""" + builder = ChatPromptBuilder(template=template, required_variables="*") + assert builder.variables == ["existing_documents"] + assert builder.required_variables == "*" + res = builder.run(existing_documents=None) + assert res["prompt"][0].text == "The number is 0!" + + def test_variables_correct_with_tuple_assignment(self): + template = """{% message role="user" %} +{% if name is not none -%} +{% set x, y = (0, 1) %} +{% else -%} +{% set x, y = (2, 3) %} +{% endif -%} +x={{ x }}, y={{ y }} +Hello, my name is {{name}}! +{% endmessage %} +""" + builder = ChatPromptBuilder(template=template, required_variables="*") + assert builder.variables == ["name"] + assert builder.required_variables == "*" + res = builder.run(name="John") + assert res["prompt"][0].text == "x=0, y=1\nHello, my name is John!" + + def test_variables_correct_with_list_assignment(self): + template = """{% message role="user" %} +{% if name is not none -%} +{% set x, y = [0, 1] %} +{% else -%} +{% set x, y = [2, 3] %} +{% endif -%} +x={{ x }}, y={{ y }} +Hello, my name is {{name}}! +{% endmessage %} +""" + builder = ChatPromptBuilder(template=template, required_variables="*") + assert builder.variables == ["name"] + assert builder.required_variables == "*" + res = builder.run(name="John") + assert res["prompt"][0].text == "x=0, y=1\nHello, my name is John!" diff --git a/test/components/builders/test_prompt_builder.py b/test/components/builders/test_prompt_builder.py index 8a5504c5fc..6e8627de3d 100644 --- a/test/components/builders/test_prompt_builder.py +++ b/test/components/builders/test_prompt_builder.py @@ -337,3 +337,47 @@ def test_warning_no_required_variables(self, caplog): with caplog.at_level(logging.WARNING): _ = PromptBuilder(template="This is a {{ variable }}") assert "but `required_variables` is not set." in caplog.text + + def test_variables_correct_with_assignment(self) -> None: + template = """{% if existing_documents is not none %} +{% set existing_doc_len = existing_documents|length %} +{% else %} +{% set existing_doc_len = 0 %} +{% endif %} +{% for doc in docs %} + +{{ doc.content }} + +{% endfor %} +""" + builder = PromptBuilder(template=template, required_variables="*") + assert set(builder.variables) == {"docs", "existing_documents"} + assert builder.required_variables == "*" + + def test_variables_correct_with_tuple_assignment(self): + template = """{% if existing_documents is not none -%} +{% set x, y = (existing_documents|length, 1) -%} +{% else -%} +{% set x, y = (0, 1) -%} +{% endif -%} +x={{ x }}, y={{ y }} +""" + builder = PromptBuilder(template=template, required_variables="*") + assert builder.variables == ["existing_documents"] + assert builder.required_variables == "*" + res = builder.run(existing_documents=None) + assert res["prompt"] == "x=0, y=1" + + def test_variables_correct_with_list_assignment(self): + template = """{% if existing_documents is not none -%} +{% set x, y = [existing_documents|length, 1] -%} +{% else -%} +{% set x, y = [0, 1] -%} +{% endif -%} +x={{ x }}, y={{ y }} +""" + builder = PromptBuilder(template=template, required_variables="*") + assert builder.variables == ["existing_documents"] + assert builder.required_variables == "*" + res = builder.run(existing_documents=None) + assert res["prompt"] == "x=0, y=1" diff --git a/test/components/converters/test_output_adapter.py b/test/components/converters/test_output_adapter.py index 2ad2dbb0fd..42e9f04255 100644 --- a/test/components/converters/test_output_adapter.py +++ b/test/components/converters/test_output_adapter.py @@ -3,13 +3,14 @@ # SPDX-License-Identifier: Apache-2.0 import json -from typing import List +from typing import Any, List import pytest from haystack import Pipeline, component from haystack.components.converters import OutputAdapter from haystack.components.converters.output_adapter import OutputAdaptationException +from haystack.core.component.sockets import InputSocket from haystack.dataclasses import Document @@ -203,3 +204,16 @@ def test_unsafe(self): ] res = adapter.run(documents=documents) assert res["output"] == documents[0] + + def test_variables_correct_with_assignment(self) -> None: + template = """{% if control == 'something' %} + {% set output = 1 %} +{% else %} + {% set output = 3 %} +{% endif %} +{{ output }} +""" + adapter = OutputAdapter(template=template, output_type=int) + assert adapter.__haystack_input__._sockets_dict == {"control": InputSocket(name="control", type=Any)} + res = adapter.run(control="something") + assert res["output"] == 1 diff --git a/test/components/routers/test_conditional_router.py b/test/components/routers/test_conditional_router.py index 245e451513..17a933d845 100644 --- a/test/components/routers/test_conditional_router.py +++ b/test/components/routers/test_conditional_router.py @@ -6,6 +6,7 @@ from unittest import mock import pytest +from jinja2.nativetypes import NativeEnvironment from haystack import Pipeline from haystack.components.routers import ConditionalRouter @@ -636,3 +637,15 @@ def test_sede_multiple_outputs(self): reloaded_router = ConditionalRouter.from_dict(router.to_dict()) assert reloaded_router.custom_filters == router.custom_filters assert reloaded_router.routes == router.routes + + def test_extract_variables_correct_with_assignment(self): + condition = """{%- if control == 'something' -%} +{% set streams = 1 %} +{%- else -%} +{% set streams = 2 %} +{%- endif -%} +{{streams == 1}} +""" + templates = [condition, "{{query}}"] + extracted_variables = ConditionalRouter._extract_variables(env=NativeEnvironment(), templates=templates) + assert extracted_variables == {"control", "query"}