Skip to content

Commit 5195dc9

Browse files
author
andrasfe
committed
GitbookLoader now inherits from BaseLoader and uses WebLoader thread-safely, CVE support added, reverted pyproject.toml and uv.lock to master
1 parent 10d6ad5 commit 5195dc9

File tree

5 files changed

+463
-194
lines changed

5 files changed

+463
-194
lines changed
Lines changed: 206 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import warnings
2-
from typing import Any, AsyncIterator, Iterator, List, Optional, Set
3-
from urllib.parse import urljoin, urlparse
2+
from typing import Any, AsyncIterator, Iterator, List, Optional, Set, Union
3+
from urllib.parse import urlparse
44

55
from bs4 import BeautifulSoup
66
from langchain_core.documents import Document
77

8+
from langchain_community.document_loaders.base import BaseLoader
89
from langchain_community.document_loaders.web_base import WebBaseLoader
910

1011

11-
class GitbookLoader(WebBaseLoader):
12+
class GitbookLoader(BaseLoader):
1213
"""Load `GitBook` data.
1314
1415
1. load from either a single page, or
@@ -25,6 +26,7 @@ def __init__(
2526
show_progress: bool = True,
2627
*,
2728
sitemap_url: Optional[str] = None,
29+
allowed_domains: Optional[Set[str]] = None,
2830
):
2931
"""Initialize with web page and whether to load all paths.
3032
@@ -44,25 +46,86 @@ def __init__(
4446
show_progress: whether to show a progress bar while loading. Default: True
4547
sitemap_url: Custom sitemap URL to use when load_all_paths is True.
4648
Defaults to "{base_url}/sitemap.xml".
49+
allowed_domains: Optional set of allowed domains to fetch from.
50+
If provided, only URLs from these domains will be processed.
51+
Helps prevent SSRF vulnerabilities in server environments.
52+
Defaults to None (all domains allowed).
4753
"""
4854
self.base_url = base_url or web_page
4955
if self.base_url.endswith("/"):
5056
self.base_url = self.base_url[:-1]
5157

52-
if load_all_paths:
53-
# set web_path to the sitemap if we want to crawl all paths
54-
if sitemap_url:
55-
web_page = sitemap_url
56-
else:
57-
web_page = f"{self.base_url}/sitemap.xml"
58-
59-
super().__init__(
60-
web_paths=(web_page,),
61-
continue_on_failure=continue_on_failure,
62-
show_progress=show_progress,
63-
)
58+
self.web_page = web_page
6459
self.load_all_paths = load_all_paths
6560
self.content_selector = content_selector
61+
self.continue_on_failure = continue_on_failure
62+
self.show_progress = show_progress
63+
self.allowed_domains = allowed_domains
64+
65+
# If allowed_domains is not specified, extract domain from web_page as default
66+
if self.allowed_domains is None:
67+
initial_domain = urlparse(web_page).netloc
68+
if initial_domain:
69+
self.allowed_domains = {initial_domain}
70+
71+
# Determine the starting URL (either a sitemap or a direct page)
72+
if load_all_paths:
73+
self.start_url = sitemap_url or f"{self.base_url}/sitemap.xml"
74+
else:
75+
self.start_url = web_page
76+
77+
# Validate the start_url is allowed
78+
if not self._is_url_allowed(self.start_url):
79+
raise ValueError(
80+
f"Domain in {self.start_url} is not in the allowed domains list: "
81+
f"{self.allowed_domains}"
82+
)
83+
84+
def _is_url_allowed(self, url: str) -> bool:
85+
"""Check if a URL's domain is allowed for processing.
86+
87+
Args:
88+
url: The URL to check
89+
90+
Returns:
91+
bool: True if the domain is allowed or if no allowed_domains set is defined
92+
"""
93+
if self.allowed_domains is None:
94+
return True
95+
96+
netloc = urlparse(url).netloc
97+
return netloc in self.allowed_domains
98+
99+
def _safe_add_url(
100+
self, url_list: List[str], url: str, url_type: str = "URL"
101+
) -> bool:
102+
"""Safely add a URL to a list if it's from an allowed domain.
103+
104+
Args:
105+
url_list: The list to add the URL to
106+
url: The URL to add
107+
url_type: Type of URL for warning message (e.g., "sitemap", "content")
108+
109+
Returns:
110+
bool: True if URL was added, False if skipped
111+
"""
112+
if self._is_url_allowed(url):
113+
url_list.append(url)
114+
return True
115+
else:
116+
warnings.warn(f"Skipping disallowed {url_type} URL: {url}")
117+
return False
118+
119+
def _create_web_loader(self, url_or_urls: Union[str, List[str]]) -> WebBaseLoader:
120+
"""Create a new WebBaseLoader instance for the given URL(s).
121+
122+
This ensures each operation gets its own isolated WebBaseLoader.
123+
"""
124+
return WebBaseLoader(
125+
web_path=url_or_urls,
126+
continue_on_failure=self.continue_on_failure,
127+
show_progress=self.show_progress,
128+
)
66129

67130
def _is_sitemap_index(self, soup: BeautifulSoup) -> bool:
68131
"""Check if the soup contains a sitemap index."""
@@ -71,19 +134,30 @@ def _is_sitemap_index(self, soup: BeautifulSoup) -> bool:
71134
def _extract_sitemap_urls(self, soup: BeautifulSoup) -> List[str]:
72135
"""Extract sitemap URLs from a sitemap index."""
73136
sitemap_tags = soup.find_all("sitemap")
74-
urls = []
137+
urls: List[str] = []
75138
for sitemap in sitemap_tags:
76139
loc = sitemap.find("loc")
77140
if loc and loc.text:
78-
urls.append(loc.text)
141+
self._safe_add_url(urls, loc.text, "sitemap")
79142
return urls
80143

81144
def _process_sitemap(
82-
self, soup: BeautifulSoup, processed_urls: Optional[Set[str]] = None
145+
self,
146+
soup: BeautifulSoup,
147+
processed_urls: Set[str],
148+
web_loader: Optional[WebBaseLoader] = None,
83149
) -> List[str]:
84-
"""Process a sitemap, handling both direct content URLs and sitemap indexes."""
85-
if processed_urls is None:
86-
processed_urls = set()
150+
"""Process a sitemap, handling both direct content URLs and sitemap indexes.
151+
152+
Args:
153+
soup: The BeautifulSoup object of the sitemap
154+
processed_urls: Set of already processed URLs to avoid cycles
155+
web_loader: WebBaseLoader instance to reuse for all requests,
156+
created if None
157+
"""
158+
# Create a loader if not provided
159+
if web_loader is None:
160+
web_loader = self._create_web_loader(self.start_url)
87161

88162
# If it's a sitemap index, recursively process each sitemap URL
89163
if self._is_sitemap_index(soup):
@@ -99,13 +173,20 @@ def _process_sitemap(
99173

100174
processed_urls.add(sitemap_url)
101175
try:
102-
# We need to temporarily set the web_paths to the sitemap URL
103-
original_web_paths = self.web_paths
104-
self.web_paths = [sitemap_url]
105-
sitemap_soup = self.scrape(parser="xml")
176+
# Temporarily override the web_path of the loader
177+
original_web_paths = web_loader.web_paths
178+
web_loader.web_paths = [sitemap_url]
179+
180+
# Reuse the same loader for the next sitemap
181+
sitemap_soup = web_loader.scrape(parser="xml")
182+
106183
# Restore original web_paths
107-
self.web_paths = original_web_paths
108-
content_urls = self._process_sitemap(sitemap_soup, processed_urls)
184+
web_loader.web_paths = original_web_paths
185+
186+
# Recursive call with the same loader
187+
content_urls = self._process_sitemap(
188+
sitemap_soup, processed_urls, web_loader
189+
)
109190
all_content_urls.extend(content_urls)
110191
except Exception as e:
111192
if self.continue_on_failure:
@@ -122,28 +203,49 @@ async def _aprocess_sitemap(
122203
self,
123204
soup: BeautifulSoup,
124205
base_url: str,
125-
processed_urls: Optional[Set[str]] = None,
206+
processed_urls: Set[str],
207+
web_loader: Optional[WebBaseLoader] = None,
126208
) -> List[str]:
127-
"""Async version of _process_sitemap."""
128-
if processed_urls is None:
129-
processed_urls = set()
209+
"""Async version of _process_sitemap.
210+
211+
Args:
212+
soup: The BeautifulSoup object of the sitemap
213+
base_url: The base URL for relative paths
214+
processed_urls: Set of already processed URLs to avoid cycles
215+
web_loader: WebBaseLoader instance to reuse for all requests,
216+
created if None
217+
"""
218+
# Create a loader if not provided
219+
if web_loader is None:
220+
web_loader = self._create_web_loader(self.start_url)
130221

131222
# If it's a sitemap index, recursively process each sitemap URL
132223
if self._is_sitemap_index(soup):
133224
sitemap_urls = self._extract_sitemap_urls(soup)
134225
all_content_urls = []
135226

136-
# Use base class's ascrape_all for efficient parallel fetching
137-
soups = await self.ascrape_all(
138-
[url for url in sitemap_urls if url not in processed_urls], parser="xml"
139-
)
140-
for sitemap_url, sitemap_soup in zip(
141-
[url for url in sitemap_urls if url not in processed_urls], soups
142-
):
227+
# Filter out already processed URLs
228+
new_urls = [url for url in sitemap_urls if url not in processed_urls]
229+
230+
if not new_urls:
231+
return []
232+
233+
# Update the web_paths of the loader to fetch all sitemaps at once
234+
original_web_paths = web_loader.web_paths
235+
web_loader.web_paths = new_urls
236+
237+
# Use the same WebBaseLoader's ascrape_all for efficient parallel fetching
238+
soups = await web_loader.ascrape_all(new_urls, parser="xml")
239+
240+
# Restore original web_paths
241+
web_loader.web_paths = original_web_paths
242+
243+
for sitemap_url, sitemap_soup in zip(new_urls, soups):
143244
processed_urls.add(sitemap_url)
144245
try:
246+
# Recursive call with the same loader
145247
content_urls = await self._aprocess_sitemap(
146-
sitemap_soup, base_url, processed_urls
248+
sitemap_soup, base_url, processed_urls, web_loader
147249
)
148250
all_content_urls.extend(content_urls)
149251
except Exception as e:
@@ -159,53 +261,84 @@ async def _aprocess_sitemap(
159261

160262
def lazy_load(self) -> Iterator[Document]:
161263
"""Fetch text from one single GitBook page or recursively from sitemap."""
162-
if self.load_all_paths:
163-
# Get initial sitemap
164-
soup_info = self.scrape()
264+
if not self.load_all_paths:
265+
# Simple case: load a single page
266+
temp_loader = self._create_web_loader(self.web_page)
267+
soup = temp_loader.scrape()
268+
doc = self._get_document(soup, self.web_page)
269+
if doc:
270+
yield doc
271+
else:
272+
# Get initial sitemap using the recursive method
273+
temp_loader = self._create_web_loader(self.start_url)
274+
soup_info = temp_loader.scrape(parser="xml")
165275

166276
# Process sitemap(s) recursively to get all content URLs
167-
relative_paths = self._process_sitemap(soup_info)
277+
processed_urls: Set[str] = set()
278+
relative_paths = self._process_sitemap(soup_info, processed_urls)
279+
168280
if not relative_paths and self.show_progress:
169-
warnings.warn(
170-
f"No content URLs found in sitemap at {self.web_paths[0]}"
171-
)
281+
warnings.warn(f"No content URLs found in sitemap at {self.start_url}")
282+
283+
# Build full URLs from relative paths
284+
urls: List[str] = []
285+
for url in relative_paths:
286+
# URLs are now already absolute from _get_paths
287+
self._safe_add_url(urls, url, "content")
172288

173-
urls = [urljoin(self.base_url, path) for path in relative_paths]
289+
if not urls:
290+
return
174291

175-
# Use base class's scrape_all to efficiently fetch all pages
176-
soup_infos = self.scrape_all(urls)
292+
# Create a loader for content pages
293+
content_loader = self._create_web_loader(urls)
294+
295+
# Use WebBaseLoader to fetch all pages
296+
soup_infos = content_loader.scrape_all(urls)
177297

178298
for soup_info, url in zip(soup_infos, urls):
179299
doc = self._get_document(soup_info, url)
180300
if doc:
181301
yield doc
182-
else:
183-
# Use base class functionality directly for single page
184-
for doc in super().lazy_load():
185-
yield doc
186302

187303
async def alazy_load(self) -> AsyncIterator[Document]:
188304
"""Asynchronously fetch text from GitBook page(s)."""
189305
if not self.load_all_paths:
190-
# For single page case, use the parent class implementation
191-
async for doc in super().alazy_load():
306+
# Simple case: load a single page asynchronously
307+
temp_loader = self._create_web_loader(self.web_page)
308+
soups = await temp_loader.ascrape_all([self.web_page])
309+
soup_info = soups[0]
310+
doc = self._get_document(soup_info, self.web_page)
311+
if doc:
192312
yield doc
193313
else:
194-
# Fetch initial sitemap using base class's functionality
195-
soups = await self.ascrape_all(self.web_paths, parser="xml")
314+
# Get initial sitemap - web_loader will be created in _aprocess_sitemap
315+
temp_loader = self._create_web_loader(self.start_url)
316+
soups = await temp_loader.ascrape_all([self.start_url], parser="xml")
196317
soup_info = soups[0]
197318

198319
# Process sitemap(s) recursively to get all content URLs
199-
relative_paths = await self._aprocess_sitemap(soup_info, self.base_url)
320+
processed_urls: Set[str] = set()
321+
relative_paths = await self._aprocess_sitemap(
322+
soup_info, self.base_url, processed_urls
323+
)
324+
200325
if not relative_paths and self.show_progress:
201-
warnings.warn(
202-
f"No content URLs found in sitemap at {self.web_paths[0]}"
203-
)
326+
warnings.warn(f"No content URLs found in sitemap at {self.start_url}")
327+
328+
# Build full URLs from relative paths
329+
urls: List[str] = []
330+
for url in relative_paths:
331+
# URLs are now already absolute from _get_paths
332+
self._safe_add_url(urls, url, "content")
333+
334+
if not urls:
335+
return
204336

205-
urls = [urljoin(self.base_url, path) for path in relative_paths]
337+
# Create a loader for content pages
338+
content_loader = self._create_web_loader(urls)
206339

207-
# Use base class's ascrape_all for efficient parallel fetching
208-
soup_infos = await self.ascrape_all(urls)
340+
# Use WebBaseLoader's ascrape_all for efficient parallel fetching
341+
soup_infos = await content_loader.ascrape_all(urls)
209342

210343
for soup_info, url in zip(soup_infos, urls):
211344
maybe_doc = self._get_document(soup_info, url)
@@ -222,9 +355,15 @@ def _get_document(
222355
content = page_content_raw.get_text(separator="\n").strip()
223356
title_if_exists = page_content_raw.find("h1")
224357
title = title_if_exists.text if title_if_exists else ""
225-
metadata = {"source": custom_url or self.web_path, "title": title}
358+
metadata = {"source": custom_url or self.web_page, "title": title}
226359
return Document(page_content=content, metadata=metadata)
227360

228361
def _get_paths(self, soup: Any) -> List[str]:
229-
"""Fetch all relative paths in the sitemap."""
230-
return [urlparse(loc.text).path for loc in soup.find_all("loc")]
362+
"""Fetch all URLs in the sitemap."""
363+
urls = []
364+
for loc in soup.find_all("loc"):
365+
if loc.text:
366+
# Instead of extracting just the path, keep the full URL
367+
# to preserve domain information
368+
urls.append(loc.text)
369+
return urls

0 commit comments

Comments
 (0)