@@ -347,6 +347,21 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
347
347
if (cfg.partition_info .enabled ) {
348
348
return CompileGraphWithFallback (mod, cfg);
349
349
}
350
+ auto device_spec = cfg.convert_info .engine_settings .device ;
351
+
352
+ // GPU default WS size : 1 GB
353
+ // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
354
+ auto workspace_size = cfg.convert_info .engine_settings .workspace_size ;
355
+ cudaDeviceProp device_prop;
356
+ cudaGetDeviceProperties (&device_prop, device_spec.gpu_id );
357
+ if (workspace_size == 0 ) {
358
+ if (device_prop.major < 6 ) {
359
+ cfg.convert_info .engine_settings .workspace_size = 256 * (1 << 20 );
360
+ } else {
361
+ cfg.convert_info .engine_settings .workspace_size = 1 << 30 ;
362
+ }
363
+ }
364
+
350
365
// TODO: Should be doing a functional transform but need PR #31978
351
366
// [jit] More robust mangling
352
367
// torch::jit::script::Module new_mod = mod.clone();
@@ -357,7 +372,6 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
357
372
if (method.name ().compare (" forward" ) == 0 ) {
358
373
auto engine = ConvertGraphToTRTEngine (mod, method.name (), cfg);
359
374
auto new_g = std::make_shared<torch::jit::Graph>();
360
- auto device_spec = cfg.convert_info .engine_settings .device ;
361
375
auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
362
376
AddEngineToGraph (new_mod, new_g, engine, cuda_device);
363
377
auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
0 commit comments