diff --git a/examples/anthropic/search_graph_schema_haiku.py b/examples/anthropic/search_graph_schema_haiku.py index 0ccafa79..1158d58a 100644 --- a/examples/anthropic/search_graph_schema_haiku.py +++ b/examples/anthropic/search_graph_schema_haiku.py @@ -5,7 +5,7 @@ import os from typing import List from dotenv import load_dotenv -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import SearchGraph load_dotenv() diff --git a/examples/anthropic/smart_scraper_schema_haiku.py b/examples/anthropic/smart_scraper_schema_haiku.py index 0e70aeb5..bd447a06 100644 --- a/examples/anthropic/smart_scraper_schema_haiku.py +++ b/examples/anthropic/smart_scraper_schema_haiku.py @@ -4,7 +4,7 @@ import os from typing import List -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from dotenv import load_dotenv from scrapegraphai.graphs import SmartScraperGraph from scrapegraphai.utils import prettify_exec_info diff --git a/examples/azure/search_graph_schema_azure.py b/examples/azure/search_graph_schema_azure.py index ba80b373..629c92ab 100644 --- a/examples/azure/search_graph_schema_azure.py +++ b/examples/azure/search_graph_schema_azure.py @@ -9,7 +9,7 @@ from scrapegraphai.graphs import SearchGraph from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from typing import List # ************************************************ diff --git a/examples/azure/smart_scraper_schema_azure.py b/examples/azure/smart_scraper_schema_azure.py index 5c882c46..d0816bf5 100644 --- a/examples/azure/smart_scraper_schema_azure.py +++ b/examples/azure/smart_scraper_schema_azure.py @@ -5,7 +5,7 @@ import os import json from typing import List -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from dotenv import load_dotenv from scrapegraphai.graphs import SmartScraperGraph diff --git a/examples/bedrock/search_graph_schema_bedrock.py b/examples/bedrock/search_graph_schema_bedrock.py index ad2cadab..a49ba730 100644 --- a/examples/bedrock/search_graph_schema_bedrock.py +++ b/examples/bedrock/search_graph_schema_bedrock.py @@ -4,7 +4,7 @@ from scrapegraphai.graphs import SearchGraph from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from typing import List # ************************************************ diff --git a/examples/bedrock/smart_scraper_schema_bedrock.py b/examples/bedrock/smart_scraper_schema_bedrock.py index 02f83029..2829efec 100644 --- a/examples/bedrock/smart_scraper_schema_bedrock.py +++ b/examples/bedrock/smart_scraper_schema_bedrock.py @@ -2,7 +2,7 @@ Basic example of scraping pipeline using SmartScraper """ from typing import List -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import SmartScraperGraph from scrapegraphai.utils import prettify_exec_info diff --git a/examples/deepseek/search_graph_schema_deepseek.py b/examples/deepseek/search_graph_schema_deepseek.py index 9966602d..1471ede1 100644 --- a/examples/deepseek/search_graph_schema_deepseek.py +++ b/examples/deepseek/search_graph_schema_deepseek.py @@ -9,7 +9,7 @@ from scrapegraphai.graphs import SearchGraph from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from typing import List # ************************************************ diff --git a/examples/deepseek/smart_scraper_schema_deepseek.py b/examples/deepseek/smart_scraper_schema_deepseek.py index 6d924b71..722e02bf 100644 --- a/examples/deepseek/smart_scraper_schema_deepseek.py +++ b/examples/deepseek/smart_scraper_schema_deepseek.py @@ -4,7 +4,7 @@ import os from typing import List -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from dotenv import load_dotenv from scrapegraphai.graphs import SmartScraperGraph from scrapegraphai.utils import prettify_exec_info diff --git a/examples/extras/serch_graph_scehma.py b/examples/extras/serch_graph_scehma.py index f4135d19..66c47a33 100644 --- a/examples/extras/serch_graph_scehma.py +++ b/examples/extras/serch_graph_scehma.py @@ -5,7 +5,7 @@ import os from dotenv import load_dotenv from scrapegraphai.graphs import SearchGraph -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from typing import List load_dotenv() diff --git a/examples/fireworks/pdf_scraper_multi_fireworks.py b/examples/fireworks/pdf_scraper_multi_fireworks.py index bbf3808a..c1077061 100644 --- a/examples/fireworks/pdf_scraper_multi_fireworks.py +++ b/examples/fireworks/pdf_scraper_multi_fireworks.py @@ -5,7 +5,7 @@ import json from typing import List from dotenv import load_dotenv -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import PdfScraperMultiGraph load_dotenv() diff --git a/examples/fireworks/script_generator_schema_fireworks.py b/examples/fireworks/script_generator_schema_fireworks.py index f6f90ddf..6355a4e8 100644 --- a/examples/fireworks/script_generator_schema_fireworks.py +++ b/examples/fireworks/script_generator_schema_fireworks.py @@ -5,7 +5,7 @@ import os from typing import List from dotenv import load_dotenv -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import ScriptCreatorGraph from scrapegraphai.utils import prettify_exec_info diff --git a/examples/fireworks/search_graph_schema_fireworks.py b/examples/fireworks/search_graph_schema_fireworks.py index a16d3ae2..d88d991e 100644 --- a/examples/fireworks/search_graph_schema_fireworks.py +++ b/examples/fireworks/search_graph_schema_fireworks.py @@ -9,7 +9,7 @@ from scrapegraphai.graphs import SearchGraph from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from typing import List # ************************************************ diff --git a/examples/fireworks/smart_scraper_schema_fireworks.py b/examples/fireworks/smart_scraper_schema_fireworks.py index cefb4d7d..d71593f3 100644 --- a/examples/fireworks/smart_scraper_schema_fireworks.py +++ b/examples/fireworks/smart_scraper_schema_fireworks.py @@ -5,7 +5,7 @@ import os, json from typing import List from dotenv import load_dotenv -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import SmartScraperGraph load_dotenv() diff --git a/examples/google_genai/search_graph_schema_gemini.py b/examples/google_genai/search_graph_schema_gemini.py index 3c44dd2d..e4b7983d 100644 --- a/examples/google_genai/search_graph_schema_gemini.py +++ b/examples/google_genai/search_graph_schema_gemini.py @@ -9,7 +9,7 @@ from scrapegraphai.graphs import SearchGraph from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from typing import List # ************************************************ diff --git a/examples/google_genai/smart_scraper_schema_gemini.py b/examples/google_genai/smart_scraper_schema_gemini.py index 8096c466..6c817e20 100644 --- a/examples/google_genai/smart_scraper_schema_gemini.py +++ b/examples/google_genai/smart_scraper_schema_gemini.py @@ -4,7 +4,7 @@ import os from typing import List -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from dotenv import load_dotenv from scrapegraphai.utils import prettify_exec_info from scrapegraphai.graphs import SmartScraperGraph diff --git a/examples/google_vertexai/search_graph_schema_gemini.py b/examples/google_vertexai/search_graph_schema_gemini.py index 7e73c584..54586c7e 100644 --- a/examples/google_vertexai/search_graph_schema_gemini.py +++ b/examples/google_vertexai/search_graph_schema_gemini.py @@ -9,7 +9,7 @@ from scrapegraphai.graphs import SearchGraph from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from typing import List # ************************************************ diff --git a/examples/google_vertexai/smart_scraper_schema_gemini.py b/examples/google_vertexai/smart_scraper_schema_gemini.py index 18e0ef5b..541ce9aa 100644 --- a/examples/google_vertexai/smart_scraper_schema_gemini.py +++ b/examples/google_vertexai/smart_scraper_schema_gemini.py @@ -4,7 +4,7 @@ import os from typing import List -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from dotenv import load_dotenv from scrapegraphai.utils import prettify_exec_info from scrapegraphai.graphs import SmartScraperGraph diff --git a/examples/groq/search_graph_schema_groq.py b/examples/groq/search_graph_schema_groq.py index d5253f2c..4cc2209d 100644 --- a/examples/groq/search_graph_schema_groq.py +++ b/examples/groq/search_graph_schema_groq.py @@ -9,7 +9,7 @@ from scrapegraphai.graphs import SearchGraph from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from typing import List # ************************************************ diff --git a/examples/groq/smart_scraper_schema_groq.py b/examples/groq/smart_scraper_schema_groq.py index b2b377b3..f9c1a40b 100644 --- a/examples/groq/smart_scraper_schema_groq.py +++ b/examples/groq/smart_scraper_schema_groq.py @@ -4,7 +4,7 @@ import os, json from typing import List -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from dotenv import load_dotenv from scrapegraphai.graphs import SmartScraperGraph from scrapegraphai.utils import prettify_exec_info diff --git a/examples/integrations/indexify_node_example.py b/examples/integrations/indexify_node_example.py index fae2403a..61db52d2 100644 --- a/examples/integrations/indexify_node_example.py +++ b/examples/integrations/indexify_node_example.py @@ -8,7 +8,7 @@ from dotenv import load_dotenv load_dotenv() -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import SmartScraperGraph from scrapegraphai.integrations import IndexifyNode diff --git a/examples/local_models/search_graph_schema_ollama.py b/examples/local_models/search_graph_schema_ollama.py index 6720383c..fb87954f 100644 --- a/examples/local_models/search_graph_schema_ollama.py +++ b/examples/local_models/search_graph_schema_ollama.py @@ -4,7 +4,7 @@ from scrapegraphai.graphs import SearchGraph from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from typing import List # ************************************************ diff --git a/examples/local_models/smart_scraper_schema_ollama.py b/examples/local_models/smart_scraper_schema_ollama.py index 5f15b080..35503bd7 100644 --- a/examples/local_models/smart_scraper_schema_ollama.py +++ b/examples/local_models/smart_scraper_schema_ollama.py @@ -3,7 +3,7 @@ """ import json from typing import List -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import SmartScraperGraph from scrapegraphai.utils import prettify_exec_info diff --git a/examples/mistral/pdf_scraper_multi_mistral.py b/examples/mistral/pdf_scraper_multi_mistral.py index 18b8a1f0..e9f1613f 100644 --- a/examples/mistral/pdf_scraper_multi_mistral.py +++ b/examples/mistral/pdf_scraper_multi_mistral.py @@ -5,7 +5,7 @@ import json from typing import List from dotenv import load_dotenv -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import PdfScraperMultiGraph load_dotenv() diff --git a/examples/mistral/script_generator_schema_mistral.py b/examples/mistral/script_generator_schema_mistral.py index beaca4c1..3ad46685 100644 --- a/examples/mistral/script_generator_schema_mistral.py +++ b/examples/mistral/script_generator_schema_mistral.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv from scrapegraphai.graphs import ScriptCreatorGraph from scrapegraphai.utils import prettify_exec_info -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from typing import List load_dotenv() diff --git a/examples/mistral/search_graph_schema_mistral.py b/examples/mistral/search_graph_schema_mistral.py index d804d984..7c71c0b1 100644 --- a/examples/mistral/search_graph_schema_mistral.py +++ b/examples/mistral/search_graph_schema_mistral.py @@ -5,7 +5,7 @@ import os from typing import List from dotenv import load_dotenv -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import SearchGraph from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info diff --git a/examples/mistral/smart_scraper_schema_mistral.py b/examples/mistral/smart_scraper_schema_mistral.py index 517bd743..3e1e505a 100644 --- a/examples/mistral/smart_scraper_schema_mistral.py +++ b/examples/mistral/smart_scraper_schema_mistral.py @@ -5,7 +5,7 @@ import os, json from typing import List from dotenv import load_dotenv -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import SmartScraperGraph load_dotenv() diff --git a/examples/nemotron/script_generator_schema_nemotron.py b/examples/nemotron/script_generator_schema_nemotron.py index 09d0d682..3f0713a4 100644 --- a/examples/nemotron/script_generator_schema_nemotron.py +++ b/examples/nemotron/script_generator_schema_nemotron.py @@ -7,7 +7,7 @@ from scrapegraphai.graphs import ScriptCreatorGraph from scrapegraphai.utils import prettify_exec_info -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from typing import List load_dotenv() diff --git a/examples/nemotron/search_graph_schema_nemotron.py b/examples/nemotron/search_graph_schema_nemotron.py index 84150f53..eec72daf 100644 --- a/examples/nemotron/search_graph_schema_nemotron.py +++ b/examples/nemotron/search_graph_schema_nemotron.py @@ -9,7 +9,7 @@ from scrapegraphai.graphs import SearchGraph from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from typing import List # ************************************************ diff --git a/examples/nemotron/smart_scraper_schema_nemotron.py b/examples/nemotron/smart_scraper_schema_nemotron.py index 11ce4de2..e1462e85 100644 --- a/examples/nemotron/smart_scraper_schema_nemotron.py +++ b/examples/nemotron/smart_scraper_schema_nemotron.py @@ -5,7 +5,7 @@ import os, json from typing import List from dotenv import load_dotenv -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import SmartScraperGraph load_dotenv() diff --git a/examples/oneapi/search_graph_schema_oneapi.py b/examples/oneapi/search_graph_schema_oneapi.py index d99bca00..7fc44539 100644 --- a/examples/oneapi/search_graph_schema_oneapi.py +++ b/examples/oneapi/search_graph_schema_oneapi.py @@ -4,7 +4,7 @@ from scrapegraphai.graphs import SearchGraph from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from typing import List # ************************************************ diff --git a/examples/oneapi/smart_scraper_schema_oneapi.py b/examples/oneapi/smart_scraper_schema_oneapi.py index b12bbf66..0c011bb6 100644 --- a/examples/oneapi/smart_scraper_schema_oneapi.py +++ b/examples/oneapi/smart_scraper_schema_oneapi.py @@ -2,7 +2,7 @@ Basic example of scraping pipeline using SmartScraper and OneAPI """ from typing import List -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import SmartScraperGraph from scrapegraphai.utils import prettify_exec_info diff --git a/examples/openai/pdf_scraper_multi_openai.py b/examples/openai/pdf_scraper_multi_openai.py index a405da43..91e219e3 100644 --- a/examples/openai/pdf_scraper_multi_openai.py +++ b/examples/openai/pdf_scraper_multi_openai.py @@ -5,7 +5,7 @@ import json from typing import List from dotenv import load_dotenv -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import PdfScraperMultiGraph load_dotenv() diff --git a/examples/openai/script_generator_schema_openai.py b/examples/openai/script_generator_schema_openai.py index e8e7d111..5e542c53 100644 --- a/examples/openai/script_generator_schema_openai.py +++ b/examples/openai/script_generator_schema_openai.py @@ -7,7 +7,7 @@ from scrapegraphai.graphs import ScriptCreatorGraph from scrapegraphai.utils import prettify_exec_info -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from typing import List load_dotenv() diff --git a/examples/openai/search_graph_schema_openai.py b/examples/openai/search_graph_schema_openai.py index c0440955..571f08b0 100644 --- a/examples/openai/search_graph_schema_openai.py +++ b/examples/openai/search_graph_schema_openai.py @@ -5,7 +5,7 @@ import os from typing import List from dotenv import load_dotenv -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import SearchGraph from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info diff --git a/examples/openai/smart_scraper_schema_openai.py b/examples/openai/smart_scraper_schema_openai.py index f1ce3d14..0c1618d6 100644 --- a/examples/openai/smart_scraper_schema_openai.py +++ b/examples/openai/smart_scraper_schema_openai.py @@ -5,7 +5,7 @@ import os, json from typing import List from dotenv import load_dotenv -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from scrapegraphai.graphs import SmartScraperGraph load_dotenv() diff --git a/scrapegraphai/nodes/generate_answer_csv_node.py b/scrapegraphai/nodes/generate_answer_csv_node.py index 160be9ce..85593cfa 100644 --- a/scrapegraphai/nodes/generate_answer_csv_node.py +++ b/scrapegraphai/nodes/generate_answer_csv_node.py @@ -3,17 +3,14 @@ """ from typing import List, Optional -from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel -from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_openai import ChatOpenAI from langchain_mistralai import ChatMistralAI from tqdm import tqdm -from ..utils.logging import get_logger from .base_node import BaseNode -from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser +from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV class GenerateAnswerCSVNode(BaseNode): @@ -101,14 +98,10 @@ def execute(self, state): self.llm_model = self.llm_model.with_structured_output( schema = self.node_config["schema"]) # json schema works only on specific models - output_parser = typed_dict_output_parser - if is_basemodel_subclass(self.node_config["schema"]): - output_parser = base_model_v2_output_parser - if issubclass(self.node_config["schema"], BaseModelV1): - output_parser = base_model_v1_output_parser + output_parser = get_structured_output_parser(self.node_config["schema"]) format_instructions = "NA" else: - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + output_parser = get_pydantic_output_parser(self.node_config["schema"]) format_instructions = output_parser.get_format_instructions() else: diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index 6e32d1b5..b0c102e1 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -2,17 +2,15 @@ GenerateAnswerNode Module """ from typing import List, Optional -from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel -from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_openai import ChatOpenAI, AzureChatOpenAI from langchain_mistralai import ChatMistralAI from langchain_community.chat_models import ChatOllama from tqdm import tqdm from .base_node import BaseNode -from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser +from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser from ..prompts import (TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, @@ -95,15 +93,11 @@ def execute(self, state: dict) -> dict: if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): self.llm_model = self.llm_model.with_structured_output( schema = self.node_config["schema"]) # json schema works only on specific models - - output_parser = typed_dict_output_parser - if is_basemodel_subclass(self.node_config["schema"]): - output_parser = base_model_v2_output_parser - if issubclass(self.node_config["schema"], BaseModelV1): - output_parser = base_model_v1_output_parser + + output_parser = get_structured_output_parser(self.node_config["schema"]) format_instructions = "NA" else: - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + output_parser = get_pydantic_output_parser(self.node_config["schema"]) format_instructions = output_parser.get_format_instructions() else: diff --git a/scrapegraphai/nodes/generate_answer_omni_node.py b/scrapegraphai/nodes/generate_answer_omni_node.py index 871a26e5..2824a573 100644 --- a/scrapegraphai/nodes/generate_answer_omni_node.py +++ b/scrapegraphai/nodes/generate_answer_omni_node.py @@ -2,17 +2,15 @@ GenerateAnswerNode Module """ from typing import List, Optional -from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel -from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_openai import ChatOpenAI from langchain_mistralai import ChatMistralAI from tqdm import tqdm from langchain_community.chat_models import ChatOllama from .base_node import BaseNode -from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser +from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser from ..prompts.generate_answer_node_omni_prompts import (TEMPLATE_NO_CHUNKS_OMNI, TEMPLATE_CHUNKS_OMNI, TEMPLATE_MERGE_OMNI) @@ -90,14 +88,10 @@ def execute(self, state: dict) -> dict: self.llm_model = self.llm_model.with_structured_output( schema = self.node_config["schema"]) # json schema works only on specific models - output_parser = typed_dict_output_parser - if is_basemodel_subclass(self.node_config["schema"]): - output_parser = base_model_v2_output_parser - if issubclass(self.node_config["schema"], BaseModelV1): - output_parser = base_model_v1_output_parser + output_parser = get_structured_output_parser(self.node_config["schema"]) format_instructions = "NA" else: - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + output_parser = get_pydantic_output_parser(self.node_config["schema"]) format_instructions = output_parser.get_format_instructions() else: diff --git a/scrapegraphai/nodes/generate_answer_pdf_node.py b/scrapegraphai/nodes/generate_answer_pdf_node.py index 4832d45e..544184b4 100644 --- a/scrapegraphai/nodes/generate_answer_pdf_node.py +++ b/scrapegraphai/nodes/generate_answer_pdf_node.py @@ -2,18 +2,15 @@ Module for generating the answer node """ from typing import List, Optional -from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel -from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_openai import ChatOpenAI from langchain_mistralai import ChatMistralAI from tqdm import tqdm from langchain_community.chat_models import ChatOllama -from ..utils.logging import get_logger from .base_node import BaseNode -from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser +from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser from ..prompts.generate_answer_node_pdf_prompts import (TEMPLATE_CHUNKS_PDF, TEMPLATE_NO_CHUNKS_PDF, TEMPLATE_MERGE_PDF) @@ -102,14 +99,10 @@ def execute(self, state): self.llm_model = self.llm_model.with_structured_output( schema = self.node_config["schema"]) # json schema works only on specific models - output_parser = typed_dict_output_parser - if is_basemodel_subclass(self.node_config["schema"]): - output_parser = base_model_v2_output_parser - if issubclass(self.node_config["schema"], BaseModelV1): - output_parser = base_model_v1_output_parser + output_parser = get_structured_output_parser(self.node_config["schema"]) format_instructions = "NA" else: - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + output_parser = get_pydantic_output_parser(self.node_config["schema"]) format_instructions = output_parser.get_format_instructions() else: diff --git a/scrapegraphai/nodes/graph_iterator_node.py b/scrapegraphai/nodes/graph_iterator_node.py index 8781cf2d..e38461f1 100644 --- a/scrapegraphai/nodes/graph_iterator_node.py +++ b/scrapegraphai/nodes/graph_iterator_node.py @@ -5,7 +5,7 @@ from typing import List, Optional from tqdm.asyncio import tqdm from .base_node import BaseNode -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel DEFAULT_BATCHSIZE = 16 diff --git a/scrapegraphai/nodes/merge_answers_node.py b/scrapegraphai/nodes/merge_answers_node.py index 64e1f149..9f9a356c 100644 --- a/scrapegraphai/nodes/merge_answers_node.py +++ b/scrapegraphai/nodes/merge_answers_node.py @@ -2,16 +2,13 @@ MergeAnswersNode Module """ from typing import List, Optional -from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser -from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_openai import ChatOpenAI from langchain_mistralai import ChatMistralAI -from ..utils.logging import get_logger from .base_node import BaseNode from ..prompts import TEMPLATE_COMBINED -from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser +from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser class MergeAnswersNode(BaseNode): """ @@ -78,14 +75,10 @@ def execute(self, state: dict) -> dict: self.llm_model = self.llm_model.with_structured_output( schema = self.node_config["schema"]) # json schema works only on specific models - output_parser = typed_dict_output_parser - if is_basemodel_subclass(self.node_config["schema"]): - output_parser = base_model_v2_output_parser - if issubclass(self.node_config["schema"], BaseModelV1): - output_parser = base_model_v1_output_parser + output_parser = get_structured_output_parser(self.node_config["schema"]) format_instructions = "NA" else: - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + output_parser = get_pydantic_output_parser(self.node_config["schema"]) format_instructions = output_parser.get_format_instructions() else: diff --git a/scrapegraphai/utils/llm_output_parser.py b/scrapegraphai/utils/llm_output_parser.py deleted file mode 100644 index e6ac6e2d..00000000 --- a/scrapegraphai/utils/llm_output_parser.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -Custom output parser for the LLM model. -""" -from pydantic import BaseModel as BaseModelV2 -from pydantic.v1 import BaseModel as BaseModelV1 - -def base_model_v1_output_parser(x: BaseModelV1) -> dict: - """ - Parse the output of an LLM when the schema is a BaseModelv1 and `with_structured_output` is used. - - Args: - x (BaseModelV2 | BaseModelV1): The output from the LLM model. - - Returns: - dict: The parsed output. - """ - work_dict = x.dict() - - # recursive dict parser - def recursive_dict_parser(work_dict: dict) -> dict: - dict_keys = work_dict.keys() - for key in dict_keys: - if isinstance(work_dict[key], BaseModelV1): - work_dict[key] = work_dict[key].dict() - recursive_dict_parser(work_dict[key]) - return work_dict - - return recursive_dict_parser(work_dict) - - -def base_model_v2_output_parser(x: BaseModelV2) -> dict: - """ - Parse the output of an LLM when the schema is a BaseModelv2 and `with_structured_output` is used. - - Args: - x (BaseModelV2): The output from the LLM model. - - Returns: - dict: The parsed output. - """ - return x.model_dump() - -def typed_dict_output_parser(x: dict) -> dict: - """ - Parse the output of an LLM when the schema is a TypedDict and `with_structured_output` is used. - - Args: - x (dict): The output from the LLM model. - - Returns: - dict: The parsed output. - """ - return x diff --git a/scrapegraphai/utils/output_parser.py b/scrapegraphai/utils/output_parser.py new file mode 100644 index 00000000..39ae092e --- /dev/null +++ b/scrapegraphai/utils/output_parser.py @@ -0,0 +1,85 @@ +""" +Functions to retrieve the correct output parser and format instructions for the LLM model. +""" +from pydantic import BaseModel as BaseModelV2 +from pydantic.v1 import BaseModel as BaseModelV1 +from typing import Union, Dict, Any, Type, Callable +from langchain_core.output_parsers import JsonOutputParser + +def get_structured_output_parser(schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type]) -> Callable: + """ + Get the correct output parser for the LLM model. + + Returns: + Callable: The output parser function. + """ + if issubclass(schema, BaseModelV1): + return _base_model_v1_output_parser + + if issubclass(schema, BaseModelV2): + return _base_model_v2_output_parser + + return _dict_output_parser + +def get_pydantic_output_parser(schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type]) -> JsonOutputParser: + """ + Get the correct output parser for the LLM model. + + Returns: + JsonOutputParser: The output parser object. + """ + if issubclass(schema, BaseModelV1): + raise ValueError("pydantic.v1 and langchain_core.pydantic_v1 are not supported with this LLM model. Please use pydantic v2 instead.") + + if issubclass(schema, BaseModelV2): + return JsonOutputParser(pydantic_object=schema) + + raise ValueError("The schema is not a pydantic subclass. With this LLM model you must use a pydantic schemas.") + +def _base_model_v1_output_parser(x: BaseModelV1) -> dict: + """ + Parse the output of an LLM when the schema is BaseModelv1. + + Args: + x (BaseModelV1): The output from the LLM model. + + Returns: + dict: The parsed output. + """ + work_dict = x.dict() + + # recursive dict parser + def recursive_dict_parser(work_dict: dict) -> dict: + dict_keys = work_dict.keys() + for key in dict_keys: + if isinstance(work_dict[key], BaseModelV1): + work_dict[key] = work_dict[key].dict() + recursive_dict_parser(work_dict[key]) + return work_dict + + return recursive_dict_parser(work_dict) + + +def _base_model_v2_output_parser(x: BaseModelV2) -> dict: + """ + Parse the output of an LLM when the schema is BaseModelv2. + + Args: + x (BaseModelV2): The output from the LLM model. + + Returns: + dict: The parsed output. + """ + return x.model_dump() + +def _dict_output_parser(x: dict) -> dict: + """ + Parse the output of an LLM when the schema is TypedDict or JsonSchema. + + Args: + x (dict): The output from the LLM model. + + Returns: + dict: The parsed output. + """ + return x