Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 55 additions & 22 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from monai.apps.mmars.mmars import _get_all_ngc_models
from monai.apps.utils import _basename, download_url, extractall, get_logger
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
Expand Down Expand Up @@ -247,7 +248,7 @@ def _process_bundle_dir(bundle_dir: PathLike | None = None) -> Path:
return Path(bundle_dir)


@deprecated_arg_default("source", "github", "monaihosting", since="1.3", replaced="1.4")
@deprecated_arg_default("source", "github", "monaihosting", since="1.3", replaced="1.5")
def download(
name: str | None = None,
version: str | None = None,
Expand Down Expand Up @@ -375,8 +376,9 @@ def download(
)


@deprecated_arg("net_name", since="1.3", removed="1.4", msg_suffix="please use ``model`` instead.")
@deprecated_arg("net_kwargs", since="1.3", removed="1.3", msg_suffix="please use ``model`` instead.")
@deprecated_arg("net_name", since="1.3", removed="1.5", msg_suffix="please use ``model`` instead.")
@deprecated_arg("net_kwargs", since="1.3", removed="1.5", msg_suffix="please use ``model`` instead.")
@deprecated_arg("return_state_dict", since="1.3", removed="1.5")
def load(
name: str,
model: torch.nn.Module | None = None,
Expand All @@ -395,8 +397,10 @@ def load(
workflow_name: str | BundleWorkflow | None = None,
args_file: str | None = None,
copy_model_args: dict | None = None,
return_state_dict: bool = True,
net_override: dict | None = None,
net_name: str | None = None,
**net_override: Any,
**net_kwargs: Any,
) -> object | tuple[torch.nn.Module, dict, dict] | Any:
"""
Load model weights or TorchScript module of a bundle.
Expand Down Expand Up @@ -441,7 +445,12 @@ def load(
workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
args_file: a JSON or YAML file to provide default values for all the args in "download" function.
copy_model_args: other arguments for the `monai.networks.copy_model_state` function.
net_override: id-value pairs to override the parameters in the network of the bundle.
return_state_dict: whether to return state dict, if True, return state_dict, else a corresponding network
from `_workflow.network_def` will be instantiated and load the achieved weights.
net_override: id-value pairs to override the parameters in the network of the bundle, default to `None`.
net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights.
This argument only works when loading weights.
net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`.

Returns:
1. If `load_ts_module` is `False` and `model` is `None`,
Expand All @@ -452,9 +461,15 @@ def load(
3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
the corresponding metadata dict, and extra files dict.
please check `monai.data.load_net_with_metadata` for more details.
4. If `return_state_dict` is True, return model weights, only used for compatibility
when `model` and `net_name` are all `None`.

"""
if return_state_dict and (model is not None or net_name is not None):
warnings.warn("Incompatible values: model and net_name are all specified, return state dict instead.")

bundle_dir_ = _process_bundle_dir(bundle_dir)
net_override = {} if net_override is None else net_override
copy_model_args = {} if copy_model_args is None else copy_model_args

if device is None:
Expand All @@ -466,7 +481,7 @@ def load(
if remove_prefix:
name = _remove_ngc_prefix(name, prefix=remove_prefix)
full_path = os.path.join(bundle_dir_, name, model_file)
if not os.path.exists(full_path) or model is None:
if not os.path.exists(full_path):
download(
name=name,
version=version,
Expand All @@ -477,34 +492,52 @@ def load(
progress=progress,
args_file=args_file,
)
train_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json"
if train_config_file.is_file():
_net_override = {f"network_def#{key}": value for key, value in net_override.items()}
_workflow = create_workflow(
workflow_name=workflow_name,
args_file=args_file,
config_file=str(train_config_file),
workflow_type=workflow_type,
**_net_override,
)
else:
_workflow = None

# loading with `torch.jit.load`
if load_ts_module is True:
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
# loading with `torch.load`
model_dict = torch.load(full_path, map_location=torch.device(device))

if not isinstance(model_dict, Mapping):
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
model_dict = get_state_dict(model_dict)

if model is None and _workflow is None:
if return_state_dict:
return model_dict
model = _workflow.network_def if model is None else model
model.to(device)

copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args)
_workflow = None
if model is None and net_name is None:
bundle_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json"
if bundle_config_file.is_file():
_net_override = {f"network_def#{key}": value for key, value in net_override.items()}
_workflow = create_workflow(
workflow_name=workflow_name,
args_file=args_file,
config_file=str(bundle_config_file),
workflow_type=workflow_type,
**_net_override,
)
else:
warnings.warn(f"Cannot find the config file: {bundle_config_file}, return state dict instead.")
return model_dict
if _workflow is not None:
if not hasattr(_workflow, "network_def"):
warnings.warn("No available network definition in the bundle, return state dict instead.")
return model_dict
else:
model = _workflow.network_def
elif net_name is not None:
net_kwargs["_target_"] = net_name
configer = ConfigComponent(config=net_kwargs)
model = configer.instantiate() # type: ignore

model.to(device) # type: ignore

copy_model_state(
dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args # type: ignore
)

return model


Expand Down
13 changes: 12 additions & 1 deletion monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class BundleWorkflow(ABC):
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
workflow: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.

"""

Expand All @@ -56,7 +60,8 @@ class BundleWorkflow(ABC):
new_name="workflow_type",
msg_suffix="please use `workflow_type` instead.",
)
def __init__(self, workflow_type: str | None = None):
def __init__(self, workflow_type: str | None = None, workflow: str | None = None):
workflow_type = workflow if workflow is not None else workflow_type
if workflow_type is None:
self.properties = copy(MetaProperties)
self.workflow_type = None
Expand Down Expand Up @@ -198,6 +203,10 @@ class ConfigWorkflow(BundleWorkflow):
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
workflow: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``

Expand All @@ -221,8 +230,10 @@ def __init__(
final_id: str = "finalize",
tracking: str | dict | None = None,
workflow_type: str | None = None,
workflow: str | None = None,
**override: Any,
) -> None:
workflow_type = workflow if workflow is not None else workflow_type
super().__init__(workflow_type=workflow_type)
if config_file is not None:
_config_files = ensure_tuple(config_file)
Expand Down
7 changes: 6 additions & 1 deletion tests/ngc_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download
self.assertTrue(check_hash(filepath=full_file_path, val=hash_val))

model = load(
name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix
name=bundle_name,
source="ngc",
version=version,
bundle_dir=tempdir,
remove_prefix=remove_prefix,
return_state_dict=False,
)
assert_allclose(
model.state_dict()[TESTCASE_WEIGHTS["key"]],
Expand Down
31 changes: 28 additions & 3 deletions tests/test_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
source="github",
progress=False,
device=device,
return_state_dict=True,
)

# prepare network
Expand Down Expand Up @@ -174,21 +175,44 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
bundle_dir=tempdir,
progress=False,
device=device,
net_name=model_name,
source="github",
return_state_dict=False,
)
model_2.eval()
output_2 = model_2.forward(input_tensor)
assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False)

# test compatibility with return_state_dict=True.
model_3 = load(
name=bundle_name,
model_file=model_file,
bundle_dir=tempdir,
progress=False,
device=device,
net_name=model_name,
source="github",
return_state_dict=False,
**net_args,
)
model_3.eval()
output_3 = model_3.forward(input_tensor)
assert_allclose(output_3, expected_output, atol=1e-4, rtol=1e-4, type_test=False)

@parameterized.expand([TEST_CASE_7])
@skip_if_quick
def test_load_weights_with_net_override(self, bundle_name, device, net_override):
with skip_if_downloading_fails():
# download bundle, and load weights from the downloaded path
with tempfile.TemporaryDirectory() as tempdir:
# load weights
model = load(name=bundle_name, bundle_dir=tempdir, source="monaihosting", progress=False, device=device)
model = load(
name=bundle_name,
bundle_dir=tempdir,
source="monaihosting",
progress=False,
device=device,
return_state_dict=False,
)

# prepare data and test
input_tensor = torch.rand(1, 1, 96, 96, 96).to(device)
Expand All @@ -209,7 +233,8 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
source="monaihosting",
progress=False,
device=device,
**net_override,
return_state_dict=False,
net_override=net_override,
)

# prepare data and test
Expand Down