Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
author = "Bradley Steinfeld, Sam Prokopchuk, James Reeve"

# The full version, including alpha/beta/rc tags
release = "0.20.3"
release = "0.20.4"


# -- General configuration ---------------------------------------------------
Expand Down
63 changes: 54 additions & 9 deletions skillsnetwork/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,27 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None]
raise Exception(f"Failed to read dataset at {url}") from None


def _verify_files_dont_exist(paths: Iterable[Union[str, Path]]) -> None:
def _verify_files_dont_exist(
paths: Iterable[Union[str, Path]], remove_if_exist: bool = False
) -> None:
"""
Verifies all paths in 'paths' don't exist.
:param paths: A iterable of strs or pathlib.Paths.
:param remove_if_exist=False: Removes file at path if they already exist.
:returns: None
:raises FileExistsError: On the first path found that already exists.
"""
for path in paths:
if Path(path).exists():
raise FileExistsError(f"Error: File '{path}' already exists.")
path = Path(path)
if path.exists():
if remove_if_exist:
if path.is_symlink():
realpath = path.resolve()
path.unlink(realpath)
else:
shutil.rmtree(path)
else:
raise FileExistsError(f"Error: File '{path}' already exists.")


def _is_file_to_symlink(path: Path) -> bool:
Expand Down Expand Up @@ -188,7 +199,9 @@ async def read(url: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> bytes:
return b"".join([chunk async for chunk in _get_chunks(url, chunk_size)])


async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) -> None:
async def prepare(
url: str, path: Optional[str] = None, verbose: bool = True, overwrite: bool = False
) -> None:
"""
Prepares a dataset for learners. Downloads a dataset from the given url,
decompresses it if necessary. If not using jupyterlite, will extract to
Expand All @@ -200,6 +213,8 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->

:param url: The URL to download the dataset from.
:param path: The path the dataset will be available at. Current working directory by default.
:param verbose=True: Prints saved path if True.
:param overwrite=False: Overwrites any existing files at destination if they exist.
:raise InvalidURLException: When URL is invalid.
:raise FileExistsError: it raises this when a file to be symlinked already exists.
:raise ValueError: When requested path is in /tmp, or cannot be saved to path.
Expand Down Expand Up @@ -239,7 +254,8 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
path / child.name
for child in map(Path, tf.getnames())
if len(child.parents) == 1 and _is_file_to_symlink(child)
]
],
overwrite,
) # Only check if top-level fileobject
pbar = tqdm(iterable=tf.getmembers(), total=len(tf.getmembers()))
pbar.set_description(f"Extracting {filename}")
Expand All @@ -253,15 +269,16 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
path / child.name
for child in map(Path, zf.namelist())
if len(child.parents) == 1 and _is_file_to_symlink(child)
]
],
overwrite,
)
pbar = tqdm(iterable=zf.infolist(), total=len(zf.infolist()))
pbar.set_description(f"Extracting {filename}")
for member in pbar:
zf.extract(member=member, path=extract_dir)
tmp_download_file.unlink()
else:
_verify_files_dont_exist([path / filename])
_verify_files_dont_exist([path / filename], overwrite)
shutil.move(tmp_download_file, extract_dir / filename)

# If in jupyterlite environment, the extract_dir = path, so the files are already there.
Expand All @@ -274,8 +291,36 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
print(f"Saved to '{relpath(path.resolve())}'")


if _is_jupyterlite():
tqdm.monitor_interval = 0
def setup() -> None:
if _is_jupyterlite():
tqdm.monitor_interval = 0

try:
import sys # pyright: ignore

ipython = get_ipython()

def hide_traceback(
exc_tuple=None,
filename=None,
tb_offset=None,
exception_only=False,
running_compiled_code=False,
):
etype, value, tb = sys.exc_info()
value.__cause__ = None # suppress chained exceptions
return ipython._showtraceback(
etype, value, ipython.InteractiveTB.get_exception_only(etype, value)
)

ipython.showtraceback = hide_traceback

except NameError:
pass


setup()


# For backwards compatibility
download_dataset = download
Expand Down