Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 65 additions & 73 deletions skillsnetwork/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None]
pbar.update(len(value))
pbar.close()
except JsException:
raise Exception(f"Failed to read dataset at {url}") from None
raise Exception(f"Failed to read dataset at '{url}'.") from None
else:
import requests # pyright: ignore
from requests.exceptions import ConnectionError # pyright: ignore
Expand All @@ -99,7 +99,7 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None]
# If requests.get fails, it will return readable error
if response.status_code >= 400:
raise Exception(
f"received status code {response.status_code} from {url}"
f"received status code {response.status_code} from '{url}'."
)
pbar = tqdm(
miniters=1,
Expand All @@ -111,28 +111,36 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None]
pbar.update(len(chunk))
pbar.close()
except ConnectionError:
raise Exception(f"Failed to read dataset at {url}") from None
raise Exception(f"Failed to read dataset at '{url}'.") from None


def _rmrf(path: Path) -> None:
if path.is_dir():
shutil.rmtree(path)
else:
path.unlink()


def _verify_files_dont_exist(
paths: Iterable[Union[str, Path]], remove_if_exist: bool = False
paths: Iterable[Path], remove_if_exist: bool = False
) -> None:
"""
Verifies all paths in 'paths' don't exist.
:param paths: A iterable of strs or pathlib.Paths.
:param remove_if_exist=False: Removes file at path if they already exist.
:param paths: A iterable of pathlib.Path s.
:param remove_if_exist=False: Remove each file at each path in paths if they already exist.
:returns: None
:raises FileExistsError: On the first path found that already exists.
:raises FileExistsError: On the first path found that already exists if remove_if_exist is False.
"""
for path in paths:
path = Path(path)
if path.exists():
# Could be a broken symlink => path.exists() is False
if path.exists() or path.is_symlink():
if remove_if_exist:
if path.is_symlink():
realpath = path.resolve()
path.unlink(realpath)
else:
shutil.rmtree(path)
while path.is_symlink():
temp = path.readlink()
path.unlink(missing_ok=True)
path = temp
if path.exists():
_rmrf(path)
else:
raise FileExistsError(f"Error: File '{path}' already exists.")

Expand Down Expand Up @@ -224,14 +232,13 @@ async def prepare(
path = Path.cwd() if path is None else Path(path)
# Check if path contains /tmp
if Path("/tmp") in path.parents:
raise ValueError("path must not be in /tmp")
raise ValueError("path must not be in /tmp.")
elif path.is_file():
raise ValueError("Datasets must be prepared to directories, not files")
raise ValueError("Datasets must be prepared to directories, not files.")
# Create the target path if it doesn't exist yet
path.mkdir(exist_ok=True)

# For avoiding collisions with any other files the user may have downloaded to /tmp/

dname = f"skills-network-{hash(url)}"
# The file to extract data to. If not jupyterlite, to be symlinked to as well
extract_dir = path if _is_jupyterlite() else Path(f"/tmp/{dname}")
Expand All @@ -247,44 +254,52 @@ async def prepare(
shutil.rmtree(extract_dir)
extract_dir.mkdir()

if tarfile.is_tarfile(tmp_download_file):
with tarfile.open(tmp_download_file) as tf:
_verify_files_dont_exist(
[
path / child.name
for child in map(Path, tf.getnames())
if len(child.parents) == 1 and _is_file_to_symlink(child)
],
overwrite,
) # Only check if top-level fileobject
pbar = tqdm(iterable=tf.getmembers(), total=len(tf.getmembers()))
pbar.set_description(f"Extracting {filename}")
for member in pbar:
tf.extract(member=member, path=extract_dir)
tmp_download_file.unlink()
elif zipfile.is_zipfile(tmp_download_file):
with zipfile.ZipFile(tmp_download_file) as zf:
_verify_files_dont_exist(
[
path / child.name
for child in map(Path, zf.namelist())
if len(child.parents) == 1 and _is_file_to_symlink(child)
],
overwrite,
)
pbar = tqdm(iterable=zf.infolist(), total=len(zf.infolist()))
pbar.set_description(f"Extracting {filename}")
for member in pbar:
zf.extract(member=member, path=extract_dir)
tmp_download_file.unlink()
else:
_verify_files_dont_exist([path / filename], overwrite)
shutil.move(tmp_download_file, extract_dir / filename)
try:
if tarfile.is_tarfile(tmp_download_file):
with tarfile.open(tmp_download_file) as tf:
_verify_files_dont_exist(
[
path / child.name
for child in map(Path, tf.getnames())
if len(child.parents) == 1 and _is_file_to_symlink(child)
], # Only check if top-level fileobject
remove_if_exist=overwrite,
)
pbar = tqdm(iterable=tf.getmembers(), total=len(tf.getmembers()))
pbar.set_description(f"Extracting {filename}")
for member in pbar:
tf.extract(member=member, path=extract_dir)
tmp_download_file.unlink()
elif zipfile.is_zipfile(tmp_download_file):
with zipfile.ZipFile(tmp_download_file) as zf:
_verify_files_dont_exist(
[
path / child.name
for child in map(Path, zf.namelist())
if len(child.parents) == 1 and _is_file_to_symlink(child)
], # Only check if top-level fileobject
remove_if_exist=overwrite,
)
pbar = tqdm(iterable=zf.infolist(), total=len(zf.infolist()))
pbar.set_description(f"Extracting {filename}")
for member in pbar:
zf.extract(member=member, path=extract_dir)
tmp_download_file.unlink()
else:
_verify_files_dont_exist([path / filename], remove_if_exist=overwrite)
shutil.move(tmp_download_file, extract_dir / filename)
except FileExistsError as e:
raise FileExistsError(
str(e)
+ "\nIf you want to overwrite any existing files, use prepare(..., overwrite=True)."
) from None

# If in jupyterlite environment, the extract_dir = path, so the files are already there.
if not _is_jupyterlite():
# If not in jupyterlite environment, symlink top-level file objects in extract_dir
for child in filter(_is_file_to_symlink, extract_dir.iterdir()):
if (path / child.name).is_symlink() and overwrite:
(path / child.name).unlink()
(path / child.name).symlink_to(child, target_is_directory=child.is_dir())

if verbose:
Expand All @@ -295,29 +310,6 @@ def setup() -> None:
if _is_jupyterlite():
tqdm.monitor_interval = 0

try:
import sys # pyright: ignore

ipython = get_ipython()

def hide_traceback(
exc_tuple=None,
filename=None,
tb_offset=None,
exception_only=False,
running_compiled_code=False,
):
etype, value, tb = sys.exc_info()
value.__cause__ = None # suppress chained exceptions
return ipython._showtraceback(
etype, value, ipython.InteractiveTB.get_exception_only(etype, value)
)

ipython.showtraceback = hide_traceback

except NameError:
pass


setup()

Expand Down
73 changes: 73 additions & 0 deletions tests/test_skillsnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,76 @@ async def test_prepare_non_compressed_dataset_with_path(httpserver):
await skillsnetwork.prepare_dataset(httpserver.url_for(url), path=path)
assert expected_path.exists()
expected_path.unlink()


@pytest.mark.asyncio
async def test_prepare_non_compressed_dataset_no_path_with_overwrite(httpserver):
url = "/test.csv"
expected_path = Path("./test.csv")
with open("tests/test.csv", "rb") as expected_data:
httpserver.expect_request(url).respond_with_data(expected_data)
await skillsnetwork.prepare_dataset(httpserver.url_for(url))
assert expected_path.exists()
httpserver.clear()
with open("tests/test.csv", "rb") as expected_data:
httpserver.expect_request(url).respond_with_data(expected_data)
await skillsnetwork.prepare_dataset(httpserver.url_for(url), overwrite=True)
assert expected_path.exists()
assert Path(expected_path).stat().st_size == 540
expected_path.unlink()


@pytest.mark.asyncio
async def test_prepare_dataset_tar_no_path_with_overwrite(httpserver):
url = "/test.tar.gz"
expected_directory = Path("test")
try:
shutil.rmtree(expected_directory) # clean up any previous test
except FileNotFoundError as e:
print(e)
pass

with open("tests/test.tar.gz", "rb") as expected_data:
httpserver.expect_request(url).respond_with_data(expected_data)
await skillsnetwork.prepare_dataset(httpserver.url_for(url))

assert os.path.isdir(expected_directory)
with open(expected_directory / "1.txt") as f:
assert "I am the first test file" in f.read()
httpserver.clear()

with open("tests/test.tar.gz", "rb") as expected_data:
httpserver.expect_request(url).respond_with_data(expected_data)
await skillsnetwork.prepare_dataset(httpserver.url_for(url), overwrite=True)
assert os.path.isdir(expected_directory)
with open(expected_directory / "1.txt") as f:
assert "I am the first test file" in f.read()
expected_directory.unlink()


@pytest.mark.asyncio
async def test_prepare_dataset_zip_no_path_with_overwrite(httpserver):
url = "/test.zip"
expected_directory = Path("test")
try:
shutil.rmtree(expected_directory) # clean up any previous test
except FileNotFoundError as e:
print(e)
pass

with open("tests/test.zip", "rb") as expected_data:
httpserver.expect_request(url).respond_with_data(expected_data)
await skillsnetwork.prepare_dataset(httpserver.url_for(url))

assert os.path.isdir(expected_directory)
with open(expected_directory / "1.txt") as f:
assert "I am the first test file" in f.read()
httpserver.clear()

with open("tests/test.zip", "rb") as expected_data:
httpserver.expect_request(url).respond_with_data(expected_data)
await skillsnetwork.prepare_dataset(httpserver.url_for(url), overwrite=True)
assert os.path.isdir(expected_directory)
with open(expected_directory / "1.txt") as f:
assert "I am the first test file" in f.read()
expected_directory.unlink()