diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 52a9b47c12..253738b434 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -61,7 +61,8 @@ TRTEngine::TRTEngine( const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - const std::string& serialized_metadata) + const std::string& serialized_metadata, + const ResourceAllocationStrategy resource_allocation_strategy) : TRTEngine( "deserialized_trt", serialized_engine, @@ -71,7 +72,8 @@ TRTEngine::TRTEngine( target_platform, hardware_compatible, requires_output_allocator, - serialized_metadata) {} + serialized_metadata, + resource_allocation_strategy) {} TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( @@ -83,7 +85,8 @@ TRTEngine::TRTEngine(std::vector serialized_info) Platform(serialized_info[TARGET_PLATFORM_IDX]), static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX])), static_cast(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])), - serialized_info[SERIALIZED_METADATA_IDX]) {} + serialized_info[SERIALIZED_METADATA_IDX], + (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic : ResourceAllocationStrategy::kStatic)) {} TRTEngine::TRTEngine( const std::string& mod_name, @@ -94,7 +97,8 @@ TRTEngine::TRTEngine( const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - const std::string& serialized_metadata) { + const std::string& serialized_metadata, + const ResourceAllocationStrategy resource_allocation_strategy) { TORCHTRT_CHECK( is_supported_on_current_platform(target_platform), "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " @@ -124,7 +128,14 @@ TRTEngine::TRTEngine( cuda_engine->setWeightStreamingBudgetV2(budget_bytes); } - exec_ctx = make_trt(cuda_engine->createExecutionContext()); + this->resource_allocation_strategy = resource_allocation_strategy; + LOG_DEBUG("Resource allocation strategy: " << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static")); + if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) { + this->exec_ctx = + make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); + } else { + this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); + } TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context"); runtime_states.old_cudagraphs = CUDAGRAPHS_MODE; @@ -393,6 +404,7 @@ std::string TRTEngine::to_str() const { ss << " Device: " << device_info << std::endl; ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl; ss << " Target Platform: " << target_platform << std::endl; + ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; // clang-format on return ss.str(); } @@ -436,7 +448,8 @@ FlattenedState TRTEngine::__obj_flatten__() { std::tuple("hardware_compatible", serialized_info[HW_COMPATIBLE_IDX]), std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]), std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), - std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX])); + std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]), + std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])); } std::vector TRTEngine::serialize() { @@ -459,6 +472,7 @@ std::vector TRTEngine::serialize() { serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this->requires_output_allocator ? "1" : "0"; serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata; serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize(); + serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; return serialized_info; } @@ -467,6 +481,20 @@ void TRTEngine::reset_captured_graph() { cudagraph.reset(); } +void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) { + if (new_strategy != this->resource_allocation_strategy) { + this->resource_allocation_strategy = new_strategy; + if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) { + LOG_DEBUG("Setting resource allocation strategy to dynamic"); + this->exec_ctx = make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); + } else { + LOG_DEBUG("Setting resource allocation strategy to static"); + this->exec_ctx = make_trt( + cuda_engine->createExecutionContext()); + } + } +} + } // namespace runtime } // namespace core } // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 15d723ce4e..2ed07f0bcc 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -29,7 +29,8 @@ using FlattenedState = std::tuple< std::tuple, // HW compatibility std::tuple, // requires_output_allocator std::tuple, // serialized metadata - std::tuple>; // Platform + std::tuple, // Platform + std::tuple>; // Resource Allocation Strategy struct TorchTRTRuntimeStates { // Indicates whether CUDAGraphs were enabled in the previous execute_engine @@ -98,6 +99,8 @@ class DynamicOutputAllocator : public nvinfer1::IOutputAllocator { }; struct TRTEngine : torch::CustomClassHolder { + // Resource Allocation Strategy + typedef enum { kStatic = 0, kDynamic } ResourceAllocationStrategy; // Each engine needs it's own runtime object std::shared_ptr rt; std::shared_ptr cuda_engine; @@ -128,7 +131,9 @@ struct TRTEngine : torch::CustomClassHolder { const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - const std::string& serialized_metadata = ""); + const std::string& serialized_metadata = "", + const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = + TRTEngine::ResourceAllocationStrategy::kStatic); TRTEngine(std::vector serialized_info); @@ -141,7 +146,9 @@ struct TRTEngine : torch::CustomClassHolder { const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - const std::string& serialized_metadata = ""); + const std::string& serialized_metadata = "", + const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = + TRTEngine::ResourceAllocationStrategy::kStatic); TRTEngine& operator=(const TRTEngine& other); std::string to_str() const; @@ -200,6 +207,9 @@ struct TRTEngine : torch::CustomClassHolder { std::string cuda_graph_debug_path; std::mutex mu; std::unique_ptr trt_engine_profiler; + ResourceAllocationStrategy resource_allocation_strategy = kStatic; + void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy); + ResourceAllocationStrategy get_resource_allocation_strategy(); }; } // namespace runtime diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 64b111750f..d36cc98c80 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -201,6 +201,12 @@ void create_output_allocator(c10::intrusive_ptr compiled_engine) { } std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine) { + torch::Tensor dynamic_workspace; + if (compiled_engine->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) { + dynamic_workspace = torch::empty(compiled_engine->cuda_engine->getDeviceMemorySizeV2(), {torch::kCUDA}); + compiled_engine->exec_ctx->setDeviceMemory(dynamic_workspace.data_ptr()); + } + auto run_standard_execution = [&]() { bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); bool shape_changed = _validate_shapes(inputs, compiled_engine); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 173ff8c35f..6d15bd8ae8 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -90,6 +90,13 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("get_engine_layer_info", &TRTEngine::get_engine_layer_info) .def("infer_outputs", &TRTEngine::infer_outputs) .def("reset_captured_graph", &TRTEngine::reset_captured_graph) + .def( + "use_dynamically_allocated_resources", + [](const c10::intrusive_ptr& self, bool dynamic) -> void { + self->set_resource_allocation_strategy( + dynamic ? TRTEngine::ResourceAllocationStrategy::kDynamic + : TRTEngine::ResourceAllocationStrategy::kStatic); + }) .def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs) .def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs) .def_property( @@ -102,6 +109,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = [](const c10::intrusive_ptr& self) -> std::vector { return self->serialize(); }, [](std::vector serialized_info) -> c10::intrusive_ptr { serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]); + LOG_DEBUG("Deserialized resource allocation strategy: " << (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? "Dynamic" : "Static")); TRTEngine::verify_serialization_fmt(serialized_info); return c10::make_intrusive(serialized_info); }); @@ -135,6 +143,7 @@ TORCH_LIBRARY(tensorrt, m) { m.def("TARGET_PLATFORM_IDX", []() -> int64_t { return TARGET_PLATFORM_IDX; }); m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; }); m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; }); + m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; }); m.def("_platform_linux_x86_64", []() -> std::string { auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); return it->second; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 894df55bfe..233b4bb274 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -38,6 +38,7 @@ typedef enum { SERIALIZED_METADATA_IDX, TARGET_PLATFORM_IDX, REQUIRES_OUTPUT_ALLOCATOR_IDX, + RESOURCE_ALLOCATION_STRATEGY_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; @@ -45,6 +46,9 @@ std::string base64_encode(const std::string& in); std::string base64_decode(const std::string& in); std::string serialize_bindings(const std::vector& bindings); +std::string resource_allocation_strategy_to_string(TRTEngine::ResourceAllocationStrategy strategy); +TRTEngine::ResourceAllocationStrategy resource_allocation_strategy_from_string(const std::string& str); + c10::optional get_most_compatible_device( const RTDevice& target_device, const RTDevice& curr_device = RTDevice(), diff --git a/examples/dynamo/dynamic_memory_allocation.py b/examples/dynamo/dynamic_memory_allocation.py new file mode 100644 index 0000000000..d609a83045 --- /dev/null +++ b/examples/dynamo/dynamic_memory_allocation.py @@ -0,0 +1,42 @@ +# %% +import numpy as np +import torch +import torch_tensorrt as torch_trt +import torchvision.models as models +import time +import gc + +np.random.seed(5) +torch.manual_seed(5) +inputs = [torch.rand((100, 3, 224, 224)).to("cuda")] + +settings = { + "ir": "dynamo", + "use_python_runtime": False, + "enabled_precisions": {torch.float32}, + "immutable_weights": False, + "lazy_engine_init": True, + "dynamically_allocate_resources": True + +} + +model = models.resnet152(pretrained=True).eval().to("cuda") +compiled_module = torch_trt.compile(model, inputs=inputs, **settings) +print((torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3) +compiled_module(*inputs) + +time.sleep(30) +with torch_trt.dynamo.runtime.ResourceAllocationStrategy(compiled_module, dynamically_allocate_resources=False): + print( + "Memory used (GB):", + (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3, + ) + compiled_module(*inputs) + gc.collect() + torch.cuda.empty_cache() + time.sleep(30) + print( + "Memory used (GB):", + (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3, + ) + compiled_module(*inputs) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 74cab980c4..cebbb88273 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -103,6 +103,7 @@ def cross_compile_for_windows( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -177,6 +178,7 @@ def cross_compile_for_windows( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -340,6 +342,7 @@ def cross_compile_for_windows( "enable_weight_streaming": enable_weight_streaming, "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, + "dynamically_allocate_resources": dynamically_allocate_resources, } # disable the following settings is not supported for cross compilation for windows feature @@ -440,6 +443,7 @@ def compile( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -517,6 +521,7 @@ def compile( tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. + dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -690,6 +695,7 @@ def compile( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, + "dynamically_allocate_resources": dynamically_allocate_resources, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index de970ecd81..b58d0a528b 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -57,6 +57,7 @@ L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False OFFLOAD_MODULE_TO_CPU = False +DYNAMICALLY_ALLOCATE_RESOURCES = False if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d8f6809eae..e9f5174e2c 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -11,6 +11,7 @@ DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, + DYNAMICALLY_ALLOCATE_RESOURCES, DRYRUN, ENABLE_CROSS_COMPILE_FOR_WINDOWS, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, @@ -97,6 +98,8 @@ class CompilationSettings: tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + offload_module_to_cpu (bool): Offload the model to CPU to reduce memory footprint during compilation + dynamically_allocate_resources (bool): Dynamically allocate resources for TensorRT engines """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -140,6 +143,7 @@ class CompilationSettings: l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU + dynamically_allocate_resources: bool = DYNAMICALLY_ALLOCATE_RESOURCES def __getstate__(self) -> dict[str, Any]: from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( diff --git a/py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py b/py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py new file mode 100644 index 0000000000..f843cedcec --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py @@ -0,0 +1,32 @@ +from typing import Any + +import torch + + +class ResourceAllocationStrategy(torch.nn.Module): # type: ignore[misc] + """ + ResourceAllocationStrategy is a context manager module that temporarily enables dynamic resource allocation + for all TRT submodules of the given compiled_module. When entering the context, + it sets these submodules to use dynamically allocated resources. Upon exiting, it restores them to their + original (static) resource allocation mode. + """ + + def __init__( + self, + compiled_module: torch.nn.Module, + dynamically_allocate_resources: bool = True + ) -> None: + super(ResourceAllocationStrategy, self).__init__() + self.compiled_module = compiled_module + self.dynamically_allocate_resources = dynamically_allocate_resources + + def __enter__(self) -> None: + print("Entering resource allocator context") + for name, submodule in self.compiled_module.named_modules(): + if "_run_on_acc" in name: + submodule.use_dynamically_allocated_resources(dynamically_allocate_resources=self.dynamically_allocate_resources) + + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: + for name, submodule in self.compiled_module.named_modules(): + if "_run_on_acc" in name: + submodule.use_dynamically_allocated_resources(dynamically_allocate_resources=self.dynamically_allocate_resources) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 95f1581881..9c279396d7 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -50,7 +50,10 @@ REQUIRES_OUTPUT_ALLOCATOR_IDX = ( torch.ops.tensorrt.REQUIRES_OUTPUT_ALLOCATOR_IDX() ) # 9 - SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 10 + RESOURCE_ALLOCATION_STRATEGY_IDX = ( + torch.ops.tensorrt.RESOURCE_ALLOCATION_STRATEGY_IDX() + ) # 10 + SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 11 @for_all_methods(needs_torch_tensorrt_runtime) @@ -139,6 +142,7 @@ def __init__( self.serialized_engine = serialized_engine self.engine = None self.requires_output_allocator = requires_output_allocator + self.dynamically_allocate_resources = settings.dynamically_allocate_resources if ( serialized_engine @@ -184,6 +188,11 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = str( int(self.requires_output_allocator) ) + print(f"PROVIDED RESOURCE ALLOCATION STRATEGY: {self.dynamically_allocate_resources}") + engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str( + int(self.dynamically_allocate_resources) + ) + print(engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX]) return engine_info @@ -212,6 +221,10 @@ def set_device_memory_budget(self, budget_bytes: int) -> int: def _reset_captured_graph(self) -> None: self.engine.reset_captured_graph() + def use_dynamically_allocated_resources(self, dynamically_allocate_resources: bool = False) -> None: + self.dynamically_allocate_resources = dynamically_allocate_resources + self.engine.use_dynamically_allocated_resources(self.dynamically_allocate_resources) + def setup_engine(self) -> None: """ Setup engine for a module which has deferred engine setup. diff --git a/py/torch_tensorrt/dynamo/runtime/__init__.py b/py/torch_tensorrt/dynamo/runtime/__init__.py index de47d942e9..0eb66b24b0 100644 --- a/py/torch_tensorrt/dynamo/runtime/__init__.py +++ b/py/torch_tensorrt/dynamo/runtime/__init__.py @@ -2,6 +2,9 @@ from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( # noqa: F401 PythonTorchTensorRTModule, ) +from torch_tensorrt.dynamo.runtime._ResourceAllocator import ( # noqa: F401 + ResourceAllocationStrategy, +) from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: F401 TorchTensorRTModule, ) diff --git a/uv.lock b/uv.lock index 18b5f3d7ed..79621781be 100644 --- a/uv.lock +++ b/uv.lock @@ -2526,7 +2526,7 @@ sdist = { url = "https://pypi.nvidia.com/tensorrt/tensorrt-10.3.0.tar.gz", hash [[package]] name = "tensorrt" -version = "10.11.0.33" +version = "10.12.0.36" source = { registry = "https://pypi.nvidia.com/" } resolution-markers = [ "python_full_version >= '3.12' and platform_machine != 'aarch64' and 'tegra' not in platform_release and sys_platform == 'linux'", @@ -2551,9 +2551,9 @@ resolution-markers = [ "python_full_version < '3.10' and platform_machine != 'aarch64' and 'tegra' in platform_release and sys_platform == 'windows'", ] dependencies = [ - { name = "tensorrt-cu12", version = "10.11.0.33", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, + { name = "tensorrt-cu12", version = "10.12.0.36", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, ] -sdist = { url = "https://pypi.nvidia.com/tensorrt/tensorrt-10.11.0.33.tar.gz", hash = "sha256:a3d6048f86e11ea5202d473646194d3be866c0c8d578ac0b7eeb91d923f65d0b" } +sdist = { url = "https://pypi.nvidia.com/tensorrt/tensorrt-10.12.0.36.tar.gz", hash = "sha256:b246a830c26713e097b73151917e101cfb81aa0e7274c3c3b4c1f9f8b886be2e" } [[package]] name = "tensorrt-cu12" @@ -2573,7 +2573,7 @@ sdist = { url = "https://pypi.nvidia.com/tensorrt-cu12/tensorrt-cu12-10.3.0.tar. [[package]] name = "tensorrt-cu12" -version = "10.11.0.33" +version = "10.12.0.36" source = { registry = "https://pypi.nvidia.com/" } resolution-markers = [ "python_full_version >= '3.12' and platform_machine != 'aarch64' and 'tegra' not in platform_release and sys_platform == 'linux'", @@ -2598,10 +2598,10 @@ resolution-markers = [ "python_full_version < '3.10' and platform_machine != 'aarch64' and 'tegra' in platform_release and sys_platform == 'windows'", ] dependencies = [ - { name = "tensorrt-cu12-bindings", version = "10.11.0.33", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, - { name = "tensorrt-cu12-libs", version = "10.11.0.33", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, + { name = "tensorrt-cu12-bindings", version = "10.12.0.36", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, + { name = "tensorrt-cu12-libs", version = "10.12.0.36", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, ] -sdist = { url = "https://pypi.nvidia.com/tensorrt-cu12/tensorrt_cu12-10.11.0.33.tar.gz", hash = "sha256:7e29c8b16771c025320035ba9609c2a074767d9a8c05696a30c9d5c0fdfb37df" } +sdist = { url = "https://pypi.nvidia.com/tensorrt-cu12/tensorrt_cu12-10.12.0.36.tar.gz", hash = "sha256:aedeee0195c042592ac6b0536b19bc8cdbb1a548f35e09d24fbe78e1c76217c5" } [[package]] name = "tensorrt-cu12-bindings" @@ -2620,7 +2620,7 @@ resolution-markers = [ [[package]] name = "tensorrt-cu12-bindings" -version = "10.11.0.33" +version = "10.12.0.36" source = { registry = "https://pypi.nvidia.com/" } resolution-markers = [ "python_full_version >= '3.12' and platform_machine != 'aarch64' and 'tegra' not in platform_release and sys_platform == 'linux'", @@ -2645,16 +2645,16 @@ resolution-markers = [ "python_full_version < '3.10' and platform_machine != 'aarch64' and 'tegra' in platform_release and sys_platform == 'windows'", ] wheels = [ - { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.11.0.33-cp310-none-manylinux_2_28_x86_64.whl", hash = "sha256:a2d27745575be5d7f06caa9565230025b8e41a8915ee6a5dc735d41c3faf206d" }, - { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.11.0.33-cp310-none-manylinux_2_31_aarch64.whl", hash = "sha256:546c7ee976366dc9cb76ffefbde555dec4feddcfb508b4c99ee626447b8c72de" }, - { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.11.0.33-cp311-none-manylinux_2_28_x86_64.whl", hash = "sha256:e7b7a5b80174f8c4ddd8a63bc9fa97cad3320409eafad79428bc2b1e15884068" }, - { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.11.0.33-cp311-none-manylinux_2_31_aarch64.whl", hash = "sha256:492e3e91d7c1083bff1f7c15fdd8f5fb09a782dcfa6d1d0f8d9034b2e3b38cad" }, - { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.11.0.33-cp312-none-manylinux_2_28_x86_64.whl", hash = "sha256:a8f374f6d752ce4b0d4a8303d29c3ba9904eb29da0dc95b4db6b75c501997e4a" }, - { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.11.0.33-cp312-none-manylinux_2_31_aarch64.whl", hash = "sha256:6a3b768cea69b153ed0c2eb50130d150406d5c1498fdb0bf6c8a1be160137a6a" }, - { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.11.0.33-cp313-none-manylinux_2_28_x86_64.whl", hash = "sha256:1ceda290d1ed79b6107b0eb29eeb178f569d007c1506b72caae8248975d57662" }, - { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.11.0.33-cp313-none-manylinux_2_31_aarch64.whl", hash = "sha256:3c27e0d6e36a3b1f06e1dc8b735e34f04f5b8aac3e7d9b21762b8264496e825f" }, - { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.11.0.33-cp39-none-manylinux_2_28_x86_64.whl", hash = "sha256:9a801886f389b75f92e69fc6be40308392ec7746dbf4de4a2b76585d591960f0" }, - { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.11.0.33-cp39-none-manylinux_2_31_aarch64.whl", hash = "sha256:42e9b3cc2e3c6bcc0785c9c96b4dd25cd7043ff95e4fd09c8d35331f63ce9634" }, + { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.12.0.36-cp310-none-manylinux_2_28_x86_64.whl", hash = "sha256:7ecdb6fc2555caed7d4fbbd8158ed7ced64e230c125484f62a5369c40dcc70e5" }, + { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.12.0.36-cp310-none-manylinux_2_31_aarch64.whl", hash = "sha256:d8548ab5976ca5c91279c68ee77f4c892e03460709cfa3fbd2a22aa8123cb731" }, + { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.12.0.36-cp311-none-manylinux_2_28_x86_64.whl", hash = "sha256:58cf45605bb330e86f8ad49bc8997ed68cfdf5b09da229534fb7f84aa3fe5bf4" }, + { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.12.0.36-cp311-none-manylinux_2_31_aarch64.whl", hash = "sha256:ae0866a89caaeada1c16776de85413a523f78f53b1fd83f1b903c39eed264d82" }, + { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.12.0.36-cp312-none-manylinux_2_28_x86_64.whl", hash = "sha256:fb3a2ce96c7472a46bbee2030ce6a54fd6a32deda401c1c67d9de057550e0171" }, + { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.12.0.36-cp312-none-manylinux_2_31_aarch64.whl", hash = "sha256:f5128b8b2a379e65c09745ba97df58abf3a418cbfd6508d37f76121d9bdd3bc8" }, + { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.12.0.36-cp313-none-manylinux_2_28_x86_64.whl", hash = "sha256:0eb8d3e41279b1d0d329b85372d5d720c8d2ff1228f6273142d717b44d75935b" }, + { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.12.0.36-cp313-none-manylinux_2_31_aarch64.whl", hash = "sha256:a850992cad842340e6fed41fe74f529064064ff61881d50ef5a2be1816526f9b" }, + { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.12.0.36-cp39-none-manylinux_2_28_x86_64.whl", hash = "sha256:986cb86202ef9541279b59d4e254743aff43bae1def87d14dd06e02369107c8b" }, + { url = "https://pypi.nvidia.com/tensorrt-cu12-bindings/tensorrt_cu12_bindings-10.12.0.36-cp39-none-manylinux_2_31_aarch64.whl", hash = "sha256:c5b86638ae5e3a2101755d469ac2ce831d4bdece1d20fa2bd546c05c554b5952" }, ] [[package]] @@ -2677,7 +2677,7 @@ dependencies = [ [[package]] name = "tensorrt-cu12-libs" -version = "10.11.0.33" +version = "10.12.0.36" source = { registry = "https://pypi.nvidia.com/" } resolution-markers = [ "python_full_version >= '3.12' and platform_machine != 'aarch64' and 'tegra' not in platform_release and sys_platform == 'linux'", @@ -2706,8 +2706,8 @@ dependencies = [ { name = "nvidia-cuda-runtime-cu12", version = "12.9.79", source = { registry = "https://download.pytorch.org/whl/nightly/cu129" }, marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, ] wheels = [ - { url = "https://pypi.nvidia.com/tensorrt-cu12-libs/tensorrt_cu12_libs-10.11.0.33-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:81ace8d3284fdbef0804c444a4d7555343ee079370e79c93cb328c7d9b08f968" }, - { url = "https://pypi.nvidia.com/tensorrt-cu12-libs/tensorrt_cu12_libs-10.11.0.33-py2.py3-none-manylinux_2_31_aarch64.whl", hash = "sha256:b6846dbc32d717a5031d9757f16293dd9e25de8a1c4aae8c00701d52351ef173" }, + { url = "https://pypi.nvidia.com/tensorrt-cu12-libs/tensorrt_cu12_libs-10.12.0.36-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:3910039e1d49de0edfdc8bf273e40ad4b85a9d57c7c383fe0e22f75417df9610" }, + { url = "https://pypi.nvidia.com/tensorrt-cu12-libs/tensorrt_cu12_libs-10.12.0.36-py2.py3-none-manylinux_2_31_aarch64.whl", hash = "sha256:1c117effa7318b65508457e9a11e67941859c8e5c346b59fd0090f66be28f2f4" }, ] [[package]] @@ -2886,13 +2886,13 @@ dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.10' and 'tegra' not in platform_release and sys_platform == 'linux') or (python_full_version >= '3.10' and 'tegra' not in platform_release and sys_platform == 'windows')" }, { name = "packaging", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, { name = "tensorrt", version = "10.3.0", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine == 'aarch64' and 'tegra' in platform_release and sys_platform == 'linux') or (platform_machine == 'aarch64' and 'tegra' in platform_release and sys_platform == 'windows')" }, - { name = "tensorrt", version = "10.11.0.33", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, + { name = "tensorrt", version = "10.12.0.36", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, { name = "tensorrt-cu12", version = "10.3.0", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine == 'aarch64' and 'tegra' in platform_release and sys_platform == 'linux') or (platform_machine == 'aarch64' and 'tegra' in platform_release and sys_platform == 'windows')" }, - { name = "tensorrt-cu12", version = "10.11.0.33", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, + { name = "tensorrt-cu12", version = "10.12.0.36", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, { name = "tensorrt-cu12-bindings", version = "10.3.0", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine == 'aarch64' and 'tegra' in platform_release and sys_platform == 'linux') or (platform_machine == 'aarch64' and 'tegra' in platform_release and sys_platform == 'windows')" }, - { name = "tensorrt-cu12-bindings", version = "10.11.0.33", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, + { name = "tensorrt-cu12-bindings", version = "10.12.0.36", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, { name = "tensorrt-cu12-libs", version = "10.3.0", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine == 'aarch64' and 'tegra' in platform_release and sys_platform == 'linux') or (platform_machine == 'aarch64' and 'tegra' in platform_release and sys_platform == 'windows')" }, - { name = "tensorrt-cu12-libs", version = "10.11.0.33", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, + { name = "tensorrt-cu12-libs", version = "10.12.0.36", source = { registry = "https://pypi.nvidia.com/" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, { name = "torch", version = "2.7.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and 'tegra' in platform_release and sys_platform == 'linux') or (platform_machine == 'aarch64' and 'tegra' in platform_release and sys_platform == 'windows')" }, { name = "torch", version = "2.9.0.dev20250701+cu129", source = { registry = "https://download.pytorch.org/whl/nightly/cu129" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and sys_platform == 'windows') or ('tegra' not in platform_release and sys_platform == 'linux') or ('tegra' not in platform_release and sys_platform == 'windows')" }, { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, @@ -2940,13 +2940,13 @@ requires-dist = [ { name = "numpy", marker = "platform_machine == 'aarch64' and 'tegra' in platform_release", specifier = "<2.0.0" }, { name = "nvidia-modelopt", extras = ["all"], marker = "extra == 'quantization'", specifier = ">=0.27.1" }, { name = "packaging", specifier = ">=23" }, - { name = "tensorrt", marker = "platform_machine != 'aarch64' or 'tegra' not in platform_release", specifier = ">=10.11.0,<10.12.0" }, + { name = "tensorrt", marker = "platform_machine != 'aarch64' or 'tegra' not in platform_release", specifier = ">=10.12.0,<10.13.0" }, { name = "tensorrt", marker = "platform_machine == 'aarch64' and 'tegra' in platform_release", specifier = ">=10.3.0,<10.4.0" }, - { name = "tensorrt-cu12", marker = "platform_machine != 'aarch64' or 'tegra' not in platform_release", specifier = ">=10.11.0,<10.12.0" }, + { name = "tensorrt-cu12", marker = "platform_machine != 'aarch64' or 'tegra' not in platform_release", specifier = ">=10.12.0,<10.13.0" }, { name = "tensorrt-cu12", marker = "platform_machine == 'aarch64' and 'tegra' in platform_release", specifier = ">=10.3.0,<10.4.0" }, - { name = "tensorrt-cu12-bindings", marker = "platform_machine != 'aarch64' or 'tegra' not in platform_release", specifier = ">=10.11.0,<10.12.0" }, + { name = "tensorrt-cu12-bindings", marker = "platform_machine != 'aarch64' or 'tegra' not in platform_release", specifier = ">=10.12.0,<10.13.0" }, { name = "tensorrt-cu12-bindings", marker = "platform_machine == 'aarch64' and 'tegra' in platform_release", specifier = ">=10.3.0,<10.4.0" }, - { name = "tensorrt-cu12-libs", marker = "platform_machine != 'aarch64' or 'tegra' not in platform_release", specifier = ">=10.11.0,<10.12.0" }, + { name = "tensorrt-cu12-libs", marker = "platform_machine != 'aarch64' or 'tegra' not in platform_release", specifier = ">=10.12.0,<10.13.0" }, { name = "tensorrt-cu12-libs", marker = "platform_machine == 'aarch64' and 'tegra' in platform_release", specifier = ">=10.3.0,<10.4.0" }, { name = "torch", marker = "platform_machine != 'aarch64' or 'tegra' not in platform_release", specifier = ">=2.9.0.dev0,<2.10.0", index = "https://download.pytorch.org/whl/nightly/cu129" }, { name = "torch", marker = "platform_machine == 'aarch64' and 'tegra' in platform_release", specifier = ">=2.7.0,<2.8.0" },