|
16 | 16 | from torch_tensorrt.dynamo._defaults import ( |
17 | 17 | DEBUG, |
18 | 18 | DEVICE, |
| 19 | + DISABLE_TF32, |
| 20 | + DLA_GLOBAL_DRAM_SIZE, |
| 21 | + DLA_LOCAL_DRAM_SIZE, |
| 22 | + DLA_SRAM_SIZE, |
19 | 23 | ENABLE_EXPERIMENTAL_DECOMPOSITIONS, |
| 24 | + ENGINE_CAPABILITY, |
20 | 25 | MAX_AUX_STREAMS, |
21 | 26 | MIN_BLOCK_SIZE, |
| 27 | + NUM_AVG_TIMING_ITERS, |
22 | 28 | OPTIMIZATION_LEVEL, |
23 | 29 | PASS_THROUGH_BUILD_FAILURES, |
24 | 30 | PRECISION, |
| 31 | + REFIT, |
25 | 32 | REQUIRE_FULL_COMPILATION, |
| 33 | + SPARSE_WEIGHTS, |
26 | 34 | TRUNCATE_LONG_AND_DOUBLE, |
27 | 35 | USE_FAST_PARTITIONER, |
28 | 36 | USE_PYTHON_RUNTIME, |
@@ -51,17 +59,18 @@ def compile( |
51 | 59 | inputs: Any, |
52 | 60 | *, |
53 | 61 | device: Optional[Union[Device, torch.device, str]] = DEVICE, |
54 | | - disable_tf32: bool = False, |
55 | | - sparse_weights: bool = False, |
| 62 | + disable_tf32: bool = DISABLE_TF32, |
| 63 | + sparse_weights: bool = SPARSE_WEIGHTS, |
56 | 64 | enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,), |
57 | | - refit: bool = False, |
| 65 | + engine_capability: EngineCapability = ENGINE_CAPABILITY, |
| 66 | + refit: bool = REFIT, |
58 | 67 | debug: bool = DEBUG, |
59 | 68 | capability: EngineCapability = EngineCapability.default, |
60 | | - num_avg_timing_iters: int = 1, |
| 69 | + num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS, |
61 | 70 | workspace_size: int = WORKSPACE_SIZE, |
62 | | - dla_sram_size: int = 1048576, |
63 | | - dla_local_dram_size: int = 1073741824, |
64 | | - dla_global_dram_size: int = 536870912, |
| 71 | + dla_sram_size: int = DLA_SRAM_SIZE, |
| 72 | + dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE, |
| 73 | + dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE, |
65 | 74 | calibrator: object = None, |
66 | 75 | truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE, |
67 | 76 | require_full_compilation: bool = REQUIRE_FULL_COMPILATION, |
@@ -200,6 +209,13 @@ def compile( |
200 | 209 | "use_fast_partitioner": use_fast_partitioner, |
201 | 210 | "enable_experimental_decompositions": enable_experimental_decompositions, |
202 | 211 | "require_full_compilation": require_full_compilation, |
| 212 | + "disable_tf32": disable_tf32, |
| 213 | + "sparse_weights": sparse_weights, |
| 214 | + "refit": refit, |
| 215 | + "engine_capability": engine_capability, |
| 216 | + "dla_sram_size": dla_sram_size, |
| 217 | + "dla_local_dram_size": dla_local_dram_size, |
| 218 | + "dla_global_dram_size": dla_global_dram_size, |
203 | 219 | } |
204 | 220 |
|
205 | 221 | settings = CompilationSettings(**compilation_options) |
|
0 commit comments