Skip to content

Commit a234335

Browse files
committed
feat!: Changing the default behavior for selecting the input type
BREAKING CHANGE: This commit changes the default behavior of the compiler where if the user does not specify an input data type explicity instead of using the enabled precision, now the compiler will inspect the model provided to infer the data type for the input that will not cause an error if the model was run in torch. In practice this means - If the weights are in FP32 for the first tensor calculation then default input type is FP32 - If the weights are in FP16 for the first tensor calculation then default input type is FP16 - etc. If the data type cannot be determined the compiler will default to FP32. This calculation is done per input tensor so if one input is inferred to use FP32 and another INT32 then the expected types will be the same (FP32, INT32) As was the same before if the user defines the data type explicitly or provides an example tensor the data type specified there will be respected Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 19ecc64 commit a234335

File tree

14 files changed

+310
-71
lines changed

14 files changed

+310
-71
lines changed

core/compiler.cpp

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -287,22 +287,45 @@ GraphAndMapping ConstructFallbackGraph(
287287
return {new_g, old_to_new_g};
288288
}
289289

290+
291+
void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptr<torch::jit::Graph>& g, ir::StaticParams& static_params, const util::InputTypeMap& first_use_type_map) {
292+
// Associate input specs with inputs
293+
cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));
294+
295+
for (auto& in : g->inputs()) {
296+
auto est_type_opt = first_use_type_map.find(in)->second;
297+
ir::Input& spec = cfg.convert_info.inputs.find(in)->second;
298+
if (est_type_opt && !spec.dtype_is_user_defined) {
299+
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated type
300+
LOG_INFO("Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
301+
<< in->debugName() << " has type " << est_type_opt.value() << ". If this is incorrect explicitly set dtype for input and file a bug");
302+
spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value());
303+
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
304+
// If we cannot calculate the type and the user did not define the type, then default to FP32
305+
LOG_WARNING(
306+
"Cannot deterime input type from calcuations in graph for input "
307+
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
308+
spec.dtype = nvinfer1::DataType::kFLOAT;
309+
} else {
310+
// The user defined the type so no changes are necessary
311+
}
312+
}
313+
}
314+
290315
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
291316
// Go through Lowering to simplify graph and extract weight parameters
292317
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);
293318

294-
auto convert_cfg = std::move(cfg.convert_info);
295319
auto g = graph_and_parameters.first;
296-
297320
auto params = graph_and_parameters.second;
298321
auto static_params = ir::get_static_params(g->inputs(), params);
322+
// Infer the type of an input from the weights of the calculation
323+
auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block());
299324

300-
LOG_INFO(*g << "(CompileGraph)\n");
325+
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
301326

302-
// Move the user defined inputs to the convert_cfg since some might be static;
303-
convert_cfg.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));
327+
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
304328

305-
auto engine = conversion::ConvertBlockToEngine(g->block(), convert_cfg, static_params);
306329
return std::move(engine);
307330
}
308331

@@ -331,27 +354,12 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
331354
auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.lower_info);
332355

333356
auto g = graph_and_parameters.first;
334-
LOG_INFO("Lowered Graph: " << *g);
335357
auto params = graph_and_parameters.second;
336358
auto static_params = ir::get_static_params(g->inputs(), params);
337-
338-
cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));
339-
340-
// If the user did not explicitly set the input type, then use the first
341-
// tensor calculation to infer type.
359+
// Infer the type of an input from the weights of the calculation
342360
auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block());
343-
for (auto& in : g->inputs()) {
344-
auto est_type_opt = first_use_types[in];
345-
ir::Input& spec = cfg.convert_info.inputs.find(in)->second;
346-
if (est_type_opt && !spec.dtype_is_user_defined) {
347-
spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value());
348-
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
349-
LOG_WARNING(
350-
"Cannot deterime input type from calcuations in graph for input "
351-
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
352-
spec.dtype = nvinfer1::DataType::kFLOAT;
353-
}
354-
}
361+
362+
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
355363

356364
if (cfg.partition_info.enabled) {
357365
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
9797
// Is this necessary?
9898
// lowering::LowerBlock(g->block());
9999

100+
LOG_INFO("Lowered Graph: " << *(graph_and_ivalues.first));
100101
return graph_and_ivalues;
101102
}
102103

core/util/jit_util.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,8 @@ c10::optional<at::ScalarType> get_value_first_calc_dtype_opt(torch::jit::Block*
9696
return dtype;
9797
}
9898

99-
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> get_block_first_calc_dtypes_opt(
100-
torch::jit::Block* b) {
101-
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> types;
99+
InputTypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b) {
100+
InputTypeMap types;
102101

103102
for (auto i : b->inputs()) {
104103
if (i->type() == c10::TensorType::get()) {

core/util/jit_util.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ namespace trtorch {
99
namespace core {
1010
namespace util {
1111

12+
using InputTypeMap = std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>>;
13+
1214
inline std::string node_info(const torch::jit::Node* n) {
1315
std::stringstream ss;
1416
ss << *n;
@@ -61,8 +63,7 @@ inline std::string GetPyTorchSourceCode(const torch::jit::Node* n) {
6163
}
6264

6365
c10::optional<at::ScalarType> get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in);
64-
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> get_block_first_calc_dtypes_opt(
65-
torch::jit::Block* b);
66+
InputTypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b);
6667

6768
} // namespace util
6869
} // namespace core

core/util/logging/TRTorchLogger.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ namespace {
125125

126126
TRTorchLogger& get_global_logger() {
127127
#ifndef NDEBUG
128-
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kDEBUG, true);
128+
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kGRAPH, true);
129129
#else
130130
static TRTorchLogger global_logger("[TRTorch] - ", LogLevel::kERROR, false);
131131
#endif

cpp/include/trtorch/trtorch.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ struct TRTORCH_API CompileSpec {
387387
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
388388
*
389389
* @param shape Input tensor shape
390-
* @param dtype Expected data type for the input (Defaults to Float32)
390+
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32)
391391
* @param format Expected tensor format for the input (Defaults to contiguous)
392392
*/
393393
Input(std::vector<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);
@@ -398,7 +398,7 @@ struct TRTORCH_API CompileSpec {
398398
* tensor format
399399
*
400400
* @param shape Input tensor shape
401-
* @param dtype Expected data type for the input (Defaults to Float32)
401+
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32)
402402
* @param format Expected tensor format for the input (Defaults to contiguous)
403403
*/
404404
Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
@@ -421,7 +421,7 @@ struct TRTORCH_API CompileSpec {
421421
* allow the user to configure expected input shape tensor format
422422
*
423423
* @param shape Input tensor shape
424-
* @param dtype Expected data type for the input (Defaults to Float32)
424+
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32)
425425
* @param format Expected tensor format for the input (Defaults to contiguous)
426426
*/
427427
Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
@@ -451,7 +451,7 @@ struct TRTORCH_API CompileSpec {
451451
* @param min_shape Minimum shape for input tensor
452452
* @param opt_shape Target optimization shape for input tensor
453453
* @param max_shape Maximum acceptible shape for input tensor
454-
* @param dtype Expected data type for the input (Defaults to Float32)
454+
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32)
455455
* @param format Expected tensor format for the input (Defaults to contiguous)
456456
*/
457457
Input(
@@ -486,7 +486,7 @@ struct TRTORCH_API CompileSpec {
486486
* @param min_shape Minimum shape for input tensor
487487
* @param opt_shape Target optimization shape for input tensor
488488
* @param max_shape Maximum acceptible shape for input tensor
489-
* @param dtype Expected data type for the input (Defaults to Float32)
489+
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32)
490490
* @param format Expected tensor format for the input (Defaults to contiguous)
491491
*/
492492
Input(
@@ -506,14 +506,9 @@ struct TRTORCH_API CompileSpec {
506506
*/
507507
Input(at::Tensor tensor);
508508

509-
bool get_explicit_set_dtype() {
510-
return explicit_set_dtype;
511-
}
512-
513509
private:
514510
friend std::ostream& operator<<(std::ostream& os, const Input& input);
515511
bool input_is_dynamic;
516-
bool explicit_set_dtype;
517512
};
518513

519514
/**

cpp/src/compile_spec.cpp

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ std::ostream& operator<<(std::ostream& os, const CompileSpec::Input& input) {
7373
}
7474

7575
nvinfer1::DataType toTRTDataType(CompileSpec::DataType value) {
76-
TRTORCH_CHECK(!(value == CompileSpec::DataType::kUnknown), "Data type is unknown");
7776
switch (value) {
7877
case CompileSpec::DataType::kChar:
7978
return nvinfer1::DataType::kINT8;
@@ -162,8 +161,7 @@ CompileSpec::Input::Input(std::vector<int64_t> shape, TensorFormat format) {
162161
this->min_shape = shape;
163162
this->max_shape = shape;
164163
this->shape = shape;
165-
this->dtype = dtype;
166-
this->explicit_set_dtype = false;
164+
this->dtype = CompileSpec::DataType::kUnknown;
167165
this->format = format;
168166
this->input_is_dynamic = false;
169167
}
@@ -174,7 +172,6 @@ CompileSpec::Input::Input(std::vector<int64_t> shape, DataType dtype, TensorForm
174172
this->max_shape = shape;
175173
this->shape = shape;
176174
this->dtype = dtype;
177-
this->explicit_set_dtype = true;
178175
this->format = format;
179176
this->input_is_dynamic = false;
180177
}
@@ -184,8 +181,7 @@ CompileSpec::Input::Input(c10::IntArrayRef shape, TensorFormat format) {
184181
this->min_shape = core::util::toVec(shape);
185182
this->max_shape = core::util::toVec(shape);
186183
this->shape = core::util::toVec(shape);
187-
this->dtype = DataType::kFloat;
188-
this->explicit_set_dtype = false;
184+
this->dtype = CompileSpec::DataType::kUnknown;
189185
this->format = format;
190186
this->input_is_dynamic = false;
191187
}
@@ -196,7 +192,6 @@ CompileSpec::Input::Input(c10::IntArrayRef shape, DataType dtype, TensorFormat f
196192
this->max_shape = core::util::toVec(shape);
197193
this->shape = core::util::toVec(shape);
198194
this->dtype = dtype;
199-
this->explicit_set_dtype = true;
200195
this->format = format;
201196
this->input_is_dynamic = false;
202197
}
@@ -210,8 +205,7 @@ CompileSpec::Input::Input(
210205
this->min_shape = min_shape;
211206
this->max_shape = max_shape;
212207
this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape);
213-
this->dtype = dtype;
214-
this->explicit_set_dtype = false;
208+
this->dtype = CompileSpec::DataType::kUnknown;
215209
this->format = format;
216210
this->input_is_dynamic = true;
217211
}
@@ -227,7 +221,6 @@ CompileSpec::Input::Input(
227221
this->max_shape = max_shape;
228222
this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape);
229223
this->dtype = dtype;
230-
this->explicit_set_dtype = true;
231224
this->format = format;
232225
this->input_is_dynamic = true;
233226
}
@@ -241,8 +234,7 @@ CompileSpec::Input::Input(
241234
this->min_shape = core::util::toVec(min_shape);
242235
this->max_shape = core::util::toVec(max_shape);
243236
this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape);
244-
this->dtype = dtype;
245-
this->explicit_set_dtype = false;
237+
this->dtype = CompileSpec::DataType::kUnknown;
246238
this->format = format;
247239
this->input_is_dynamic = true;
248240
}
@@ -258,7 +250,6 @@ CompileSpec::Input::Input(
258250
this->max_shape = core::util::toVec(max_shape);
259251
this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape);
260252
this->dtype = dtype;
261-
this->explicit_set_dtype = true;
262253
this->format = format;
263254
this->input_is_dynamic = true;
264255
}
@@ -269,7 +260,6 @@ CompileSpec::Input::Input(at::Tensor tensor) {
269260
this->max_shape = tensor.sizes().vec();
270261
this->shape = tensor.sizes().vec();
271262
this->dtype = tensor.scalar_type();
272-
this->explicit_set_dtype = true;
273263
TRTORCH_ASSERT(
274264
tensor.is_contiguous(at::MemoryFormat::ChannelsLast) || tensor.is_contiguous(at::MemoryFormat::Contiguous),
275265
"Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last");
@@ -292,7 +282,7 @@ core::ir::Input to_internal_input(CompileSpec::Input& i) {
292282
i.max_shape,
293283
toTRTDataType(i.dtype),
294284
toTRTTensorFormat(i.format),
295-
i.get_explicit_set_dtype());
285+
!(i.dtype == CompileSpec::DataType::kUnknown));
296286
}
297287

298288
std::vector<core::ir::Input> to_vec_internal_inputs(std::vector<CompileSpec::Input>& external) {

py/trtorch/Input.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class _ShapeMode(Enum):
3030

3131
shape_mode = None #: (trtorch.Input._ShapeMode): Is input statically or dynamically shaped
3232
shape = None #: (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
33-
dtype = _types.dtype.float32 #: The expected data type of the input tensor (default: trtorch.dtype.float32)
33+
dtype = _types.dtype.unknown #: The expected data type of the input tensor (default: trtorch.dtype.float32)
3434
_explicit_set_dtype = False
3535
format = _types.TensorFormat.contiguous #: The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)
3636

@@ -133,16 +133,44 @@ def __str__(self) -> str:
133133
def _to_internal(self) -> trtorch._C.Input:
134134
internal_in = trtorch._C.Input()
135135
if self.shape_mode == Input._ShapeMode.DYNAMIC:
136-
internal_in.min = self.shape["min_shape"]
137-
internal_in.opt = self.shape["opt_shape"]
138-
internal_in.max = self.shape["max_shape"]
136+
if not Input._supported_input_size_type(self.shape["min_shape"]):
137+
raise TypeError(
138+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
139+
+ str(type(self.shape["min_shape"])) + " for min_shape")
140+
else:
141+
internal_in.min = self.shape["min_shape"]
142+
143+
if not Input._supported_input_size_type(self.shape["opt_shape"]):
144+
raise TypeError(
145+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
146+
+ str(type(self.shape["opt_shape"])) + " for opt_shape")
147+
else:
148+
internal_in.min = self.shape["op_shape"]
149+
150+
if not Input._supported_input_size_type(self.shape["max_shape"]):
151+
raise TypeError(
152+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
153+
+ str(type(self.shape["max_shape"])) + " for max_shape")
154+
else:
155+
internal_in.min = self.shape["opt_shape"]
139156
internal_in.input_is_dynamic = True
140157
else:
141-
internal_in.opt = self.shape
158+
if not Input._supported_input_size_type(self.shape):
159+
raise TypeError(
160+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
161+
+ str(type(self.shape)) + " for shape")
162+
else:
163+
internal_in.opt = self.shape
142164
internal_in.input_is_dynamic = False
143-
internal_in.dtype = self.dtype
165+
166+
if self.dtype != _types.dtype.unknown:
167+
self._explicit_set_dtype = True
168+
else:
169+
self._explicit_set_dtype = False
170+
171+
internal_in.dtype = Input._parse_dtype(self.dtype)
144172
internal_in._explicit_set_dtype = self._explicit_set_dtype
145-
internal_in.format = self.format
173+
internal_in.format = Input._parse_format(self.format)
146174
return internal_in
147175

148176
@staticmethod
@@ -172,7 +200,7 @@ def _parse_dtype(dtype: Any) -> _types.dtype:
172200
"Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: "
173201
+ str(dtype))
174202

175-
elif isinstance(dtype, _types.DataTypes):
203+
elif isinstance(dtype, _types.dtype):
176204
return dtype
177205

178206
else:

py/trtorch/csrc/tensorrt_classes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ nvinfer1::DataType toTRTDataType(DataType value) {
3333
return nvinfer1::DataType::kBOOL;
3434
case DataType::kFloat:
3535
return nvinfer1::DataType::kFLOAT;
36+
case DataType::kUnknown:
37+
return nvinfer1::DataType::kFLOAT;
3638
default:
3739
TRTORCH_THROW_ERROR("Unknown data type: " << to_str(value));
3840
}

py/trtorch/csrc/tensorrt_classes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace pyapi {
2727
return static_cast<int64_t>(field_name); \
2828
}
2929

30-
enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool };
30+
enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool, kUnknown };
3131
std::string to_str(DataType value);
3232
nvinfer1::DataType toTRTDataType(DataType value);
3333

py/trtorch/csrc/trtorch_py.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ PYBIND11_MODULE(_C, m) {
186186
.value("int8", DataType::kChar, "8 bit integer number")
187187
.value("int32", DataType::kInt32, "32 bit integer number")
188188
.value("bool", DataType::kChar, "Boolean value")
189+
.value("unknown", DataType::kUnknown, "Unknown data type")
189190
.export_values();
190191

191192
py::enum_<DeviceType>(m, "DeviceType", "Enum to specify device kinds to build TensorRT engines for")

0 commit comments

Comments
 (0)