diff --git a/tests/test_chromium.py b/tests/test_chromium.py new file mode 100644 index 00000000..bd589d9a --- /dev/null +++ b/tests/test_chromium.py @@ -0,0 +1,289 @@ +import asyncio +import pytest +import time + +from langchain_core.documents import Document +from scrapegraphai.docloaders.chromium import ChromiumLoader +from scrapegraphai.utils import Proxy +from unittest.mock import AsyncMock, MagicMock, patch + +class TestChromiumLoader: + @pytest.mark.asyncio + async def test_scrape_with_js_support(self): + # Arrange + url = "https://example.com" + expected_content = "JavaScript rendered content" + + # Mock the playwright and its components + with patch("playwright.async_api.async_playwright") as mock_playwright: + # Set up the mock chain + mock_browser = AsyncMock() + mock_context = AsyncMock() + mock_page = AsyncMock() + + mock_playwright.return_value.__aenter__.return_value.chromium.launch.return_value = mock_browser + mock_browser.new_context.return_value = mock_context + mock_context.new_page.return_value = mock_page + mock_page.content.return_value = expected_content + + # Create the ChromiumLoader instance + loader = ChromiumLoader([url], requires_js_support=True) + + # Act + documents = [doc async for doc in loader.alazy_load()] + + # Assert + assert len(documents) == 1 + assert documents[0].page_content == expected_content + assert documents[0].metadata["source"] == url + + # Verify that the correct methods were called + mock_page.goto.assert_called_once_with(url, wait_until="networkidle") + mock_page.content.assert_called_once() + + @pytest.mark.asyncio + async def test_ascrape_playwright_scroll(self): + # Arrange + url = "https://example.com" + expected_content = "Scrolled content" + + # Mock the playwright and its components + with patch("playwright.async_api.async_playwright") as mock_playwright, \ + patch("undetected_playwright.Malenia") as mock_malenia, \ + patch("time.sleep", return_value=None) as mock_sleep: + + # Set up the mock chain + mock_browser = AsyncMock() + mock_context = AsyncMock() + mock_page = AsyncMock() + + mock_playwright.return_value.__aenter__.return_value.chromium.launch.return_value = mock_browser + mock_browser.new_context.return_value = mock_context + mock_context.new_page.return_value = mock_page + + # Mock the page methods + mock_page.evaluate.side_effect = [1000, 2000, 2000] # Simulate scrolling + mock_page.content.return_value = expected_content + + # Create the ChromiumLoader instance + loader = ChromiumLoader([url]) + + # Act + result = await loader.ascrape_playwright_scroll(url, timeout=10, scroll=5000, sleep=1) + + # Assert + assert result == expected_content + + # Verify that the correct methods were called + mock_page.goto.assert_called_once_with(url, wait_until="domcontentloaded") + mock_page.wait_for_load_state.assert_called_once() + assert mock_page.evaluate.call_count == 3 # Called for each scroll attempt + mock_page.mouse.wheel.assert_called_with(0, 5000) + mock_sleep.assert_called_with(1) + mock_page.content.assert_called_once() + mock_browser.close.assert_awaited_once() + + # Verify that Malenia.apply_stealth was called + mock_malenia.apply_stealth.assert_called_once_with(mock_context) + + def test_lazy_load(self): + # Arrange + urls = ["https://example1.com", "https://example2.com"] + expected_content = "Test content" + + # Mock the ascrape_playwright method + with patch.object(ChromiumLoader, 'ascrape_playwright', new_callable=AsyncMock) as mock_scrape: + mock_scrape.return_value = expected_content + + # Create the ChromiumLoader instance + loader = ChromiumLoader(urls) + + # Act + documents = list(loader.lazy_load()) + + # Assert + assert len(documents) == 2 + for i, doc in enumerate(documents): + assert isinstance(doc, Document) + assert doc.page_content == expected_content + assert doc.metadata["source"] == urls[i] + + # Verify that the mocked method was called for each URL + assert mock_scrape.call_count == 2 + mock_scrape.assert_any_call(urls[0]) + mock_scrape.assert_any_call(urls[1]) + + @pytest.mark.asyncio + async def test_ascrape_undetected_chromedriver(self): + # Arrange + url = "https://example.com" + expected_content = "Selenium scraped content" + + # Mock undetected_chromedriver and Selenium components + with patch("undetected_chromedriver.Chrome") as mock_chrome, \ + patch("selenium.webdriver.chrome.options.Options") as mock_options: + + # Set up the mock chain + mock_driver = MagicMock() + mock_driver.page_source = expected_content + mock_chrome.return_value = mock_driver + + # Create the ChromiumLoader instance + loader = ChromiumLoader([url], backend="selenium", browser_name="chromium") + + # Act + result = await loader.ascrape_undetected_chromedriver(url) + + # Assert + assert result == expected_content + + # Verify that the correct methods were called + mock_options.assert_called_once() + mock_chrome.assert_called_once() + mock_driver.get.assert_called_once_with(url) + mock_driver.quit.assert_called_once() + + @pytest.mark.asyncio + async def test_ascrape_playwright_with_firefox(self): + # Arrange + url = "https://example.com" + expected_content = "Firefox scraped content" + + # Mock the playwright and its components + with patch("playwright.async_api.async_playwright") as mock_playwright, \ + patch("undetected_playwright.Malenia") as mock_malenia: + + # Set up the mock chain + mock_browser = AsyncMock() + mock_context = AsyncMock() + mock_page = AsyncMock() + + mock_playwright.return_value.__aenter__.return_value.firefox.launch.return_value = mock_browser + mock_browser.new_context.return_value = mock_context + mock_context.new_page.return_value = mock_page + mock_page.content.return_value = expected_content + + # Create the ChromiumLoader instance with Firefox + loader = ChromiumLoader([url], browser_name="firefox") + + # Act + result = await loader.ascrape_playwright(url, browser_name="firefox") + + # Assert + assert result == expected_content + + # Verify that the correct methods were called + mock_playwright.return_value.__aenter__.return_value.firefox.launch.assert_called_once() + mock_browser.new_context.assert_called_once() + mock_context.new_page.assert_called_once() + mock_page.goto.assert_called_once_with(url, wait_until="domcontentloaded") + mock_page.wait_for_load_state.assert_called_once() + mock_page.content.assert_called_once() + mock_browser.close.assert_awaited_once() + + # Verify that Malenia.apply_stealth was called + mock_malenia.apply_stealth.assert_called_once_with(mock_context) + + @pytest.mark.asyncio + async def test_ascrape_playwright_scroll_to_bottom(self): + # Arrange + url = "https://example.com" + expected_content = "Scrolled to bottom content" + + # Mock the playwright and its components + with patch("playwright.async_api.async_playwright") as mock_playwright, \ + patch("undetected_playwright.Malenia") as mock_malenia, \ + patch("time.sleep", return_value=None) as mock_sleep, \ + patch("time.time", side_effect=[0, 5, 10, 15]): # Simulate time passing + + # Set up the mock chain + mock_browser = AsyncMock() + mock_context = AsyncMock() + mock_page = AsyncMock() + + mock_playwright.return_value.__aenter__.return_value.chromium.launch.return_value = mock_browser + mock_browser.new_context.return_value = mock_context + mock_context.new_page.return_value = mock_page + + # Mock the page methods + mock_page.evaluate.side_effect = [1000, 2000, 3000, 3000] # Simulate scrolling until bottom + mock_page.content.return_value = expected_content + + # Create the ChromiumLoader instance + loader = ChromiumLoader([url]) + + # Act + result = await loader.ascrape_playwright_scroll(url, timeout=30, scroll=15000, sleep=2, scroll_to_bottom=True) + + # Assert + assert result == expected_content + + # Verify that the correct methods were called + mock_page.goto.assert_called_once_with(url, wait_until="domcontentloaded") + mock_page.wait_for_load_state.assert_called_once() + assert mock_page.evaluate.call_count == 4 # Called for each scroll attempt + assert mock_page.mouse.wheel.call_count == 3 # Called until bottom is reached + mock_page.mouse.wheel.assert_called_with(0, 15000) + assert mock_sleep.call_count == 3 # Called after each scroll + mock_page.content.assert_called_once() + mock_browser.close.assert_awaited_once() + + # Verify that Malenia.apply_stealth was called + mock_malenia.apply_stealth.assert_called_once_with(mock_context) + + @pytest.mark.asyncio + async def test_chromium_loader_with_custom_proxy(self): + # Arrange + url = "https://example.com" + expected_content = "Proxied content" + custom_proxy = Proxy( + server="http://proxy.example.com:8080", + username="user", + password="pass" + ) + + with patch("scrapegraphai.docloaders.chromium.parse_or_search_proxy") as mock_parse_proxy, \ + patch("playwright.async_api.async_playwright") as mock_playwright, \ + patch("undetected_playwright.Malenia") as mock_malenia: + + # Set up the mock chain + mock_parse_proxy.return_value = custom_proxy + mock_browser = AsyncMock() + mock_context = AsyncMock() + mock_page = AsyncMock() + + mock_playwright.return_value.__aenter__.return_value.chromium.launch.return_value = mock_browser + mock_browser.new_context.return_value = mock_context + mock_context.new_page.return_value = mock_page + mock_page.content.return_value = expected_content + + # Create the ChromiumLoader instance with custom proxy + loader = ChromiumLoader([url], proxy=custom_proxy) + + # Act + result = await loader.ascrape_playwright(url) + + # Assert + assert result == expected_content + + # Verify that the proxy was correctly parsed and used + mock_parse_proxy.assert_called_once_with(custom_proxy) + mock_playwright.return_value.__aenter__.return_value.chromium.launch.assert_called_once_with( + headless=True, + proxy={ + "server": "http://proxy.example.com:8080", + "username": "user", + "password": "pass" + } + ) + + # Verify other method calls + mock_browser.new_context.assert_called_once() + mock_context.new_page.assert_called_once() + mock_page.goto.assert_called_once_with(url, wait_until="domcontentloaded") + mock_page.wait_for_load_state.assert_called_once() + mock_page.content.assert_called_once() + mock_browser.close.assert_awaited_once() + + # Verify that Malenia.apply_stealth was called + mock_malenia.apply_stealth.assert_called_once_with(mock_context) \ No newline at end of file diff --git a/tests/test_generate_answer_node.py b/tests/test_generate_answer_node.py new file mode 100644 index 00000000..d264d93d --- /dev/null +++ b/tests/test_generate_answer_node.py @@ -0,0 +1,196 @@ +import unittest + +from langchain_core.runnables import RunnableParallel +from pydantic import BaseModel, Field +from requests.exceptions import Timeout +from scrapegraphai.nodes.generate_answer_node import GenerateAnswerNode +from unittest.mock import MagicMock, patch + +class TestGenerateAnswerNode(unittest.TestCase): + def setUp(self): + self.node_config = { + "llm_model": MagicMock(), + "verbose": False, + "force": False, + "script_creator": False, + "is_md_scraper": False, + "timeout": 10 + } + self.node = GenerateAnswerNode("input", ["output"], self.node_config) + + @patch.object(GenerateAnswerNode, 'invoke_with_timeout') + def test_execute_timeout(self, mock_invoke): + # Simulate a timeout during execution + mock_invoke.side_effect = Timeout("Response timeout exceeded") + + # Prepare input state + input_state = { + "input": "test question", + "chunks": ["chunk1", "chunk2"] + } + + # Execute the node + result_state = self.node.execute(input_state) + + # Check if the output contains the expected error message + self.assertIn("output", result_state) + self.assertIsInstance(result_state["output"], dict) + self.assertIn("error", result_state["output"]) + self.assertEqual(result_state["output"]["error"], "Response timeout exceeded") + + # Verify that invoke_with_timeout was called + mock_invoke.assert_called() + + @patch.object(GenerateAnswerNode, 'invoke_with_timeout') + def test_execute_single_chunk(self, mock_invoke): + # Prepare a mock LLM model + mock_llm = MagicMock() + + # Configure the node + node_config = { + "llm_model": mock_llm, + "verbose": False, + "force": False, + "script_creator": False, + "is_md_scraper": False, + "timeout": 10 + } + node = GenerateAnswerNode("input", ["output"], node_config) + + # Prepare input state with a single document chunk + input_state = { + "input": "test question", + "chunks": ["single chunk content"] + } + + # Mock the response from invoke_with_timeout + mock_invoke.return_value = "Mocked answer for single chunk" + + # Execute the node + result_state = node.execute(input_state) + + # Assert that invoke_with_timeout was called once + mock_invoke.assert_called_once() + + # Check if the output contains the expected answer + self.assertIn("output", result_state) + self.assertEqual(result_state["output"], "Mocked answer for single chunk") + + # Verify that the correct template was used (TEMPLATE_NO_CHUNKS_MD) + call_args = mock_invoke.call_args[0] + self.assertIn("context", call_args[1]) + self.assertEqual(call_args[1]["context"], ["single chunk content"]) + + @patch('scrapegraphai.nodes.generate_answer_node.RunnableParallel') + @patch.object(GenerateAnswerNode, 'invoke_with_timeout') + def test_execute_multiple_chunks(self, mock_invoke, mock_runnable_parallel): + # Prepare a mock LLM model + mock_llm = MagicMock() + + # Configure the node + node_config = { + "llm_model": mock_llm, + "verbose": False, + "force": False, + "script_creator": False, + "is_md_scraper": False, + "timeout": 10 + } + node = GenerateAnswerNode("input", ["output"], node_config) + + # Prepare input state with multiple document chunks + input_state = { + "input": "test question", + "chunks": ["chunk1 content", "chunk2 content", "chunk3 content"] + } + + # Mock the response from RunnableParallel + mock_runnable_parallel.return_value = MagicMock() + mock_runnable_parallel.return_value.invoke.return_value = { + "chunk1": "Result from chunk1", + "chunk2": "Result from chunk2", + "chunk3": "Result from chunk3" + } + + # Mock the response from invoke_with_timeout for the merge step + mock_invoke.side_effect = [ + mock_runnable_parallel.return_value.invoke.return_value, + "Merged result from all chunks" + ] + + # Execute the node + result_state = node.execute(input_state) + + # Assert that invoke_with_timeout was called twice (once for chunks, once for merge) + self.assertEqual(mock_invoke.call_count, 2) + + # Check if the output contains the expected merged answer + self.assertIn("output", result_state) + self.assertEqual(result_state["output"], "Merged result from all chunks") + + # Verify that RunnableParallel was created with the correct number of chains + mock_runnable_parallel.assert_called_once() + self.assertEqual(len(mock_runnable_parallel.call_args[1]), 3) # 3 chunks + + # Verify that the merge step was called with the correct context + merge_call_args = mock_invoke.call_args_list[1][0] + self.assertIn("context", merge_call_args[1]) + self.assertEqual(merge_call_args[1]["context"], { + "chunk1": "Result from chunk1", + "chunk2": "Result from chunk2", + "chunk3": "Result from chunk3" + }) + + @patch.object(GenerateAnswerNode, 'invoke_with_timeout') + def test_execute_with_additional_info(self, mock_invoke): + # Prepare a mock LLM model + mock_llm = MagicMock() + + # Configure the node with additional_info + node_config = { + "llm_model": mock_llm, + "verbose": False, + "force": False, + "script_creator": False, + "is_md_scraper": False, + "timeout": 10, + "additional_info": "This is additional information: " + } + node = GenerateAnswerNode("input", ["output"], node_config) + + # Prepare input state with a single document chunk + input_state = { + "input": "test question", + "chunks": ["single chunk content"] + } + + # Mock the response from invoke_with_timeout + mock_invoke.return_value = "Mocked answer with additional info" + + # Execute the node + result_state = node.execute(input_state) + + # Assert that invoke_with_timeout was called once + mock_invoke.assert_called_once() + + # Check if the output contains the expected answer + self.assertIn("output", result_state) + self.assertEqual(result_state["output"], "Mocked answer with additional info") + + # Verify that the correct template was used and includes additional_info + call_args = mock_invoke.call_args[0] + self.assertIn("This is additional information: ", str(call_args[0])) + + # Verify that the context is correct + self.assertIn("context", call_args[1]) + self.assertEqual(call_args[1]["context"], ["single chunk content"]) + + @patch('scrapegraphai.nodes.generate_answer_node.get_pydantic_output_parser') + @patch.object(GenerateAnswerNode, 'invoke_with_timeout') + def test_execute_with_custom_schema(self, mock_invoke, mock_get_parser): + # Define a custom schema + class CustomSchema(BaseModel): + answer: str = Field(description="The generated answer") + confidence: float = Field(description="Confidence score of the answer") + + # Prepare a mock L \ No newline at end of file diff --git a/tests/test_parse_node.py b/tests/test_parse_node.py new file mode 100644 index 00000000..2bfb17e2 --- /dev/null +++ b/tests/test_parse_node.py @@ -0,0 +1,223 @@ +import pytest + +from langchain_community.document_transformers import Html2TextTransformer +from langchain_core.documents import Document +from scrapegraphai.nodes.parse_node import ParseNode +from unittest.mock import MagicMock, patch + +class TestParseNode: + + def test_extract_relative_urls(self): + # Setup + node_config = { + "verbose": True, + "parse_html": False, + "parse_urls": True, + "chunk_size": 1000 + } + parse_node = ParseNode( + input="input,source", + output=["chunks", "link_urls", "img_urls"], + node_config=node_config + ) + + # Mock input data + mock_document = Document(page_content="Check out this relative link and this ") + mock_source = "https://example.com" + + mock_state = { + "input": [mock_document], + "source": mock_source + } + + # Execute + result_state = parse_node.execute(mock_state) + + # Assert + assert "chunks" in result_state + assert "link_urls" in result_state + assert "img_urls" in result_state + + assert "https://example.com/relative/path" in result_state["link_urls"] + assert "https://example.com/images/photo.jpg" in result_state["img_urls"] + + def test_parse_html_content(self): + # Setup + node_config = { + "verbose": True, + "parse_html": True, + "parse_urls": False, + "chunk_size": 1000 + } + parse_node = ParseNode( + input="input", + output=["chunks"], + node_config=node_config + ) + + # Mock input data + html_content = "

Test Header

This is a test paragraph.

" + mock_document = Document(page_content=html_content) + mock_state = { + "input": [mock_document] + } + + # Mock Html2TextTransformer + mock_transformed_doc = Document(page_content="Test Header\n\nThis is a test paragraph.") + mock_transformer = MagicMock() + mock_transformer.transform_documents.return_value = [mock_transformed_doc] + + # Execute with mocked Html2TextTransformer + with patch('scrapegraphai.nodes.parse_node.Html2TextTransformer', return_value=mock_transformer): + result_state = parse_node.execute(mock_state) + + # Assert + assert "chunks" in result_state + assert len(result_state["chunks"]) > 0 + assert "Test Header" in result_state["chunks"][0] + assert "This is a test paragraph" in result_state["chunks"][0] + + def test_parse_urls_without_html_parsing(self): + # Setup + node_config = { + "verbose": True, + "parse_html": False, + "parse_urls": True, + "chunk_size": 1000 + } + parse_node = ParseNode( + input="input,source", + output=["chunks", "link_urls", "img_urls"], + node_config=node_config + ) + + # Mock input data + text_content = "Check out https://example.com and /relative/path.html. Also, see image.jpg" + mock_document = Document(page_content=text_content) + mock_source = "https://sourcesite.com" + + mock_state = { + "input": [mock_document], + "source": mock_source + } + + # Execute + result_state = parse_node.execute(mock_state) + + # Assert + assert "chunks" in result_state + assert "link_urls" in result_state + assert "img_urls" in result_state + + assert "https://example.com" in result_state["link_urls"] + assert "https://sourcesite.com/relative/path.html" in result_state["link_urls"] + assert "https://sourcesite.com/image.jpg" in result_state["img_urls"] + assert len(result_state["chunks"]) > 0 + assert text_content in result_state["chunks"][0] + + def test_large_document_chunking(self): + # Setup + node_config = { + "verbose": False, + "parse_html": False, + "parse_urls": False, + "chunk_size": 100 # Small chunk size for testing + } + parse_node = ParseNode( + input="input", + output=["chunks"], + node_config=node_config + ) + + # Create a large document + large_content = "This is a test sentence. " * 50 # 1250 characters + mock_document = Document(page_content=large_content) + mock_state = { + "input": [mock_document] + } + + # Execute + result_state = parse_node.execute(mock_state) + + # Assert + assert "chunks" in result_state + chunks = result_state["chunks"] + + # Check if we have the expected number of chunks + expected_chunks = 13 # 1250 / 100 = 12.5, rounded up to 13 + assert len(chunks) == expected_chunks, f"Expected {expected_chunks} chunks, but got {len(chunks)}" + + # Check if each chunk is approximately the right size + for chunk in chunks[:-1]: # All but the last chunk + assert 80 <= len(chunk) <= 100, f"Chunk size {len(chunk)} is out of expected range" + + # Check if the content is preserved + reconstructed_content = " ".join(chunks) + assert reconstructed_content.strip() == large_content.strip(), "Content was not preserved after chunking" + + def test_missing_input_keys(self): + # Setup + node_config = { + "verbose": False, + "parse_html": True, + "parse_urls": True, + "chunk_size": 1000 + } + parse_node = ParseNode( + input="input,source", + output=["chunks", "link_urls", "img_urls"], + node_config=node_config + ) + + # Create a state with missing input keys + incomplete_state = {} + + # Execute and assert + with pytest.raises(KeyError) as excinfo: + parse_node.execute(incomplete_state) + + # Check if the error message contains information about missing keys + assert "input" in str(excinfo.value) or "source" in str(excinfo.value), \ + "KeyError should mention missing input keys" + + def test_invalid_url_handling(self): + # Setup + node_config = { + "verbose": False, + "parse_html": False, + "parse_urls": True, + "chunk_size": 1000 + } + parse_node = ParseNode( + input="input,source", + output=["chunks", "link_urls", "img_urls"], + node_config=node_config + ) + + # Mock input data with an invalid URL + invalid_url = "http://invalid[url].com" + text_content = f"Check out this invalid URL: {invalid_url}" + mock_document = Document(page_content=text_content) + mock_source = "https://example.com" + + mock_state = { + "input": [mock_document], + "source": mock_source + } + + # Execute + result_state = parse_node.execute(mock_state) + + # Assert + assert "chunks" in result_state + assert "link_urls" in result_state + assert "img_urls" in result_state + + # Check that the invalid URL is not in the extracted URLs + assert invalid_url not in result_state["link_urls"] + + # Check that the chunk still contains the original text + assert text_content in result_state["chunks"][0] + + # Verify that at least the valid source URL is extracted + assert mock_source in result_state["link_urls"] \ No newline at end of file diff --git a/tests/test_split_text_into_chunks.py b/tests/test_split_text_into_chunks.py new file mode 100644 index 00000000..789b8382 --- /dev/null +++ b/tests/test_split_text_into_chunks.py @@ -0,0 +1,220 @@ +import pytest + +from scrapegraphai.utils.split_text_into_chunks import split_text_into_chunks +from unittest.mock import Mock, patch + +class TestSplitTextIntoChunks: + def test_split_text_without_semchunk(self): + # Test splitting text without using semchunk + text = "This is a test sentence. It should be split into chunks based on token count." + chunk_size = 10 + expected_chunks = [ + "This is a test sentence.", + "It should be split into chunks", + "based on token count." + ] + + result = split_text_into_chunks(text, chunk_size, use_semchunk=False) + + assert result == expected_chunks, f"Expected {expected_chunks}, but got {result}" + + @patch('scrapegraphai.utils.split_text_into_chunks.chunk') + def test_split_text_with_semchunk(self, mock_chunk): + # Test splitting text using semchunk + text = "This is a test sentence. It should be split into chunks based on token count." + chunk_size = 10 + expected_chunks = [ + "This is a test sentence.", + "It should be split into chunks", + "based on token count." + ] + + # Mock the semchunk.chunk function to return our expected chunks + mock_chunk.return_value = expected_chunks + + result = split_text_into_chunks(text, chunk_size, use_semchunk=True) + + assert result == expected_chunks, f"Expected {expected_chunks}, but got {result}" + mock_chunk.assert_called_once() + + def test_short_text_no_splitting(self): + # Test when the input text is shorter than the chunk size + text = "This is a short text." + chunk_size = 50 # Larger than the token count of the input text + expected_chunks = [text] # The entire text should be returned as a single chunk + + result = split_text_into_chunks(text, chunk_size, use_semchunk=False) + + assert result == expected_chunks, f"Expected {expected_chunks}, but got {result}" + assert len(result) == 1, f"Expected 1 chunk, but got {len(result)}" + + def test_very_small_chunk_size(self): + # Test splitting text with a very small chunk size + text = "This is a test with small chunks." + chunk_size = 3 # Very small chunk size + expected_chunks = [ + "This", + "is", + "a", + "test", + "with", + "small", + "chunks." + ] + + result = split_text_into_chunks(text, chunk_size, use_semchunk=False) + + assert result == expected_chunks, f"Expected {expected_chunks}, but got {result}" + assert len(result) == len(expected_chunks), f"Expected {len(expected_chunks)} chunks, but got {len(result)}" + + @patch('scrapegraphai.utils.split_text_into_chunks.num_tokens_calculus') + def test_chunk_size_equal_to_token_count(self, mock_num_tokens): + # Test when chunk size is exactly equal to the token count of the input text + text = "This is a test sentence with exact token count." + chunk_size = 10 # Set to match the mocked token count + expected_chunks = [text] # The entire text should be returned as a single chunk + + # Mock num_tokens_calculus to always return the chunk_size + mock_num_tokens.return_value = chunk_size + + result = split_text_into_chunks(text, chunk_size, use_semchunk=False) + + assert result == expected_chunks, f"Expected {expected_chunks}, but got {result}" + assert len(result) == 1, f"Expected 1 chunk, but got {len(result)}" + mock_num_tokens.assert_called_with(text) + + @patch('scrapegraphai.utils.split_text_into_chunks.chunk') + @patch('scrapegraphai.utils.split_text_into_chunks.num_tokens_calculus') + def test_large_chunk_size_with_semchunk(self, mock_num_tokens, mock_chunk): + # Test splitting text using semchunk with a large chunk size + text = "This is a test sentence for large chunk size." + large_chunk_size = 1000 + expected_adjusted_chunk_size = int(large_chunk_size * 0.9) + expected_chunks = ["This is a test sentence", "for large chunk size."] + + # Mock num_tokens_calculus to return a consistent value + mock_num_tokens.return_value = 5 + + # Mock the semchunk.chunk function to return our expected chunks + mock_chunk.return_value = expected_chunks + + result = split_text_into_chunks(text, large_chunk_size, use_semchunk=True) + + assert result == expected_chunks, f"Expected {expected_chunks}, but got {result}" + mock_chunk.assert_called_once_with( + text=text, + chunk_size=expected_adjusted_chunk_size, + token_counter=mock_num_tokens.return_value, + memoize=False + ) + assert mock_num_tokens.call_count > 0, "num_tokens_calculus should have been called" + + @patch('scrapegraphai.utils.split_text_into_chunks.num_tokens_calculus') + def test_split_text_with_long_words(self, mock_num_tokens): + # Test splitting text with very long words + text = "Short verylongwordthatexceedschunksize another extremelylongwordthatalsosexceedschunksize end." + chunk_size = 10 + + # Mock num_tokens_calculus to return the length of each word + mock_num_tokens.side_effect = lambda word: len(word) + + expected_chunks = [ + "Short", + "verylongwordthatexceedschunksize", + "another", + "extremelylongwordthatalsosexceedschunksize", + "end." + ] + + result = split_text_into_chunks(text, chunk_size, use_semchunk=False) + + assert result == expected_chunks, f"Expected {expected_chunks}, but got {result}" + assert len(result) == len(expected_chunks), f"Expected {len(expected_chunks)} chunks, but got {len(result)}" + + # Verify that num_tokens_calculus was called for each word + assert mock_num_tokens.call_count == len(text.split()) + + @patch('scrapegraphai.utils.split_text_into_chunks.num_tokens_calculus') + def test_split_text_with_newlines(self, mock_num_tokens): + # Test splitting text that contains newline characters + text = "This is a test\nwith multiple lines.\nIt should split correctly." + chunk_size = 15 + + # Mock num_tokens_calculus to return a fixed value for simplicity + mock_num_tokens.return_value = 1 + + expected_chunks = [ + "This is a test", + "with multiple", + "lines. It should", + "split correctly." + ] + + result = split_text_into_chunks(text, chunk_size, use_semchunk=False) + + assert result == expected_chunks, f"Expected {expected_chunks}, but got {result}" + assert len(result) == len(expected_chunks), f"Expected {len(expected_chunks)} chunks, but got {len(result)}" + + # Verify that num_tokens_calculus was called for each word + assert mock_num_tokens.call_count == len(text.split()) + + # Check that newlines are preserved within chunks + assert "\n" in result[1], "Newline should be preserved in the second chunk" + + @patch('scrapegraphai.utils.split_text_into_chunks.num_tokens_calculus') + def test_split_text_with_unicode_characters(self, mock_num_tokens): + # Test splitting text that contains Unicode characters + text = "This is a test with Unicode: 你好世界! Здравствуй, мир! 🌍🌎🌏" + chunk_size = 10 + + # Mock num_tokens_calculus to return a fixed value for simplicity + mock_num_tokens.return_value = 1 + + expected_chunks = [ + "This is a test", + "with Unicode:", + "你好世界!", + "Здравствуй,", + "мир! 🌍🌎🌏" + ] + + result = split_text_into_chunks(text, chunk_size, use_semchunk=False) + + assert result == expected_chunks, f"Expected {expected_chunks}, but got {result}" + assert len(result) == len(expected_chunks), f"Expected {len(expected_chunks)} chunks, but got {len(result)}" + + # Verify that num_tokens_calculus was called for each word or character + assert mock_num_tokens.call_count == len(text.split()) + + # Check that Unicode characters are preserved in chunks + assert "你好世界!" in result, "Chinese characters should be preserved in a chunk" + assert "Здравствуй," in result, "Cyrillic characters should be preserved in a chunk" + assert "🌍🌎🌏" in result[-1], "Emojis should be preserved in the last chunk" + + @patch('scrapegraphai.utils.split_text_into_chunks.num_tokens_calculus') + def test_split_text_with_multiple_spaces(self, mock_num_tokens): + # Test splitting text that contains multiple spaces between words + text = "This is a test with multiple spaces." + chunk_size = 10 + + # Mock num_tokens_calculus to return 1 for each word, ignoring spaces + mock_num_tokens.side_effect = lambda word: 1 if word.strip() else 0 + + expected_chunks = [ + "This is a", + "test with", + "multiple", + "spaces." + ] + + result = split_text_into_chunks(text, chunk_size, use_semchunk=False) + + assert result == expected_chunks, f"Expected {expected_chunks}, but got {result}" + assert len(result) == len(expected_chunks), f"Expected {len(expected_chunks)} chunks, but got {len(result)}" + + # Verify that num_tokens_calculus was called for each word and space + assert mock_num_tokens.call_count == len(text.split()) + text.count(' ') + + # Check that multiple spaces are preserved within chunks + assert " " in result[0], "Multiple spaces should be preserved in the first chunk" + assert " " in result[1], "Multiple spaces should be preserved in the second chunk" \ No newline at end of file diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py new file mode 100644 index 00000000..b7d182e0 --- /dev/null +++ b/tests/test_tokenizer.py @@ -0,0 +1,19 @@ +import pytest + +from scrapegraphai.utils.tokenizer import num_tokens_calculus +from unittest.mock import patch + +class TestTokenizer: + @patch('scrapegraphai.utils.tokenizer.num_tokens_openai') + def test_num_tokens_calculus_calls_openai_tokenizer(self, mock_num_tokens_openai): + # Arrange + test_string = "This is a test string" + expected_tokens = 5 + mock_num_tokens_openai.return_value = expected_tokens + + # Act + result = num_tokens_calculus(test_string) + + # Assert + mock_num_tokens_openai.assert_called_once_with(test_string) + assert result == expected_tokens \ No newline at end of file diff --git a/tests/test_tokenizer_openai.py b/tests/test_tokenizer_openai.py new file mode 100644 index 00000000..09edfec9 --- /dev/null +++ b/tests/test_tokenizer_openai.py @@ -0,0 +1,30 @@ +import unittest + +from scrapegraphai.utils.tokenizers.tokenizer_openai import num_tokens_openai +from unittest.mock import MagicMock, patch + +class TestTokenizerOpenAI(unittest.TestCase): + @patch('scrapegraphai.utils.tokenizers.tokenizer_openai.get_logger') + @patch('scrapegraphai.utils.tokenizers.tokenizer_openai.tiktoken') + def test_num_tokens_openai_simple_input(self, mock_tiktoken, mock_get_logger): + # Arrange + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + + mock_encoding = MagicMock() + mock_encoding.encode.return_value = [1, 2, 3, 4, 5] # Simulating 5 tokens + mock_tiktoken.encoding_for_model.return_value = mock_encoding + + test_text = "This is a test sentence." + + # Act + result = num_tokens_openai(test_text) + + # Assert + self.assertEqual(result, 5) + mock_logger.debug.assert_called_once_with(f"Counting tokens for text of {len(test_text)} characters") + mock_tiktoken.encoding_for_model.assert_called_once_with("gpt-4o") + mock_encoding.encode.assert_called_once_with(test_text) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file