diff --git a/doc/source/conf.py b/doc/source/conf.py index 500952a..02b60e0 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -23,7 +23,7 @@ author = "Bradley Steinfeld, Sam Prokopchuk, James Reeve" # The full version, including alpha/beta/rc tags -release = "0.20.3" +release = "0.20.4" # -- General configuration --------------------------------------------------- diff --git a/skillsnetwork/core.py b/skillsnetwork/core.py index 6e1eba6..b513e18 100644 --- a/skillsnetwork/core.py +++ b/skillsnetwork/core.py @@ -114,16 +114,27 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None] raise Exception(f"Failed to read dataset at {url}") from None -def _verify_files_dont_exist(paths: Iterable[Union[str, Path]]) -> None: +def _verify_files_dont_exist( + paths: Iterable[Union[str, Path]], remove_if_exist: bool = False +) -> None: """ Verifies all paths in 'paths' don't exist. :param paths: A iterable of strs or pathlib.Paths. + :param remove_if_exist=False: Removes file at path if they already exist. :returns: None :raises FileExistsError: On the first path found that already exists. """ for path in paths: - if Path(path).exists(): - raise FileExistsError(f"Error: File '{path}' already exists.") + path = Path(path) + if path.exists(): + if remove_if_exist: + if path.is_symlink(): + realpath = path.resolve() + path.unlink(realpath) + else: + shutil.rmtree(path) + else: + raise FileExistsError(f"Error: File '{path}' already exists.") def _is_file_to_symlink(path: Path) -> bool: @@ -188,7 +199,9 @@ async def read(url: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> bytes: return b"".join([chunk async for chunk in _get_chunks(url, chunk_size)]) -async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) -> None: +async def prepare( + url: str, path: Optional[str] = None, verbose: bool = True, overwrite: bool = False +) -> None: """ Prepares a dataset for learners. Downloads a dataset from the given url, decompresses it if necessary. If not using jupyterlite, will extract to @@ -200,6 +213,8 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) -> :param url: The URL to download the dataset from. :param path: The path the dataset will be available at. Current working directory by default. + :param verbose=True: Prints saved path if True. + :param overwrite=False: Overwrites any existing files at destination if they exist. :raise InvalidURLException: When URL is invalid. :raise FileExistsError: it raises this when a file to be symlinked already exists. :raise ValueError: When requested path is in /tmp, or cannot be saved to path. @@ -239,7 +254,8 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) -> path / child.name for child in map(Path, tf.getnames()) if len(child.parents) == 1 and _is_file_to_symlink(child) - ] + ], + overwrite, ) # Only check if top-level fileobject pbar = tqdm(iterable=tf.getmembers(), total=len(tf.getmembers())) pbar.set_description(f"Extracting {filename}") @@ -253,7 +269,8 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) -> path / child.name for child in map(Path, zf.namelist()) if len(child.parents) == 1 and _is_file_to_symlink(child) - ] + ], + overwrite, ) pbar = tqdm(iterable=zf.infolist(), total=len(zf.infolist())) pbar.set_description(f"Extracting {filename}") @@ -261,7 +278,7 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) -> zf.extract(member=member, path=extract_dir) tmp_download_file.unlink() else: - _verify_files_dont_exist([path / filename]) + _verify_files_dont_exist([path / filename], overwrite) shutil.move(tmp_download_file, extract_dir / filename) # If in jupyterlite environment, the extract_dir = path, so the files are already there. @@ -274,8 +291,36 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) -> print(f"Saved to '{relpath(path.resolve())}'") -if _is_jupyterlite(): - tqdm.monitor_interval = 0 +def setup() -> None: + if _is_jupyterlite(): + tqdm.monitor_interval = 0 + + try: + import sys # pyright: ignore + + ipython = get_ipython() + + def hide_traceback( + exc_tuple=None, + filename=None, + tb_offset=None, + exception_only=False, + running_compiled_code=False, + ): + etype, value, tb = sys.exc_info() + value.__cause__ = None # suppress chained exceptions + return ipython._showtraceback( + etype, value, ipython.InteractiveTB.get_exception_only(etype, value) + ) + + ipython.showtraceback = hide_traceback + + except NameError: + pass + + +setup() + # For backwards compatibility download_dataset = download