@@ -52,7 +52,7 @@ trtorch::CompileSpec::TensorFormat parseTensorFormat(std::string str) {
5252 } else {
5353 trtorch::logging::log (
5454 trtorch::logging::Level::kERROR ,
55- " Invalid tensor format, options are [ linear | nchw | chw | contiguous | nhwc | hwc | channels_last ]" );
55+ " Invalid tensor format, options are [ linear | nchw | chw | contiguous | nhwc | hwc | channels_last ], found: " + str );
5656 return trtorch::CompileSpec::TensorFormat::kUnknown ;
5757 }
5858}
@@ -73,7 +73,7 @@ trtorch::CompileSpec::DataType parseDataType(std::string dtype_str) {
7373 } else {
7474 trtorch::logging::log (
7575 trtorch::logging::Level::kERROR ,
76- " Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b]" );
76+ " Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b], found: " + dtype_str );
7777 return trtorch::CompileSpec::DataType::kUnknown ;
7878 }
7979}
@@ -221,8 +221,8 @@ int main(int argc, char** argv) {
221221 " type" ,
222222 " The type of device the engine should be built for [ gpu | dla ] (default: gpu)" ,
223223 {' d' , " device-type" });
224- args::ValueFlag<int > gpu_id (parser, " gpu_id" , " GPU id if running on multi-GPU platform (defaults to 0)" , {" gpu-id" });
225- args::ValueFlag<int > dla_core (
224+ args::ValueFlag<uint64_t > gpu_id (parser, " gpu_id" , " GPU id if running on multi-GPU platform (defaults to 0)" , {" gpu-id" });
225+ args::ValueFlag<uint64_t > dla_core (
226226 parser, " dla_core" , " DLACore id if running on available DLA (defaults to 0)" , {" dla-core" });
227227
228228 args::ValueFlag<std::string> engine_capability (
@@ -243,13 +243,13 @@ int main(int argc, char** argv) {
243243 " Whether to treat input file as a serialized TensorRT engine and embed it into a TorchScript module (device spec must be provided)" ,
244244 {" embed-engine" });
245245
246- args::ValueFlag<int > num_min_timing_iters (
246+ args::ValueFlag<uint64_t > num_min_timing_iters (
247247 parser, " num_iters" , " Number of minimization timing iterations used to select kernels" , {" num-min-timing-iter" });
248- args::ValueFlag<int > num_avg_timing_iters (
248+ args::ValueFlag<uint64_t > num_avg_timing_iters (
249249 parser, " num_iters" , " Number of averaging timing iterations used to select kernels" , {" num-avg-timing-iters" });
250- args::ValueFlag<int > workspace_size (
250+ args::ValueFlag<uint64_t > workspace_size (
251251 parser, " workspace_size" , " Maximum size of workspace given to TensorRT" , {" workspace-size" });
252- args::ValueFlag<int > max_batch_size (
252+ args::ValueFlag<uint64_t > max_batch_size (
253253 parser, " max_batch_size" , " Maximum batch size (must be >= 1 to be set, 0 means not set)" , {" max-batch-size" });
254254 args::ValueFlag<double > threshold (
255255 parser,
@@ -276,8 +276,8 @@ int main(int argc, char** argv) {
276276 std::cout << parser;
277277 return 0 ;
278278 } catch (args::ParseError e) {
279- std::cerr << e.what () << std::endl ;
280- std::cerr << parser;
279+ trtorch::logging::log (trtorch::logging::Level:: kERROR , e.what ()) ;
280+ std::cerr << std::endl << parser;
281281 return 1 ;
282282 }
283283
@@ -309,13 +309,13 @@ int main(int argc, char** argv) {
309309 auto parsed_dtype = parseDataType (dtype);
310310 if (parsed_dtype == trtorch::CompileSpec::DataType::kUnknown ) {
311311 trtorch::logging::log (trtorch::logging::Level::kERROR , " Invalid datatype for input specification " + spec);
312- std::cerr << parser;
312+ std::cerr << std::endl << parser;
313313 exit (1 );
314314 }
315315 auto parsed_format = parseTensorFormat (format);
316316 if (parsed_format == trtorch::CompileSpec::TensorFormat::kUnknown ) {
317317 trtorch::logging::log (trtorch::logging::Level::kERROR , " Invalid format for input specification " + spec);
318- std::cerr << parser;
318+ std::cerr << std::endl << parser;
319319 exit (1 );
320320 }
321321 if (shapes.rfind (" (" , 0 ) == 0 ) {
@@ -326,7 +326,7 @@ int main(int argc, char** argv) {
326326 trtorch::CompileSpec::Input (dyn_shapes[0 ], dyn_shapes[1 ], dyn_shapes[2 ], parsed_dtype, parsed_format));
327327 } else {
328328 trtorch::logging::log (trtorch::logging::Level::kERROR , spec_err_str);
329- std::cerr << parser;
329+ std::cerr << std::endl << parser;
330330 exit (1 );
331331 }
332332 // THERE IS NO SPEC FOR FORMAT
@@ -337,7 +337,7 @@ int main(int argc, char** argv) {
337337 auto parsed_dtype = parseDataType (dtype);
338338 if (parsed_dtype == trtorch::CompileSpec::DataType::kUnknown ) {
339339 trtorch::logging::log (trtorch::logging::Level::kERROR , " Invalid datatype for input specification " + spec);
340- std::cerr << parser;
340+ std::cerr << std::endl << parser;
341341 exit (1 );
342342 }
343343 if (shapes.rfind (" (" , 0 ) == 0 ) {
@@ -347,7 +347,7 @@ int main(int argc, char** argv) {
347347 ranges.push_back (trtorch::CompileSpec::Input (dyn_shapes[0 ], dyn_shapes[1 ], dyn_shapes[2 ], parsed_dtype));
348348 } else {
349349 trtorch::logging::log (trtorch::logging::Level::kERROR , spec_err_str);
350- std::cerr << parser;
350+ std::cerr << std::endl << parser;
351351 exit (1 );
352352 }
353353 }
@@ -359,7 +359,7 @@ int main(int argc, char** argv) {
359359 auto parsed_format = parseTensorFormat (format);
360360 if (parsed_format == trtorch::CompileSpec::TensorFormat::kUnknown ) {
361361 trtorch::logging::log (trtorch::logging::Level::kERROR , " Invalid format for input specification " + spec);
362- std::cerr << parser;
362+ std::cerr << std::endl << parser;
363363 exit (1 );
364364 }
365365 if (shapes.rfind (" (" , 0 ) == 0 ) {
@@ -369,7 +369,7 @@ int main(int argc, char** argv) {
369369 ranges.push_back (trtorch::CompileSpec::Input (dyn_shapes[0 ], dyn_shapes[1 ], dyn_shapes[2 ], parsed_format));
370370 } else {
371371 trtorch::logging::log (trtorch::logging::Level::kERROR , spec_err_str);
372- std::cerr << parser;
372+ std::cerr << std::endl << parser;
373373 exit (1 );
374374 }
375375 // JUST SHAPE USE DEFAULT DTYPE
@@ -381,7 +381,7 @@ int main(int argc, char** argv) {
381381 ranges.push_back (trtorch::CompileSpec::Input (dyn_shapes[0 ], dyn_shapes[1 ], dyn_shapes[2 ]));
382382 } else {
383383 trtorch::logging::log (trtorch::logging::Level::kERROR , spec_err_str);
384- std::cerr << parser;
384+ std::cerr << std::endl << parser;
385385 exit (1 );
386386 }
387387 }
@@ -430,14 +430,15 @@ int main(int argc, char** argv) {
430430 trtorch::logging::log (
431431 trtorch::logging::Level::kERROR ,
432432 " If targeting INT8 default operating precision with trtorchc, a calibration cache file must be provided" );
433- std::cerr << parser;
434433 return 1 ;
435434 }
436435 } else {
436+ std::stringstream ss;
437+ ss << " Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 ], found: " ;
438+ ss << dtype;
437439 trtorch::logging::log (
438- trtorch::logging::Level::kERROR ,
439- " Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 ]" );
440- std::cerr << parser;
440+ trtorch::logging::Level::kERROR , ss.str ());
441+ std::cerr << std::endl << parser;
441442 return 1 ;
442443 }
443444 }
@@ -460,8 +461,8 @@ int main(int argc, char** argv) {
460461 compile_settings.device .dla_core = args::get (dla_core);
461462 }
462463 } else {
463- trtorch::logging::log (trtorch::logging::Level::kERROR , " Invalid device type, options are [ gpu | dla ]" );
464- std::cerr << parser;
464+ trtorch::logging::log (trtorch::logging::Level::kERROR , " Invalid device type, options are [ gpu | dla ] found: " + device );
465+ std::cerr << std::endl << parser;
465466 return 1 ;
466467 }
467468 }
@@ -479,7 +480,7 @@ int main(int argc, char** argv) {
479480 } else {
480481 trtorch::logging::log (
481482 trtorch::logging::Level::kERROR , " Invalid engine capability, options are [ default | safe_gpu | safe_dla ]" );
482- std::cerr << parser;
483+ std::cerr << std::endl << parser;
483484 return 1 ;
484485 }
485486 }
@@ -517,7 +518,6 @@ int main(int argc, char** argv) {
517518 mod = torch::jit::load (real_input_path);
518519 } catch (const c10::Error& e) {
519520 trtorch::logging::log (trtorch::logging::Level::kERROR , " Error loading the model (path may be incorrect)" );
520- std::cerr << parser;
521521 return 1 ;
522522 }
523523
0 commit comments