Skip to content

Commit 28643c8

Browse files
committed
feat: return model_paths in fetch_models_task
Signed-off-by: Anupam Kumar <[email protected]>
1 parent 22cbea9 commit 28643c8

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

nc_py_api/ex_app/integration_fastapi.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,29 +120,35 @@ def __map_app_static_folders(fast_api_app: FastAPI):
120120
fast_api_app.mount(f"/{mnt_dir}", staticfiles.StaticFiles(directory=mnt_dir_path), name=mnt_dir)
121121

122122

123-
def fetch_models_task(nc: NextcloudApp, models: dict[str, dict], progress_init_start_value: int) -> None:
123+
def fetch_models_task(
124+
nc: NextcloudApp, models: dict[str, dict], progress_init_start_value: int
125+
) -> dict[str, str | None]:
124126
"""Use for cases when you want to define custom `/init` but still need to easy download models."""
127+
model_paths = {}
125128
if models:
126129
current_progress = progress_init_start_value
127130
percent_for_each = min(int((100 - progress_init_start_value) / len(models)), 99)
128131
for model in models:
129132
if model.startswith(("http://", "https://")):
130-
__fetch_model_as_file(current_progress, percent_for_each, nc, model, models[model])
133+
model_paths[model] = __fetch_model_as_file(current_progress, percent_for_each, nc, model, models[model])
131134
else:
132-
__fetch_model_as_snapshot(current_progress, percent_for_each, nc, model, models[model])
135+
model_paths[model] = __fetch_model_as_snapshot(
136+
current_progress, percent_for_each, nc, model, models[model]
137+
)
133138
current_progress += percent_for_each
134139
nc.set_init_status(100)
140+
return model_paths
135141

136142

137143
def __fetch_model_as_file(
138144
current_progress: int, progress_for_task: int, nc: NextcloudApp, model_path: str, download_options: dict
139-
) -> None:
145+
) -> str | None:
140146
result_path = download_options.pop("save_path", urlparse(model_path).path.split("/")[-1])
141147
try:
142148
with httpx.stream("GET", model_path, follow_redirects=True) as response:
143149
if not response.is_success:
144150
nc.log(LogLvl.ERROR, f"Downloading of '{model_path}' returned {response.status_code} status.")
145-
return
151+
return None
146152
downloaded_size = 0
147153
linked_etag = ""
148154
for each_history in response.history:
@@ -163,7 +169,7 @@ def __fetch_model_as_file(
163169
sha256_hash.update(byte_block)
164170
if f'"{sha256_hash.hexdigest()}"' == linked_etag:
165171
nc.set_init_status(min(current_progress + progress_for_task, 99))
166-
return
172+
return None
167173

168174
with builtins.open(result_path, "wb") as file:
169175
last_progress = current_progress
@@ -174,13 +180,15 @@ def __fetch_model_as_file(
174180
if new_progress != last_progress:
175181
nc.set_init_status(new_progress)
176182
last_progress = new_progress
183+
184+
return result_path
177185
except Exception as e: # noqa pylint: disable=broad-exception-caught
178186
nc.log(LogLvl.ERROR, f"Downloading of '{model_path}' raised an exception: {e}")
179187

180188

181189
def __fetch_model_as_snapshot(
182190
current_progress: int, progress_for_task, nc: NextcloudApp, mode_name: str, download_options: dict
183-
) -> None:
191+
) -> str:
184192
from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
185193
from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401
186194

@@ -191,7 +199,9 @@ def display(self, msg=None, pos=None):
191199

192200
workers = download_options.pop("max_workers", 2)
193201
cache = download_options.pop("cache_dir", persistent_storage())
194-
snapshot_download(mode_name, tqdm_class=TqdmProgress, **download_options, max_workers=workers, cache_dir=cache)
202+
return snapshot_download(
203+
mode_name, tqdm_class=TqdmProgress, **download_options, max_workers=workers, cache_dir=cache
204+
)
195205

196206

197207
def __nc_app(request: HTTPConnection) -> dict:

0 commit comments

Comments
 (0)