Skip to content

Commit 1926973

Browse files
kyteinskybigcat88
andauthored
feat: store model download paths in "path" key (#274)
Models downloaded wrt `models_to_fetch` dict have the output path of the downloaded model files in the "path" key for each model. Signed-off-by: Anupam Kumar <[email protected]> Co-authored-by: Alexander Piskun <[email protected]>
1 parent 1bcd2b9 commit 1926973

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

nc_py_api/ex_app/integration_fastapi.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,22 +127,26 @@ def fetch_models_task(nc: NextcloudApp, models: dict[str, dict], progress_init_s
127127
percent_for_each = min(int((100 - progress_init_start_value) / len(models)), 99)
128128
for model in models:
129129
if model.startswith(("http://", "https://")):
130-
__fetch_model_as_file(current_progress, percent_for_each, nc, model, models[model])
130+
models[model]["path"] = __fetch_model_as_file(
131+
current_progress, percent_for_each, nc, model, models[model]
132+
)
131133
else:
132-
__fetch_model_as_snapshot(current_progress, percent_for_each, nc, model, models[model])
134+
models[model]["path"] = __fetch_model_as_snapshot(
135+
current_progress, percent_for_each, nc, model, models[model]
136+
)
133137
current_progress += percent_for_each
134138
nc.set_init_status(100)
135139

136140

137141
def __fetch_model_as_file(
138142
current_progress: int, progress_for_task: int, nc: NextcloudApp, model_path: str, download_options: dict
139-
) -> None:
143+
) -> str | None:
140144
result_path = download_options.pop("save_path", urlparse(model_path).path.split("/")[-1])
141145
try:
142146
with httpx.stream("GET", model_path, follow_redirects=True) as response:
143147
if not response.is_success:
144148
nc.log(LogLvl.ERROR, f"Downloading of '{model_path}' returned {response.status_code} status.")
145-
return
149+
return None
146150
downloaded_size = 0
147151
linked_etag = ""
148152
for each_history in response.history:
@@ -163,7 +167,7 @@ def __fetch_model_as_file(
163167
sha256_hash.update(byte_block)
164168
if f'"{sha256_hash.hexdigest()}"' == linked_etag:
165169
nc.set_init_status(min(current_progress + progress_for_task, 99))
166-
return
170+
return None
167171

168172
with builtins.open(result_path, "wb") as file:
169173
last_progress = current_progress
@@ -174,13 +178,17 @@ def __fetch_model_as_file(
174178
if new_progress != last_progress:
175179
nc.set_init_status(new_progress)
176180
last_progress = new_progress
181+
182+
return result_path
177183
except Exception as e: # noqa pylint: disable=broad-exception-caught
178184
nc.log(LogLvl.ERROR, f"Downloading of '{model_path}' raised an exception: {e}")
179185

186+
return None
187+
180188

181189
def __fetch_model_as_snapshot(
182190
current_progress: int, progress_for_task, nc: NextcloudApp, mode_name: str, download_options: dict
183-
) -> None:
191+
) -> str:
184192
from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
185193
from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401
186194

@@ -191,7 +199,9 @@ def display(self, msg=None, pos=None):
191199

192200
workers = download_options.pop("max_workers", 2)
193201
cache = download_options.pop("cache_dir", persistent_storage())
194-
snapshot_download(mode_name, tqdm_class=TqdmProgress, **download_options, max_workers=workers, cache_dir=cache)
202+
return snapshot_download(
203+
mode_name, tqdm_class=TqdmProgress, **download_options, max_workers=workers, cache_dir=cache
204+
)
195205

196206

197207
def __nc_app(request: HTTPConnection) -> dict:

0 commit comments

Comments
 (0)