diff --git a/core/compiler.cpp b/core/compiler.cpp index e22f57bc71..35ad0e204f 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -347,6 +347,21 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C if (cfg.partition_info.enabled) { return CompileGraphWithFallback(mod, cfg); } + auto device_spec = cfg.convert_info.engine_settings.device; + + // GPU default WS size : 1 GB + // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X. + auto workspace_size = cfg.convert_info.engine_settings.workspace_size; + cudaDeviceProp device_prop; + cudaGetDeviceProperties(&device_prop, device_spec.gpu_id); + if (workspace_size == 0) { + if (device_prop.major < 6) { + cfg.convert_info.engine_settings.workspace_size = 256 * (1 << 20); + } else { + cfg.convert_info.engine_settings.workspace_size = 1 << 30; + } + } + // TODO: Should be doing a functional transform but need PR #31978 // [jit] More robust mangling // torch::jit::script::Module new_mod = mod.clone(); @@ -357,7 +372,6 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C if (method.name().compare("forward") == 0) { auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg); auto new_g = std::make_shared(); - auto device_spec = cfg.convert_info.engine_settings.device; auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type); AddEngineToGraph(new_mod, new_g, engine, cuda_device); auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g); diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 1b86bba4f2..2496da4ebe 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -58,7 +58,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) net = make_trt( builder->createNetworkV2(1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH))); - LOG_DEBUG(build_settings); + LOG_INFO(settings); cfg = make_trt(builder->createBuilderConfig()); for (auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {