Skip to content

Commit 6232dab

Browse files
feat(app): better errors when scanning models with picklescan
Differentiate between malware detection and scan error.
1 parent 28d3356 commit 6232dab

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

invokeai/app/services/model_load/model_load_default.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,11 @@ def load_model_from_path(
8585

8686
def torch_load_file(checkpoint: Path) -> AnyModel:
8787
scan_result = scan_file_path(checkpoint)
88-
if scan_result.infected_files != 0 or scan_result.scan_err:
89-
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
88+
if scan_result.infected_files != 0:
89+
raise Exception(f"The model at {checkpoint} is potentially infected by malware. Aborting load.")
90+
if scan_result.scan_err:
91+
raise Exception(f"Error scanning model at {checkpoint} for malware. Aborting load.")
92+
9093
result = torch_load(checkpoint, map_location="cpu")
9194
return result
9295

invokeai/backend/model_manager/probe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,10 @@ def _scan_model(cls, model_name: str, checkpoint: Path) -> None:
483483
"""
484484
# scan model
485485
scan_result = scan_file_path(checkpoint)
486-
if scan_result.infected_files != 0 or scan_result.scan_err:
487-
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
486+
if scan_result.infected_files != 0:
487+
raise Exception(f"The model {model_name} is potentially infected by malware. Aborting import.")
488+
if scan_result.scan_err:
489+
raise Exception(f"Error scanning model {model_name} for malware. Aborting import.")
488490

489491

490492
# Probing utilities

invokeai/backend/model_manager/util/model_util.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,11 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str,
5858
else:
5959
if scan:
6060
scan_result = scan_file_path(path)
61-
if scan_result.infected_files != 0 or scan_result.scan_err:
62-
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
61+
if scan_result.infected_files != 0:
62+
raise Exception(f"The model at {path} is potentially infected by malware. Aborting import.")
63+
if scan_result.scan_err:
64+
raise Exception(f"Error scanning model at {path} for malware. Aborting import.")
65+
6366
checkpoint = torch.load(path, map_location=torch.device("meta"))
6467
return checkpoint
6568

0 commit comments

Comments
 (0)