diff --git a/scrapegraphai/graphs/base_graph.py b/scrapegraphai/graphs/base_graph.py index 0b9f5517..6d160e37 100644 --- a/scrapegraphai/graphs/base_graph.py +++ b/scrapegraphai/graphs/base_graph.py @@ -5,7 +5,7 @@ import warnings from typing import Tuple from ..telemetry import log_graph_execution -from ..utils import CustomOpenAiCallbackManager +from ..utils import CustomLLMCallbackManager class BaseGraph: """ @@ -52,7 +52,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool = self.entry_point = entry_point.node_name self.graph_name = graph_name self.initial_state = {} - self.callback_manager = CustomOpenAiCallbackManager() + self.callback_manager = CustomLLMCallbackManager() if nodes[0].node_name != entry_point.node_name: # raise a warning if the entry point is not the first node in the list @@ -108,6 +108,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]: error_node = None source_type = None llm_model = None + llm_model_name = None embedder_model = None source = [] prompt = None @@ -135,9 +136,11 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]: if hasattr(current_node, "llm_model") and llm_model is None: llm_model = current_node.llm_model if hasattr(llm_model, "model_name"): - llm_model = llm_model.model_name + llm_model_name = llm_model.model_name elif hasattr(llm_model, "model"): - llm_model = llm_model.model + llm_model_name = llm_model.model + elif hasattr(llm_model, "model_id"): + llm_model_name = llm_model.model_id if hasattr(current_node, "embedder_model") and embedder_model is None: embedder_model = current_node.embedder_model @@ -155,7 +158,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]: except Exception as e: schema = None - with self.callback_manager.exclusive_get_openai_callback() as cb: + with self.callback_manager.exclusive_get_callback(llm_model, llm_model_name) as cb: try: result = current_node.execute(state) except Exception as e: @@ -166,7 +169,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]: source=source, prompt=prompt, schema=schema, - llm_model=llm_model, + llm_model=llm_model_name, embedder_model=embedder_model, source_type=source_type, execution_time=graph_execution_time, @@ -222,7 +225,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]: source=source, prompt=prompt, schema=schema, - llm_model=llm_model, + llm_model=llm_model_name, embedder_model=embedder_model, source_type=source_type, content=content, diff --git a/scrapegraphai/utils/__init__.py b/scrapegraphai/utils/__init__.py index 0132c775..ecfa690f 100644 --- a/scrapegraphai/utils/__init__.py +++ b/scrapegraphai/utils/__init__.py @@ -17,4 +17,4 @@ from .screenshot_scraping.text_detection import detect_text from .tokenizer import num_tokens_calculus from .split_text_into_chunks import split_text_into_chunks -from .custom_openai_callback import CustomOpenAiCallbackManager +from .llm_callback_manager import CustomLLMCallbackManager diff --git a/scrapegraphai/utils/custom_callback.py b/scrapegraphai/utils/custom_callback.py new file mode 100644 index 00000000..a3992a5b --- /dev/null +++ b/scrapegraphai/utils/custom_callback.py @@ -0,0 +1,157 @@ +""" +Custom callback for LLM token usage statistics. + +This module has been taken and modified from the OpenAI callback manager in langchian-community. +https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/callbacks/openai_info.py +""" +from contextlib import contextmanager +import threading +from typing import Any, Dict, List, Optional +from contextvars import ContextVar + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, LLMResult +from langchain_core.tracers.context import register_configure_hook + +from .model_costs import MODEL_COST_PER_1K_TOKENS_INPUT, MODEL_COST_PER_1K_TOKENS_OUTPUT + + +def get_token_cost_for_model( + model_name: str, num_tokens: int, is_completion: bool = False +) -> float: + """ + Get the cost in USD for a given model and number of tokens. + + Args: + model_name: Name of the model + num_tokens: Number of tokens. + is_completion: Whether the model is used for completion or not. + Defaults to False. + + Returns: + Cost in USD. + """ + if model_name not in MODEL_COST_PER_1K_TOKENS_INPUT: + return 0.0 + if is_completion: + return MODEL_COST_PER_1K_TOKENS_OUTPUT[model_name] * (num_tokens / 1000) + + return MODEL_COST_PER_1K_TOKENS_INPUT[model_name] * (num_tokens / 1000) + + +class CustomCallbackHandler(BaseCallbackHandler): + """Callback Handler that tracks LLMs info.""" + + total_tokens: int = 0 + prompt_tokens: int = 0 + completion_tokens: int = 0 + successful_requests: int = 0 + total_cost: float = 0.0 + + def __init__(self, llm_model_name: str) -> None: + super().__init__() + self._lock = threading.Lock() + self.model_name = llm_model_name if llm_model_name else "unknown" + + def __repr__(self) -> str: + return ( + f"Tokens Used: {self.total_tokens}\n" + f"\tPrompt Tokens: {self.prompt_tokens}\n" + f"\tCompletion Tokens: {self.completion_tokens}\n" + f"Successful Requests: {self.successful_requests}\n" + f"Total Cost (USD): ${self.total_cost}" + ) + + @property + def always_verbose(self) -> bool: + """Whether to call verbose callbacks even if verbose is False.""" + return True + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Print out the prompts.""" + pass + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Print out the token.""" + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Collect token usage.""" + # Check for usage_metadata (langchain-core >= 0.2.2) + try: + generation = response.generations[0][0] + except IndexError: + generation = None + if isinstance(generation, ChatGeneration): + try: + message = generation.message + if isinstance(message, AIMessage): + usage_metadata = message.usage_metadata + else: + usage_metadata = None + except AttributeError: + usage_metadata = None + else: + usage_metadata = None + if usage_metadata: + token_usage = {"total_tokens": usage_metadata["total_tokens"]} + completion_tokens = usage_metadata["output_tokens"] + prompt_tokens = usage_metadata["input_tokens"] + + + else: + if response.llm_output is None: + return None + + if "token_usage" not in response.llm_output: + with self._lock: + self.successful_requests += 1 + return None + + # compute tokens and cost for this request + token_usage = response.llm_output["token_usage"] + completion_tokens = token_usage.get("completion_tokens", 0) + prompt_tokens = token_usage.get("prompt_tokens", 0) + if self.model_name in MODEL_COST_PER_1K_TOKENS_INPUT: + completion_cost = get_token_cost_for_model( + self.model_name, completion_tokens, is_completion=True + ) + prompt_cost = get_token_cost_for_model(self.model_name, prompt_tokens) + else: + completion_cost = 0 + prompt_cost = 0 + + # update shared state behind lock + with self._lock: + self.total_cost += prompt_cost + completion_cost + self.total_tokens += token_usage.get("total_tokens", 0) + self.prompt_tokens += prompt_tokens + self.completion_tokens += completion_tokens + self.successful_requests += 1 + + def __copy__(self) -> "CustomCallbackHandler": + """Return a copy of the callback handler.""" + return self + + def __deepcopy__(self, memo: Any) -> "CustomCallbackHandler": + """Return a deep copy of the callback handler.""" + return self + + +custom_callback: ContextVar[Optional[CustomCallbackHandler]] = ContextVar( + "custom_callback", default=None +) +register_configure_hook(custom_callback, True) + +@contextmanager +def get_custom_callback(llm_model_name: str): + """ + Function to get custom callback for LLM token usage statistics. + """ + cb = CustomCallbackHandler(llm_model_name) + custom_callback.set(cb) + yield cb + custom_callback.set(None) \ No newline at end of file diff --git a/scrapegraphai/utils/custom_openai_callback.py b/scrapegraphai/utils/custom_openai_callback.py deleted file mode 100644 index e0efa723..00000000 --- a/scrapegraphai/utils/custom_openai_callback.py +++ /dev/null @@ -1,17 +0,0 @@ -import threading -from contextlib import contextmanager -from langchain_community.callbacks import get_openai_callback - -class CustomOpenAiCallbackManager: - _lock = threading.Lock() - - @contextmanager - def exclusive_get_openai_callback(self): - if CustomOpenAiCallbackManager._lock.acquire(blocking=False): - try: - with get_openai_callback() as cb: - yield cb - finally: - CustomOpenAiCallbackManager._lock.release() - else: - yield None \ No newline at end of file diff --git a/scrapegraphai/utils/llm_callback_manager.py b/scrapegraphai/utils/llm_callback_manager.py new file mode 100644 index 00000000..a6b9c893 --- /dev/null +++ b/scrapegraphai/utils/llm_callback_manager.py @@ -0,0 +1,38 @@ +""" +This module provides a custom callback manager for the LLM models. +""" +import threading +from contextlib import contextmanager +from .custom_callback import get_custom_callback + +from langchain_community.callbacks import get_openai_callback +from langchain_community.callbacks.manager import get_bedrock_anthropic_callback +from langchain_openai import ChatOpenAI, AzureChatOpenAI +from langchain_aws import ChatBedrock + +class CustomLLMCallbackManager: + _lock = threading.Lock() + + @contextmanager + def exclusive_get_callback(self, llm_model, llm_model_name): + if CustomLLMCallbackManager._lock.acquire(blocking=False): + if isinstance(llm_model, ChatOpenAI) or isinstance(llm_model, AzureChatOpenAI): + try: + with get_openai_callback() as cb: + yield cb + finally: + CustomLLMCallbackManager._lock.release() + elif isinstance(llm_model, ChatBedrock) and llm_model_name is not None and "claude" in llm_model_name: + try: + with get_bedrock_anthropic_callback() as cb: + yield cb + finally: + CustomLLMCallbackManager._lock.release() + else: + try: + with get_custom_callback(llm_model_name) as cb: + yield cb + finally: + CustomLLMCallbackManager._lock.release() + else: + yield None \ No newline at end of file diff --git a/scrapegraphai/utils/model_costs.py b/scrapegraphai/utils/model_costs.py new file mode 100644 index 00000000..a34ee9cd --- /dev/null +++ b/scrapegraphai/utils/model_costs.py @@ -0,0 +1,105 @@ +""" +This file contains the cost of models per 1k tokens for input and output. +The file is on a best effort basis and may not be up to date. Any contributions are welcome. +""" +MODEL_COST_PER_1K_TOKENS_INPUT = { + ### MistralAI + # General Purpose + "open-mistral-nemo": 0.00015, + "open-mistral-nemo-2407": 0.00015, + "mistral-large": 0.002, + "mistral-large-2407": 0.002, + "mistral-small": 0.0002, + "mistral-small-2409": 0.0002, + # Specialist Models + "codestral": 0.0002, + "codestral-2405": 0.0002, + "pixtral-12b": 0.00015, + "pixtral-12b-2409": 0.00015, + # Legacy Models + "open-mistral-7b": 0.00025, + "open-mixtral-8x7b": 0.0007, + "open-mixtral-8x22b": 0.002, + "mistral-small-latest": 0.001, + "mistral-medium-latest": 0.00275, + + ### Bedrock - not Claude + #AI21 Labs + "a121.ju-ultra-v1": 0.0188, + "a121.ju-mid-v1": 0.0125, + "ai21.jamba-instruct-v1:0": 0.0005, + # Meta - LLama + "meta.llama2-13b-chat-v1": 0.00075, + "meta.llama2-70b-chat-v1": 0.00195, + "meta.llama3-8b-instruct-v1:0": 0.0003, + "meta.llama3-70b-instruct-v1:0": 0.00265, + "meta.llama3-1-8b-instruct-v1:0": 0.00022, + "meta.llama3-1-70b-instruct-v1:0": 0.00099, + "meta.llama3-1-405b-instruct-v1:0": 0.00532, + # Cohere - Command + "cohere.command-text-v14": 0.0015, + "cohere.command-light-text-v14": 0.0003, + "cohere.command-r-v1:0": 0.0005, + "cohere.command-r-plus-v1:0": 0.003, + # Mistral + "mistral.mistral-7b-instruct-v0:2": 0.00015, + "mistral.mistral-large-2402-v1:0": 0.004, + "mistral.mistral-large-2407-v1:0": 0.002, + "mistral.mistral-small-2402-v1:0": 0.001, + "mistral.mixtral-7x8b-instruct-v0:1": 0.00045, + # Amazon - Titan + "amazon.titan-text-express-v1": 0.0002, + "amazon.titan-text-lite-v1": 0.00015, + "amazon.titan-text-premier-v1:0": 0.0005, +} + +MODEL_COST_PER_1K_TOKENS_OUTPUT = { + ### MistralAI + # General Purpose + "open-mistral-nemo": 0.00015, + "open-mistral-nemo-2407": 0.00015, + "mistral-large": 0.002, + "mistral-large-2407": 0.006, + "mistral-small": 0.0002, + "mistral-small-2409": 0.0006, + # Specialist Models + "codestral": 0.0006, + "codestral-2405": 0.0006, + "pixtral-12b": 0.00015, + "pixtral-12b-2409": 0.0006, + # Legacy Models + "open-mistral-7b": 0.00025, + "open-mixtral-8x7b": 0.0007, + "open-mixtral-8x22b": 0.006, + "mistral-small-latest": 0.003, + "mistral-medium-latest": 0.0081, + + ### Bedrock - not Claude + # AI21 Labs + "a121.ju-ultra-v1": 0.0188, + "a121.ju-mid-v1": 0.0125, + "ai21.jamba-instruct-v1:0": 0.0007, + # Meta - LLama + "meta.llama2-13b-chat-v1": 0.001, + "meta.llama2-70b-chat-v1": 0.00256, + "meta.llama3-8b-instruct-v1:0": 0.0006, + "meta.llama3-70b-instruct-v1:0": 0.0035, + "meta.llama3-1-8b-instruct-v1:0": 0.00022, + "meta.llama3-1-70b-instruct-v1:0": 0.00099, + "meta.llama3-1-405b-instruct-v1:0": 0.016, + # Cohere - Command + "cohere.command-text-v14": 0.002, + "cohere.command-light-text-v14": 0.0006, + "cohere.command-r-v1:0": 0.0015, + "cohere.command-r-plus-v1:0": 0.015, + # Mistral + "mistral.mistral-7b-instruct-v0:2": 0.0002, + "mistral.mistral-large-2402-v1:0": 0.012, + "mistral.mistral-large-2407-v1:0": 0.006, + "mistral.mistral-small-2402-v1:0": 0.003, + "mistral.mixtral-7x8b-instruct-v0:1": 0.0007, + # Amazon - Titan + "amazon.titan-text-express-v1": 0.0006, + "amazon.titan-text-lite-v1": 0.0002, + "amazon.titan-text-premier-v1:0": 0.0015, +} \ No newline at end of file