diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index 40f7182d..30058ec5 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -2,6 +2,7 @@ GenerateAnswerNode Module """ from typing import List, Optional +from json.decoder import JSONDecodeError from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel @@ -121,9 +122,21 @@ def execute(self, state: dict) -> dict: partial_variables={"context": doc, "format_instructions": format_instructions} ) chain = prompt | self.llm_model + raw_response = str((prompt | self.llm_model).invoke({"question": user_prompt})) + if output_parser: - chain = chain | output_parser - answer = chain.invoke({"question": user_prompt}) + try: + answer = output_parser.parse(raw_response) + except JSONDecodeError: + lines = raw_response.split('\n') + if lines[0].strip().startswith('```'): + lines = lines[1:] + if lines[-1].strip().endswith('```'): + lines = lines[:-1] + cleaned_response = '\n'.join(lines) + answer = output_parser.parse(cleaned_response) + else: + answer = raw_response state.update({self.output[0]: answer}) return state