diff --git a/.gitignore b/.gitignore index 47d38ba..697cac9 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ log/* logs/ parts/* json_results/* +pdfs/ +results/ \ No newline at end of file diff --git a/pageindex/page_index.py b/pageindex/page_index.py index 882fb5d..29ccd58 100644 --- a/pageindex/page_index.py +++ b/pageindex/page_index.py @@ -4,6 +4,7 @@ import math import random import re +import logging from .utils import * import os from concurrent.futures import ThreadPoolExecutor, as_completed @@ -322,10 +323,70 @@ def toc_transformer(toc_content, model=None): if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model) - last_complete = json.loads(last_complete) - - cleaned_response=convert_page_to_int(last_complete['table_of_contents']) - return cleaned_response + # Use extract_json instead of json.loads for better error handling + try: + # First, try to get the JSON content properly + if isinstance(last_complete, str): + # Clean the response to extract just the JSON part + last_complete = last_complete.strip() + + # Debug: log what we're trying to parse + logging.info(f"Attempting to parse JSON: {last_complete[:500]}...") # Log first 500 chars + + # Try using extract_json first + last_complete_json = extract_json(last_complete) + + if isinstance(last_complete_json, dict) and 'table_of_contents' in last_complete_json: + cleaned_response = convert_page_to_int(last_complete_json['table_of_contents']) + return cleaned_response + else: + logging.warning(f"extract_json returned unexpected format: {type(last_complete_json)}") + logging.warning(f"Keys available: {list(last_complete_json.keys()) if isinstance(last_complete_json, dict) else 'Not a dict'}") + + # Fallback: try direct JSON parsing + last_complete_json = json.loads(last_complete) + if 'table_of_contents' in last_complete_json: + cleaned_response = convert_page_to_int(last_complete_json['table_of_contents']) + return cleaned_response + else: + logging.error(f"JSON doesn't contain 'table_of_contents' key. Available keys: {list(last_complete_json.keys())}") + + except json.JSONDecodeError as e: + logging.error(f"JSON parsing error in toc_transformer: {e}") + logging.error(f"Problematic JSON content (first 1000 chars): {last_complete[:1000]}") + + except Exception as e: + logging.error(f"Unexpected error in toc_transformer: {e}") + logging.error(f"Content type: {type(last_complete)}") + + # Final fallback: try to extract JSON manually + try: + # Find the start and end of JSON structure + start_idx = last_complete.find('{') + if start_idx != -1: + # Find the matching closing brace + brace_count = 0 + for i, char in enumerate(last_complete[start_idx:], start_idx): + if char == '{': + brace_count += 1 + elif char == '}': + brace_count -= 1 + if brace_count == 0: + clean_json = last_complete[start_idx:i+1] + logging.info(f"Manually extracted JSON: {clean_json}") + last_complete_json = json.loads(clean_json) + if 'table_of_contents' in last_complete_json: + cleaned_response = convert_page_to_int(last_complete_json['table_of_contents']) + return cleaned_response + else: + logging.error(f"Manually extracted JSON doesn't contain 'table_of_contents'. Keys: {list(last_complete_json.keys())}") + + logging.error("Could not extract valid JSON with table_of_contents") + return [] + + except Exception as fallback_error: + logging.error(f"Fallback JSON parsing also failed: {fallback_error}") + return [] @@ -1058,6 +1119,10 @@ async def tree_parser(page_list, opt, doc=None, logger=None): def page_index_main(doc, opt=None): logger = JsonLogger(doc) + # Set up cost tracking with logger + from .utils import set_global_logger + set_global_logger(logger) + is_valid_pdf = ( (isinstance(doc, str) and os.path.isfile(doc) and doc.lower().endswith(".pdf")) or isinstance(doc, BytesIO) @@ -1097,7 +1162,13 @@ async def page_index_builder(): 'structure': structure, } - return asyncio.run(page_index_builder()) + result = asyncio.run(page_index_builder()) + + # Log final cost summary to file + from .utils import log_final_cost_summary + log_final_cost_summary() + + return result def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None, diff --git a/pageindex/utils.py b/pageindex/utils.py index dc7acd8..273f140 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -1,5 +1,6 @@ import tiktoken import openai +from openai import AzureOpenAI import logging import os from datetime import datetime @@ -17,7 +18,132 @@ from pathlib import Path from types import SimpleNamespace as config -CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") + +# Load Azure OpenAI-specific configurations +OPENAI_API_TYPE = os.getenv("OPENAI_API_TYPE", "azure") +OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") +OPENAI_API_VERSION = os.getenv("OPENAI_API_VERSION") +OPENAI_API_ENGINE = os.getenv("OPENAI_API_ENGINE") + +# Note: Using modern AzureOpenAI client instead of global configuration + +# Global cost tracking +total_cost = 0.0 +call_count = 0 +global_logger = None # Add global logger reference + +# Azure OpenAI pricing per 1K tokens (update these based on your specific pricing) +# These are example rates - adjust based on your actual Azure OpenAI pricing +PRICING = { + "gpt-4": { + "input": 0.03, # per 1K input tokens + "output": 0.06 # per 1K output tokens + }, + "gpt-4-turbo": { + "input": 0.01, + "output": 0.03 + }, + "gpt-4o": { + "input": 0.005, + "output": 0.015 + }, + "gpt-4.1-nano": { # Add your specific model + "input": 0.005, # Adjust based on actual pricing + "output": 0.015 # Adjust based on actual pricing + }, + "default": { + "input": 0.01, + "output": 0.03 + } +} + +def set_global_logger(logger): + """Set the global logger for cost tracking""" + global global_logger + global_logger = logger + +def calculate_cost(model_name, input_tokens, output_tokens): + """Calculate cost based on token usage and model pricing""" + # Get pricing for the model or use default + model_pricing = PRICING.get(model_name, PRICING["default"]) + + input_cost = (input_tokens / 1000) * model_pricing["input"] + output_cost = (output_tokens / 1000) * model_pricing["output"] + total_cost = input_cost + output_cost + + return total_cost, input_cost, output_cost + +def log_cost(model_name, input_tokens, output_tokens, prompt_preview=""): + """Log cost information for the API call""" + global total_cost, call_count, global_logger + + call_cost, input_cost, output_cost = calculate_cost(model_name, input_tokens, output_tokens) + total_cost += call_cost + call_count += 1 + + # Create cost data structure + cost_data = { + "type": "api_cost", + "call_number": call_count, + "model": model_name, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "input_cost": round(input_cost, 6), + "output_cost": round(output_cost, 6), + "call_cost": round(call_cost, 6), + "running_total": round(total_cost, 6), + "prompt_preview": prompt_preview[:100] + "..." if len(prompt_preview) > 100 else prompt_preview, + "timestamp": datetime.now().isoformat() + } + + # Console output + print(f"šŸ’° API Call #{call_count} Cost:") + print(f" Model: {model_name}") + print(f" Input tokens: {input_tokens} (${input_cost:.4f})") + print(f" Output tokens: {output_tokens} (${output_cost:.4f})") + print(f" Call cost: ${call_cost:.4f}") + print(f" Running total: ${total_cost:.4f}") + if prompt_preview: + print(f" Prompt preview: {prompt_preview[:100]}...") + print("-" * 50) + + # Log to file if logger is available + if global_logger: + global_logger.log("INFO", cost_data) + + return call_cost + +def log_final_cost_summary(): + """Log final cost summary to both console and file""" + global total_cost, call_count, global_logger + + summary_data = { + "type": "cost_summary", + "total_api_calls": call_count, + "total_cost": round(total_cost, 6), + "average_cost_per_call": round(total_cost/call_count, 6) if call_count > 0 else 0, + "timestamp": datetime.now().isoformat() + } + + # Console output + print("\n" + "="*60) + print("šŸ’° FINAL COST SUMMARY") + print("="*60) + print(f"Total API calls: {call_count}") + print(f"Total cost: ${total_cost:.4f}") + print(f"Average cost per call: ${total_cost/call_count:.4f}" if call_count > 0 else "No API calls made") + print("="*60) + + # Log to file if logger is available + if global_logger: + global_logger.log("INFO", summary_data) + + return summary_data + +def get_total_cost(): + """Get the total cost accumulated so far""" + return total_cost, call_count def count_tokens(text, model=None): if not text: @@ -26,9 +152,9 @@ def count_tokens(text, model=None): tokens = enc.encode(text) return len(tokens) -def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +# Update the OpenAI client initialization to use modern Azure OpenAI client +def ChatGPT_API_with_finish_reason(model, prompt, chat_history=None): max_retries = 10 - client = openai.OpenAI(api_key=api_key) for i in range(max_retries): try: if chat_history: @@ -37,15 +163,32 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_ else: messages = [{"role": "user", "content": prompt}] + # Count input tokens + input_text = prompt if not chat_history else "\n".join([msg.get("content", "") for msg in messages]) + input_tokens = count_tokens(input_text, model) + + client = AzureOpenAI( + api_key=OPENAI_API_KEY, + api_version=OPENAI_API_VERSION, + azure_endpoint=OPENAI_API_BASE + ) response = client.chat.completions.create( - model=model, + model=OPENAI_API_ENGINE, messages=messages, temperature=0, ) + + # Get response content and count output tokens + response_content = response.choices[0].message.content + output_tokens = count_tokens(response_content, model) + + # Log cost information + log_cost(OPENAI_API_ENGINE, input_tokens, output_tokens, prompt) + if response.choices[0].finish_reason == "length": - return response.choices[0].message.content, "max_output_reached" + return response_content, "max_output_reached" else: - return response.choices[0].message.content, "finished" + return response_content, "finished" except Exception as e: print('************* Retrying *************') @@ -54,13 +197,12 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_ time.sleep(1) # Wait for 1ē§’ before retrying else: logging.error('Max retries reached for prompt: ' + prompt) - return "Error" + return "Error", "error" -def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +def ChatGPT_API(model, prompt, chat_history=None): max_retries = 10 - client = openai.OpenAI(api_key=api_key) for i in range(max_retries): try: if chat_history: @@ -69,13 +211,29 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): else: messages = [{"role": "user", "content": prompt}] + # Count input tokens + input_text = prompt if not chat_history else "\n".join([msg.get("content", "") for msg in messages]) + input_tokens = count_tokens(input_text, model) + + client = AzureOpenAI( + api_key=OPENAI_API_KEY, + api_version=OPENAI_API_VERSION, + azure_endpoint=OPENAI_API_BASE + ) response = client.chat.completions.create( - model=model, + model=OPENAI_API_ENGINE, messages=messages, temperature=0, ) + + # Get response content and count output tokens + response_content = response.choices[0].message.content + output_tokens = count_tokens(response_content, model) + + # Log cost information + log_cost(OPENAI_API_ENGINE, input_tokens, output_tokens, prompt) - return response.choices[0].message.content + return response_content except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") @@ -86,18 +244,34 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): return "Error" -async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): +async def ChatGPT_API_async(model, prompt, api_key=OPENAI_API_KEY): max_retries = 10 messages = [{"role": "user", "content": prompt}] for i in range(max_retries): try: - async with openai.AsyncOpenAI(api_key=api_key) as client: - response = await client.chat.completions.create( - model=model, - messages=messages, - temperature=0, - ) - return response.choices[0].message.content + # Count input tokens + input_tokens = count_tokens(prompt, model) + + # For Azure OpenAI with the new client + client = AzureOpenAI( + api_key=api_key, + api_version=OPENAI_API_VERSION, + azure_endpoint=OPENAI_API_BASE + ) + response = client.chat.completions.create( + model=OPENAI_API_ENGINE, + messages=messages, + temperature=0, + ) + + # Get response content and count output tokens + response_content = response.choices[0].message.content + output_tokens = count_tokens(response_content, model) + + # Log cost information + log_cost(OPENAI_API_ENGINE, input_tokens, output_tokens, prompt) + + return response_content except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") diff --git a/run_pageindex.py b/run_pageindex.py index 1070245..e14f22b 100644 --- a/run_pageindex.py +++ b/run_pageindex.py @@ -77,6 +77,13 @@ json.dump(toc_with_page_number, f, indent=2) print(f'Tree structure saved to: {output_file}') + + # Final cost summary is already logged in page_index_main + # Display it here for console output + from pageindex.utils import get_total_cost + total_cost, call_count = get_total_cost() + if call_count > 0: + print(f"\nšŸ’° Total cost: ${total_cost:.4f} for {call_count} API calls") elif args.md_path: # Validate Markdown file