Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
054a2b8
Added TRTWrapper
borisfom Aug 5, 2024
3ab9c83
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
4ec1d3b
Merge branch 'dev' into trt-wrappers
KumoLiu Aug 5, 2024
fe71030
Addressing code review comments, adding docustrings, cleanup
borisfom Aug 5, 2024
6a9727f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
29d9725
Added TRT 10.3RC to Dockerfile
borisfom Aug 6, 2024
5b8b4f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
f31d6dd
Workaround for format check
borisfom Aug 6, 2024
9303c32
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
c1d0b19
More format check workarounds
borisfom Aug 6, 2024
63c4b70
More format check workarounds
borisfom Aug 6, 2024
9a3d6a6
More format check workarounds
borisfom Aug 6, 2024
8bf0300
Using optional exports for trt_utils
borisfom Aug 6, 2024
c03e49b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
39c94c2
Fixing lint errors
borisfom Aug 6, 2024
35dffcc
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 6, 2024
9d867a7
Format fixed
borisfom Aug 6, 2024
6e2733a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
848a42d
Fixing flake errors
borisfom Aug 6, 2024
9ade6af
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 6, 2024
cf2c3b1
Fixing CI
borisfom Aug 6, 2024
e8b51f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
ddb5bc8
Fixed mypy, Engine refactor
borisfom Aug 7, 2024
79014d7
Merge branch 'dev' into trt-wrappers
yiheng-wang-nv Aug 7, 2024
511081f
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 7, 2024
b188237
Merged cast_utils, copyrights fixed.
borisfom Aug 8, 2024
60cdd74
Added unit test
borisfom Aug 8, 2024
778a44a
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 9, 2024
0ab5d26
TRTWrapper moved to networks
borisfom Aug 9, 2024
a948bfb
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 9, 2024
3a72c76
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 9, 2024
7d449f5
Refactored TRTWrapper args
borisfom Aug 10, 2024
6846fd4
Added docstring for precision
borisfom Aug 10, 2024
d598590
Fixed comments, reordered args
borisfom Aug 11, 2024
9109d3f
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 12, 2024
517c111
Reduced test assert accuracy
borisfom Aug 12, 2024
4739756
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 14, 2024
ed0d93d
Addressing code review comments
borisfom Aug 14, 2024
2ec8e53
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 15, 2024
fdcf118
Added Torch-TRT option, cleaned up engine save method
borisfom Aug 15, 2024
1009dc5
Added trt_wrap adapter
borisfom Aug 16, 2024
763f769
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 16, 2024
fd679c0
Refined trt_wrap
borisfom Aug 16, 2024
dc13b52
Used tempdir for ONNX
borisfom Aug 17, 2024
779de92
Refactored trt wrapper, added trt handler
borisfom Aug 18, 2024
6504dc9
Adjusted refactor for use in config
borisfom Aug 18, 2024
c1be72c
Added fold constant threshold param
borisfom Aug 20, 2024
0f16b8b
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 20, 2024
5c495b6
Logger refactoring
borisfom Aug 20, 2024
5d1ebc2
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 20, 2024
48b85ce
Addressing code review comments
borisfom Aug 22, 2024
1244c49
Added multiple submodules option to trt_wrap
borisfom Aug 22, 2024
a603f13
Added polygraphy to more places, torch-tensorrt option debugging
borisfom Aug 23, 2024
f5be0cc
Renamed trt_wrap -> trt_compile
borisfom Aug 23, 2024
b96ebb4
Reformatted for CI
borisfom Aug 23, 2024
73be701
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 23, 2024
85140e2
Fixed alias issue
borisfom Aug 23, 2024
fa4c182
Fixed base in Dockerfile
borisfom Aug 23, 2024
78a3ef3
Fixed CI test failures
borisfom Aug 23, 2024
267c125
Addressed code review comments
borisfom Aug 23, 2024
9adc035
Added dictionary return option
borisfom Aug 26, 2024
a017fcd
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 26, 2024
7f1c0c1
Fixed return_dict issue
borisfom Aug 26, 2024
a242a64
Implemented https://github.com/Project-MONAI/MONAI/issues/8044
borisfom Aug 26, 2024
5afc912
Generalizing merge logic, adding test case and doc
borisfom Aug 27, 2024
ceff018
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
e294968
Addressing code review comments
borisfom Aug 27, 2024
55cf7fa
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 27, 2024
b6d9179
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
652448a
doc build fixed
borisfom Aug 28, 2024
b793eb2
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 28, 2024
5c4f63a
Fixed formatting
borisfom Aug 28, 2024
dd91183
Fixed formatting
borisfom Aug 28, 2024
c41cb5a
Updated base container to 24.08
borisfom Aug 28, 2024
7e440fc
Renaming trt_wrapper -> trt_compiler, adding TRT handler test
borisfom Aug 28, 2024
329d024
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2024
84de860
fixing CI error
borisfom Aug 28, 2024
875d1a8
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 28, 2024
b84cec4
Fixing min test error, addressing comments
borisfom Aug 28, 2024
9481d9f
optional propagation of dynamo arg fixed, onnx_graphsurgeon package a…
borisfom Aug 28, 2024
6a11581
add vista test cases
yiheng-wang-nv Aug 28, 2024
6e8bd6b
Merge branch 'dev' into trt-wrappers
yiheng-wang-nv Aug 28, 2024
3d221cb
Merge branch 'dev' into trt-wrappers
KumoLiu Aug 28, 2024
792a721
Code review input addressed
borisfom Aug 29, 2024
73ac717
Fixed torch-tensorrt path of trt_compile, added test
borisfom Aug 29, 2024
ea879f2
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 29, 2024
1e7e76d
Fixing tests
borisfom Aug 29, 2024
47e676e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2024
dd4d2d6
Merge branch 'dev' into trt-wrappers
binliunls Aug 29, 2024
6b47a8b
Merge branch 'dev' into trt-wrappers
KumoLiu Aug 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@ RUN apt-get update \
&& rm -rf /var/lib/apt/lists/*
# append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations
ENV PATH=${PATH}:/opt/tools
ENV POLYGRAPHY_AUTOINSTALL_DEPS=1
WORKDIR /opt/monai
42 changes: 42 additions & 0 deletions docs/source/config_syntax.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Content:
- [`$` to evaluate as Python expressions](#to-evaluate-as-python-expressions)
- [`%` to textually replace configuration elements](#to-textually-replace-configuration-elements)
- [`_target_` (`_disabled_`, `_desc_`, `_requires_`, `_mode_`) to instantiate a Python object](#instantiate-a-python-object)
- [`+` to alter semantics of merging config keys from multiple configuration files](#multiple-config-files)
- [The command line interface](#the-command-line-interface)
- [Recommendations](#recommendations)

Expand Down Expand Up @@ -175,6 +176,47 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k
- `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``,
see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall).

## Multiple config files

_Description:_ Multiple config files may be specified on the command line.
The content of those config files is being merged. When same keys are specifiled in more than one config file,
the value associated with the key is being overridden, in the order config files are specified.
If the desired behaviour is to merge values from both files, the key in second config file should be prefixed with `+`.
The value types for the merged contents must match and be both of `dict` or both of `list` type.
`dict` values will be merged via update(), `list` values - concatenated via extend().
Here's an example. In this case, "amp" value will be overridden by extra_config.json.
`imports` and `preprocessing#transforms` lists will be merged. An error would be thrown if the value type in `"+imports"` is not `list`:

config.json:
```json
{
"amp": "$True"
"imports": [
"$import torch"
],
"preprocessing": {
"_target_": "Compose",
"transforms": [
"$@t1",
"$@t2"
]
},
}
```

extra_config.json:
```json
{
"amp": "$False"
"+imports": [
"$from monai.networks import trt_compile"
],
"+preprocessing#transforms": [
"$@t3"
]
}
```

## The command line interface

In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle.
Expand Down
8 changes: 5 additions & 3 deletions monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem
from monai.bundle.reference_resolver import ReferenceResolver
from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY
from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY, merge_kv
from monai.config import PathLike
from monai.utils import ensure_tuple, look_up_option, optional_import
from monai.utils.misc import CheckKeyDuplicatesYamlLoader, check_key_duplicates
Expand Down Expand Up @@ -423,8 +423,10 @@ def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs
if isinstance(files, str) and not Path(files).is_file() and "," in files:
files = files.split(",")
for i in ensure_tuple(files):
for k, v in (cls.load_config_file(i, **kwargs)).items():
parser[k] = v
config_dict = cls.load_config_file(i, **kwargs)
for k, v in config_dict.items():
merge_kv(parser, k, v)

return parser.get() # type: ignore

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from monai.apps.utils import _basename, download_url, extractall, get_logger
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
from monai.config import IgniteInfo, PathLike
from monai.data import load_net_with_metadata, save_net_with_metadata
Expand Down Expand Up @@ -105,7 +105,7 @@ def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kw
if isinstance(v, dict) and isinstance(args_.get(k), dict):
args_[k] = update_kwargs(args_[k], ignore_none, **v)
else:
args_[k] = v
merge_kv(args_, k, v)
return args_


Expand Down
36 changes: 35 additions & 1 deletion monai/bundle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import json
import os
import warnings
import zipfile
from typing import Any

Expand All @@ -21,12 +22,21 @@

yaml, _ = optional_import("yaml")

__all__ = ["ID_REF_KEY", "ID_SEP_KEY", "EXPR_KEY", "MACRO_KEY", "DEFAULT_MLFLOW_SETTINGS", "DEFAULT_EXP_MGMT_SETTINGS"]
__all__ = [
"ID_REF_KEY",
"ID_SEP_KEY",
"EXPR_KEY",
"MACRO_KEY",
"MERGE_KEY",
"DEFAULT_MLFLOW_SETTINGS",
"DEFAULT_EXP_MGMT_SETTINGS",
]

ID_REF_KEY = "@" # start of a reference to a ConfigItem
ID_SEP_KEY = "::" # separator for the ID of a ConfigItem
EXPR_KEY = "$" # start of a ConfigExpression
MACRO_KEY = "%" # start of a macro of a config
MERGE_KEY = "+" # prefix indicating merge instead of override in case of multiple configs.

_conf_values = get_config_values()

Expand Down Expand Up @@ -233,3 +243,27 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any
parser.read_config(f=cdata)

return parser


def merge_kv(args: dict | Any, k: str, v: Any) -> None:
"""
Update the `args` dict-like object with the key/value pair `k` and `v`.
"""
if k.startswith(MERGE_KEY):
"""
Both values associated with `+`-prefixed key pair must be of `dict` or `list` type.
`dict` values will be merged, `list` values - concatenated.
"""
id = k[1:]
if id in args:
if isinstance(v, dict) and isinstance(args[id], dict):
args[id].update(v)
elif isinstance(v, list) and isinstance(args[id], list):
args[id].extend(v)
else:
raise ValueError(ValueError(f"config must be dict or list for key `{k}`, but got {type(v)}: {v}."))
else:
warnings.warn(f"Can't merge entry ['{k}'], '{id}' is not in target dict - copying instead.")
args[id] = v
else:
args[k] = v
1 change: 1 addition & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@
from .stats_handler import StatsHandler
from .surface_distance import SurfaceDistance
from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler
from .trt_handler import TrtHandler
from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports
from .validation_handler import ValidationHandler
61 changes: 61 additions & 0 deletions monai/handlers/trt_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

from monai.config import IgniteInfo
from monai.networks import trt_compile
from monai.utils import min_version, optional_import

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
from ignite.engine import Engine
else:
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")


class TrtHandler:
"""
TrtHandler acts as an Ignite handler to apply TRT acceleration to the model.
Usage example::
handler = TrtHandler(model=model, base_path="/test/checkpoint.pt", args={"precision": "fp16"})
handler.attach(engine)
engine.run()
"""

def __init__(self, model, base_path, args=None, submodule=None):
"""
Args:
base_path: TRT path basename. TRT plan(s) saved to "base_path[.submodule].plan"
args: passed to trt_compile(). See trt_compile() for details.
submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder'
"""
self.model = model
self.base_path = base_path
self.args = args
self.submodule = submodule

def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
self.logger = engine.logger
engine.add_event_handler(Events.STARTED, self)

def __call__(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
trt_compile(self.model, self.base_path, args=self.args, submodule=self.submodule, logger=self.logger)
2 changes: 2 additions & 0 deletions monai/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

from __future__ import annotations

from .trt_compiler import trt_compile
from .utils import (
add_casts_around_norms,
convert_to_onnx,
convert_to_torchscript,
convert_to_trt,
Expand Down
8 changes: 4 additions & 4 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _check_input_size(self, spatial_shape):
)

def forward(self, x_in):
if not torch.jit.is_scripting():
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
self._check_input_size(x_in.shape[2:])
hidden_states_out = self.swinViT(x_in, self.normalize)
enc0 = self.encoder1(x_in)
Expand Down Expand Up @@ -1046,14 +1046,14 @@ def __init__(

def proj_out(self, x, normalize=False):
if normalize:
x_shape = x.size()
x_shape = x.shape
# Force trace() to generate a constant by casting to int
ch = int(x_shape[1])
if len(x_shape) == 5:
n, ch, d, h, w = x_shape
x = rearrange(x, "n c d h w -> n d h w c")
x = F.layer_norm(x, [ch])
x = rearrange(x, "n d h w c -> n c d h w")
elif len(x_shape) == 4:
n, ch, h, w = x_shape
x = rearrange(x, "n c h w -> n h w c")
x = F.layer_norm(x, [ch])
x = rearrange(x, "n h w c -> n c h w")
Expand Down
Loading