diff --git a/skillsnetwork/core.py b/skillsnetwork/core.py index b513e18..4b02748 100644 --- a/skillsnetwork/core.py +++ b/skillsnetwork/core.py @@ -89,7 +89,7 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None] pbar.update(len(value)) pbar.close() except JsException: - raise Exception(f"Failed to read dataset at {url}") from None + raise Exception(f"Failed to read dataset at '{url}'.") from None else: import requests # pyright: ignore from requests.exceptions import ConnectionError # pyright: ignore @@ -99,7 +99,7 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None] # If requests.get fails, it will return readable error if response.status_code >= 400: raise Exception( - f"received status code {response.status_code} from {url}" + f"received status code {response.status_code} from '{url}'." ) pbar = tqdm( miniters=1, @@ -111,28 +111,36 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None] pbar.update(len(chunk)) pbar.close() except ConnectionError: - raise Exception(f"Failed to read dataset at {url}") from None + raise Exception(f"Failed to read dataset at '{url}'.") from None + + +def _rmrf(path: Path) -> None: + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink() def _verify_files_dont_exist( - paths: Iterable[Union[str, Path]], remove_if_exist: bool = False + paths: Iterable[Path], remove_if_exist: bool = False ) -> None: """ Verifies all paths in 'paths' don't exist. - :param paths: A iterable of strs or pathlib.Paths. - :param remove_if_exist=False: Removes file at path if they already exist. + :param paths: A iterable of pathlib.Path s. + :param remove_if_exist=False: Remove each file at each path in paths if they already exist. :returns: None - :raises FileExistsError: On the first path found that already exists. + :raises FileExistsError: On the first path found that already exists if remove_if_exist is False. """ for path in paths: - path = Path(path) - if path.exists(): + # Could be a broken symlink => path.exists() is False + if path.exists() or path.is_symlink(): if remove_if_exist: - if path.is_symlink(): - realpath = path.resolve() - path.unlink(realpath) - else: - shutil.rmtree(path) + while path.is_symlink(): + temp = path.readlink() + path.unlink(missing_ok=True) + path = temp + if path.exists(): + _rmrf(path) else: raise FileExistsError(f"Error: File '{path}' already exists.") @@ -224,14 +232,13 @@ async def prepare( path = Path.cwd() if path is None else Path(path) # Check if path contains /tmp if Path("/tmp") in path.parents: - raise ValueError("path must not be in /tmp") + raise ValueError("path must not be in /tmp.") elif path.is_file(): - raise ValueError("Datasets must be prepared to directories, not files") + raise ValueError("Datasets must be prepared to directories, not files.") # Create the target path if it doesn't exist yet path.mkdir(exist_ok=True) # For avoiding collisions with any other files the user may have downloaded to /tmp/ - dname = f"skills-network-{hash(url)}" # The file to extract data to. If not jupyterlite, to be symlinked to as well extract_dir = path if _is_jupyterlite() else Path(f"/tmp/{dname}") @@ -247,44 +254,52 @@ async def prepare( shutil.rmtree(extract_dir) extract_dir.mkdir() - if tarfile.is_tarfile(tmp_download_file): - with tarfile.open(tmp_download_file) as tf: - _verify_files_dont_exist( - [ - path / child.name - for child in map(Path, tf.getnames()) - if len(child.parents) == 1 and _is_file_to_symlink(child) - ], - overwrite, - ) # Only check if top-level fileobject - pbar = tqdm(iterable=tf.getmembers(), total=len(tf.getmembers())) - pbar.set_description(f"Extracting {filename}") - for member in pbar: - tf.extract(member=member, path=extract_dir) - tmp_download_file.unlink() - elif zipfile.is_zipfile(tmp_download_file): - with zipfile.ZipFile(tmp_download_file) as zf: - _verify_files_dont_exist( - [ - path / child.name - for child in map(Path, zf.namelist()) - if len(child.parents) == 1 and _is_file_to_symlink(child) - ], - overwrite, - ) - pbar = tqdm(iterable=zf.infolist(), total=len(zf.infolist())) - pbar.set_description(f"Extracting {filename}") - for member in pbar: - zf.extract(member=member, path=extract_dir) - tmp_download_file.unlink() - else: - _verify_files_dont_exist([path / filename], overwrite) - shutil.move(tmp_download_file, extract_dir / filename) + try: + if tarfile.is_tarfile(tmp_download_file): + with tarfile.open(tmp_download_file) as tf: + _verify_files_dont_exist( + [ + path / child.name + for child in map(Path, tf.getnames()) + if len(child.parents) == 1 and _is_file_to_symlink(child) + ], # Only check if top-level fileobject + remove_if_exist=overwrite, + ) + pbar = tqdm(iterable=tf.getmembers(), total=len(tf.getmembers())) + pbar.set_description(f"Extracting {filename}") + for member in pbar: + tf.extract(member=member, path=extract_dir) + tmp_download_file.unlink() + elif zipfile.is_zipfile(tmp_download_file): + with zipfile.ZipFile(tmp_download_file) as zf: + _verify_files_dont_exist( + [ + path / child.name + for child in map(Path, zf.namelist()) + if len(child.parents) == 1 and _is_file_to_symlink(child) + ], # Only check if top-level fileobject + remove_if_exist=overwrite, + ) + pbar = tqdm(iterable=zf.infolist(), total=len(zf.infolist())) + pbar.set_description(f"Extracting {filename}") + for member in pbar: + zf.extract(member=member, path=extract_dir) + tmp_download_file.unlink() + else: + _verify_files_dont_exist([path / filename], remove_if_exist=overwrite) + shutil.move(tmp_download_file, extract_dir / filename) + except FileExistsError as e: + raise FileExistsError( + str(e) + + "\nIf you want to overwrite any existing files, use prepare(..., overwrite=True)." + ) from None # If in jupyterlite environment, the extract_dir = path, so the files are already there. if not _is_jupyterlite(): # If not in jupyterlite environment, symlink top-level file objects in extract_dir for child in filter(_is_file_to_symlink, extract_dir.iterdir()): + if (path / child.name).is_symlink() and overwrite: + (path / child.name).unlink() (path / child.name).symlink_to(child, target_is_directory=child.is_dir()) if verbose: @@ -295,29 +310,6 @@ def setup() -> None: if _is_jupyterlite(): tqdm.monitor_interval = 0 - try: - import sys # pyright: ignore - - ipython = get_ipython() - - def hide_traceback( - exc_tuple=None, - filename=None, - tb_offset=None, - exception_only=False, - running_compiled_code=False, - ): - etype, value, tb = sys.exc_info() - value.__cause__ = None # suppress chained exceptions - return ipython._showtraceback( - etype, value, ipython.InteractiveTB.get_exception_only(etype, value) - ) - - ipython.showtraceback = hide_traceback - - except NameError: - pass - setup() diff --git a/tests/test_skillsnetwork.py b/tests/test_skillsnetwork.py index 90312c4..40ec2ed 100644 --- a/tests/test_skillsnetwork.py +++ b/tests/test_skillsnetwork.py @@ -134,3 +134,76 @@ async def test_prepare_non_compressed_dataset_with_path(httpserver): await skillsnetwork.prepare_dataset(httpserver.url_for(url), path=path) assert expected_path.exists() expected_path.unlink() + + +@pytest.mark.asyncio +async def test_prepare_non_compressed_dataset_no_path_with_overwrite(httpserver): + url = "/test.csv" + expected_path = Path("./test.csv") + with open("tests/test.csv", "rb") as expected_data: + httpserver.expect_request(url).respond_with_data(expected_data) + await skillsnetwork.prepare_dataset(httpserver.url_for(url)) + assert expected_path.exists() + httpserver.clear() + with open("tests/test.csv", "rb") as expected_data: + httpserver.expect_request(url).respond_with_data(expected_data) + await skillsnetwork.prepare_dataset(httpserver.url_for(url), overwrite=True) + assert expected_path.exists() + assert Path(expected_path).stat().st_size == 540 + expected_path.unlink() + + +@pytest.mark.asyncio +async def test_prepare_dataset_tar_no_path_with_overwrite(httpserver): + url = "/test.tar.gz" + expected_directory = Path("test") + try: + shutil.rmtree(expected_directory) # clean up any previous test + except FileNotFoundError as e: + print(e) + pass + + with open("tests/test.tar.gz", "rb") as expected_data: + httpserver.expect_request(url).respond_with_data(expected_data) + await skillsnetwork.prepare_dataset(httpserver.url_for(url)) + + assert os.path.isdir(expected_directory) + with open(expected_directory / "1.txt") as f: + assert "I am the first test file" in f.read() + httpserver.clear() + + with open("tests/test.tar.gz", "rb") as expected_data: + httpserver.expect_request(url).respond_with_data(expected_data) + await skillsnetwork.prepare_dataset(httpserver.url_for(url), overwrite=True) + assert os.path.isdir(expected_directory) + with open(expected_directory / "1.txt") as f: + assert "I am the first test file" in f.read() + expected_directory.unlink() + + +@pytest.mark.asyncio +async def test_prepare_dataset_zip_no_path_with_overwrite(httpserver): + url = "/test.zip" + expected_directory = Path("test") + try: + shutil.rmtree(expected_directory) # clean up any previous test + except FileNotFoundError as e: + print(e) + pass + + with open("tests/test.zip", "rb") as expected_data: + httpserver.expect_request(url).respond_with_data(expected_data) + await skillsnetwork.prepare_dataset(httpserver.url_for(url)) + + assert os.path.isdir(expected_directory) + with open(expected_directory / "1.txt") as f: + assert "I am the first test file" in f.read() + httpserver.clear() + + with open("tests/test.zip", "rb") as expected_data: + httpserver.expect_request(url).respond_with_data(expected_data) + await skillsnetwork.prepare_dataset(httpserver.url_for(url), overwrite=True) + assert os.path.isdir(expected_directory) + with open(expected_directory / "1.txt") as f: + assert "I am the first test file" in f.read() + expected_directory.unlink()