Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ TRTEngine::TRTEngine(
bool hardware_compatible,
bool requires_output_allocator,
const std::string& serialized_metadata,
const ResourceAllocationStrategy& resource_allocation_strategy)
const ResourceAllocationStrategy resource_allocation_strategy)
: TRTEngine(
"deserialized_trt",
serialized_engine,
Expand All @@ -86,7 +86,7 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
static_cast<bool>(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])),
serialized_info[SERIALIZED_METADATA_IDX],
resource_allocation_strategy_from_string(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) {}
(static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic : ResourceAllocationStrategy::kStatic)) {}

TRTEngine::TRTEngine(
const std::string& mod_name,
Expand All @@ -98,7 +98,7 @@ TRTEngine::TRTEngine(
bool hardware_compatible,
bool requires_output_allocator,
const std::string& serialized_metadata,
const ResourceAllocationStrategy& resource_allocation_strategy) {
const ResourceAllocationStrategy resource_allocation_strategy) {
TORCHTRT_CHECK(
is_supported_on_current_platform(target_platform),
"This engine was not built to run on this platform (built for: " << target_platform << ", current platform: "
Expand Down Expand Up @@ -128,9 +128,11 @@ TRTEngine::TRTEngine(
cuda_engine->setWeightStreamingBudgetV2(budget_bytes);
}

this->resource_allocation_strategy = resource_allocation_strategy;
LOG_DEBUG("Resource allocation strategy: " << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static"));
if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) {
this->exec_ctx =
make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kON_PROFILE_CHANGE));
make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
} else {
this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
}
Expand Down Expand Up @@ -402,6 +404,7 @@ std::string TRTEngine::to_str() const {
ss << " Device: " << device_info << std::endl;
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
ss << " Target Platform: " << target_platform << std::endl;
ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl;
// clang-format on
return ss.str();
}
Expand Down Expand Up @@ -469,8 +472,7 @@ std::vector<std::string> TRTEngine::serialize() {
serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this->requires_output_allocator ? "1" : "0";
serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata;
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();
serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] =
resource_allocation_strategy_to_string(this->resource_allocation_strategy);
serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";

return serialized_info;
}
Expand All @@ -483,11 +485,12 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt
if (new_strategy != this->resource_allocation_strategy) {
this->resource_allocation_strategy = new_strategy;
if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {
std::cout << "Setting resource allocation strategy to dynamic" << std::endl;
this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
LOG_DEBUG("Setting resource allocation strategy to dynamic");
this->exec_ctx = make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
} else {
LOG_DEBUG("Setting resource allocation strategy to static");
this->exec_ctx = make_trt(
cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kON_PROFILE_CHANGE));
cuda_engine->createExecutionContext());
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class DynamicOutputAllocator : public nvinfer1::IOutputAllocator {

struct TRTEngine : torch::CustomClassHolder {
// Resource Allocation Strategy
enum ResourceAllocationStrategy { kStatic, kDynamic };
typedef enum { kStatic = 0, kDynamic } ResourceAllocationStrategy;
// Each engine needs it's own runtime object
std::shared_ptr<nvinfer1::IRuntime> rt;
std::shared_ptr<nvinfer1::ICudaEngine> cuda_engine;
Expand Down Expand Up @@ -132,7 +132,7 @@ struct TRTEngine : torch::CustomClassHolder {
bool hardware_compatible = false,
bool requires_output_allocator = false,
const std::string& serialized_metadata = "",
const TRTEngine::ResourceAllocationStrategy& resource_allocation_strategy =
const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy =
TRTEngine::ResourceAllocationStrategy::kStatic);

TRTEngine(std::vector<std::string> serialized_info);
Expand All @@ -147,7 +147,7 @@ struct TRTEngine : torch::CustomClassHolder {
bool hardware_compatible = false,
bool requires_output_allocator = false,
const std::string& serialized_metadata = "",
const TRTEngine::ResourceAllocationStrategy& resource_allocation_strategy =
const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy =
TRTEngine::ResourceAllocationStrategy::kStatic);

TRTEngine& operator=(const TRTEngine& other);
Expand Down
18 changes: 2 additions & 16 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,6 @@ std::string serialize_bindings(const std::vector<std::string>& bindings) {
return serialized_binding_info;
}

std::string resource_allocation_strategy_to_string(TRTEngine::ResourceAllocationStrategy strategy) {
if (strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {
return std::string("kDynamic");
} else {
return std::string("kStatic");
}
}

TRTEngine::ResourceAllocationStrategy resource_allocation_strategy_from_string(const std::string& str) {
if (str == "kDynamic")
return TRTEngine::ResourceAllocationStrategy::kDynamic;
else
return TRTEngine::ResourceAllocationStrategy::kStatic;
}

static const std::string sym_table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; //=
std::string base64_encode(const std::string& in) {
std::string out;
Expand Down Expand Up @@ -106,7 +91,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("infer_outputs", &TRTEngine::infer_outputs)
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
.def(
"_use_dynamically_allocated_resources",
"use_dynamically_allocated_resources",
[](const c10::intrusive_ptr<TRTEngine>& self, bool dynamic) -> void {
self->set_resource_allocation_strategy(
dynamic ? TRTEngine::ResourceAllocationStrategy::kDynamic
Expand All @@ -124,6 +109,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> { return self->serialize(); },
[](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {
serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]);
LOG_DEBUG("Deserialized resource allocation strategy: " << (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? "Dynamic" : "Static"));
TRTEngine::verify_serialization_fmt(serialized_info);
return c10::make_intrusive<TRTEngine>(serialized_info);
});
Expand Down
16 changes: 11 additions & 5 deletions examples/dynamo/dynamic_memory_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from diffusers import DiffusionPipeline
import time
import gc

np.random.seed(5)
torch.manual_seed(5)
Expand All @@ -14,23 +15,28 @@
"use_python_runtime": False,
"enabled_precisions": {torch.float32},
"immutable_weights": False,
"lazy_engine_init": True,
"dynamically_allocate_resources": True

}

model = models.resnet152(pretrained=True).eval().to("cuda")
compiled_module = torch_trt.compile(model, inputs=inputs, **settings)
print((torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3)
compiled_module(*inputs)

breakpoint()
with torch_trt.dynamo.runtime.ResourceAllocatorContext(compiled_module):
time.sleep(30)
with torch_trt.dynamo.runtime.ResourceAllocationStrategy(compiled_module, dynamically_allocate_resources=False):
print(
"Memory used (GB):",
(torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3,
)
breakpoint()
compiled_module(*inputs)
gc.collect()
torch.cuda.empty_cache()
time.sleep(30)
print(
"Memory used (GB):",
(torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3,
)
breakpoint()
compiled_module(*inputs)
6 changes: 6 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def cross_compile_for_windows(
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
Expand Down Expand Up @@ -177,6 +178,7 @@ def cross_compile_for_windows(
enable_weight_streaming (bool): Enable weight streaming.
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -340,6 +342,7 @@ def cross_compile_for_windows(
"enable_weight_streaming": enable_weight_streaming,
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"dynamically_allocate_resources": dynamically_allocate_resources,
}

# disable the following settings is not supported for cross compilation for windows feature
Expand Down Expand Up @@ -440,6 +443,7 @@ def compile(
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -517,6 +521,7 @@ def compile(
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -690,6 +695,7 @@ def compile(
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
"dynamically_allocate_resources": dynamically_allocate_resources,
}

settings = CompilationSettings(**compilation_options)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
L2_LIMIT_FOR_TILING = -1
USE_DISTRIBUTED_MODE_TRACE = False
OFFLOAD_MODULE_TO_CPU = False
DYNAMICALLY_ALLOCATE_RESOURCES = False

if platform.system() == "Linux":
import pwd
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
DYNAMICALLY_ALLOCATE_RESOURCES,
DRYRUN,
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
Expand Down Expand Up @@ -97,6 +98,8 @@ class CompilationSettings:
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
offload_module_to_cpu (bool): Offload the model to CPU to reduce memory footprint during compilation
dynamically_allocate_resources (bool): Dynamically allocate resources for TensorRT engines
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -140,6 +143,7 @@ class CompilationSettings:
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
dynamically_allocate_resources: bool = DYNAMICALLY_ALLOCATE_RESOURCES

def __getstate__(self) -> dict[str, Any]:
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand Down
12 changes: 7 additions & 5 deletions py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import torch


class ResourceAllocatorContext(torch.nn.Module): # type: ignore[misc]
class ResourceAllocationStrategy(torch.nn.Module): # type: ignore[misc]
"""
ResourceAllocatorContext is a context manager module that temporarily enables dynamic resource allocation
ResourceAllocationStrategy is a context manager module that temporarily enables dynamic resource allocation
for all TRT submodules of the given compiled_module. When entering the context,
it sets these submodules to use dynamically allocated resources. Upon exiting, it restores them to their
original (static) resource allocation mode.
Expand All @@ -14,17 +14,19 @@ class ResourceAllocatorContext(torch.nn.Module): # type: ignore[misc]
def __init__(
self,
compiled_module: torch.nn.Module,
dynamically_allocate_resources: bool = True
) -> None:
super(ResourceAllocatorContext, self).__init__()
super(ResourceAllocationStrategy, self).__init__()
self.compiled_module = compiled_module
self.dynamically_allocate_resources = dynamically_allocate_resources

def __enter__(self) -> None:
print("Entering resource allocator context")
for name, submodule in self.compiled_module.named_modules():
if "_run_on_acc" in name:
submodule.use_dynamically_allocated_resources(dynamic=True)
submodule.use_dynamically_allocated_resources(dynamically_allocate_resources=self.dynamically_allocate_resources)

def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
for name, submodule in self.compiled_module.named_modules():
if "_run_on_acc" in name:
submodule.use_dynamically_allocated_resources(dynamic=False)
submodule.use_dynamically_allocated_resources(dynamically_allocate_resources=self.dynamically_allocate_resources)
11 changes: 7 additions & 4 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(
self.serialized_engine = serialized_engine
self.engine = None
self.requires_output_allocator = requires_output_allocator
self.resource_allocation_strategy = 0 # Default to static allocation TODO: Make this configurable with the context manager
self.dynamically_allocate_resources = settings.dynamically_allocate_resources

if (
serialized_engine
Expand Down Expand Up @@ -188,9 +188,11 @@ def _pack_engine_info(self) -> List[str | bytes]:
engine_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = str(
int(self.requires_output_allocator)
)
print(f"PROVIDED RESOURCE ALLOCATION STRATEGY: {self.dynamically_allocate_resources}")
engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str(
int(self.resource_allocation_strategy)
int(self.dynamically_allocate_resources)
)
print(engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX])

return engine_info

Expand Down Expand Up @@ -219,8 +221,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
def _reset_captured_graph(self) -> None:
self.engine.reset_captured_graph()

def use_dynamically_allocated_resources(self, dynamic: bool = False) -> None:
self.engine._use_dynamically_allocated_resources(dynamic)
def use_dynamically_allocated_resources(self, dynamically_allocate_resources: bool = False) -> None:
self.dynamically_allocate_resources = dynamically_allocate_resources
self.engine.use_dynamically_allocated_resources(self.dynamically_allocate_resources)

def setup_engine(self) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
PythonTorchTensorRTModule,
)
from torch_tensorrt.dynamo.runtime._ResourceAllocator import ( # noqa: F401
ResourceAllocatorContext,
ResourceAllocationStrategy,
)
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: F401
TorchTensorRTModule,
Expand Down
Loading
Loading