|
12 | 12 | partition, |
13 | 13 | get_submod_inputs, |
14 | 14 | ) |
| 15 | +from torch_tensorrt.dynamo.backend.utils import parse_dynamo_kwargs |
15 | 16 | from torch_tensorrt.dynamo.backend.conversion import convert_module |
16 | 17 |
|
17 | 18 | from torch._dynamo.backends.common import fake_tensor_unsupported |
|
25 | 26 | @td.register_backend(name="torch_tensorrt") |
26 | 27 | @fake_tensor_unsupported |
27 | 28 | def torch_tensorrt_backend( |
28 | | - gm: torch.fx.GraphModule, |
29 | | - sample_inputs: Sequence[torch.Tensor], |
30 | | - settings: CompilationSettings = CompilationSettings(), |
| 29 | + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs |
31 | 30 | ): |
32 | 31 | DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend |
33 | 32 |
|
34 | | - return DEFAULT_BACKEND(gm, sample_inputs, settings=settings) |
| 33 | + return DEFAULT_BACKEND(gm, sample_inputs, **kwargs) |
35 | 34 |
|
36 | 35 |
|
37 | 36 | @td.register_backend(name="aot_torch_tensorrt_aten") |
38 | 37 | @fake_tensor_unsupported |
39 | 38 | def aot_torch_tensorrt_aten_backend( |
40 | | - gm: torch.fx.GraphModule, |
41 | | - sample_inputs: Sequence[torch.Tensor], |
42 | | - settings: CompilationSettings = CompilationSettings(), |
| 39 | + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs |
43 | 40 | ): |
| 41 | + settings = parse_dynamo_kwargs(kwargs) |
| 42 | + |
| 43 | + # Enable debug/verbose mode if requested |
| 44 | + if settings.debug: |
| 45 | + logger.setLevel(logging.DEBUG) |
| 46 | + |
44 | 47 | custom_backend = partial( |
45 | 48 | _pretraced_backend, |
46 | 49 | settings=settings, |
|
0 commit comments