Skip to content

Commit 3967fb0

Browse files
committed
Merge branch 'inocsin-fix_fallback_inputs' into 'release/1.0'
feat: support setting input types of subgraph in fallback, handle Tensor type... See merge request adlsa/TRTorch!10
2 parents a8a407f + 77bf9da commit 3967fb0

File tree

7 files changed

+51
-8
lines changed

7 files changed

+51
-8
lines changed

core/conversion/conversion.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
#include "core/util/prelude.h"
99

1010
#include "c10/util/intrusive_ptr.h"
11+
#include "core/conversion/converters/converter_util.h"
1112
#include "core/conversion/tensorcontainer/TensorContainer.h"
13+
#include "core/util/trt_util.h"
1214

1315
namespace torch_tensorrt {
1416
namespace core {
@@ -212,6 +214,21 @@ void MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outp
212214
LOG_INFO(
213215
ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)");
214216
ctx->num_outputs += 1;
217+
} else if (out_ivalue.isTuple()) {
218+
TORCHTRT_THROW_ERROR("Tuple type. Only a single tensor or a TensorList type is supported.");
219+
} else if (out_ivalue.isList()) {
220+
TORCHTRT_THROW_ERROR("List type. Only a single tensor or a TensorList type is supported.");
221+
} else if (out_ivalue.isScalar()) {
222+
TORCHTRT_THROW_ERROR("Scalar type. Only a single tensor or a TensorList type is supported.");
223+
} else if (out_ivalue.isTensor()) {
224+
// prim::NumToTensor will go to here
225+
std::string name = std::string("output_") + std::to_string(ctx->num_outputs);
226+
auto out_tensor = converters::tensor_to_const(ctx, out_ivalue.toTensor(), "");
227+
out_tensor->setName(name.c_str());
228+
ctx->net->markOutput(*out_tensor);
229+
LOG_INFO(
230+
ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)");
231+
ctx->num_outputs += 1;
215232
} else {
216233
TORCHTRT_THROW_ERROR("Unknown output type. Only a single tensor or a TensorList type is supported.");
217234
}
@@ -364,6 +381,7 @@ void ConvertBlockToNetDef(
364381
ConversionInfo& build_info,
365382
ir::StaticParams& static_params) {
366383
LOG_INFO(ctx->logger, "Converting Block");
384+
LOG_DEBUG(ctx->logger, *b->owningGraph());
367385

368386
auto inputs = b->inputs();
369387
AddParamsToCtxValueMap(ctx, static_params);

core/partitioning/PartitionInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ struct PartitionInfo {
1212
bool enabled = false;
1313
uint64_t min_block_size = 1;
1414
std::vector<std::string> forced_fallback_operators;
15+
bool truncate_long_and_double;
1516
};
1617

1718
std::ostream& operator<<(std::ostream& os, const PartitionInfo& s);

core/partitioning/partitioning.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ PartitionedGraph Partition(
404404
registerSegmentsOutputs(segmented_blocks, block);
405405

406406
// run shape analysis on each segmented block
407-
runShapeAnalysis(segmented_blocks, example_tensor_map);
407+
runShapeAnalysis(segmented_blocks, example_tensor_map, partition_info);
408408

409409
LOG_INFO(segmented_blocks);
410410

core/partitioning/shape_analysis.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
3434

3535
void getSegmentsOutputByRunning(
3636
SegmentedBlock& seg_block,
37-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
37+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
38+
const PartitionInfo& partition_info) {
3839
// create a module to run the graph
3940
auto g = seg_block.g();
4041
auto copy_g = g->copy();
@@ -108,7 +109,28 @@ void getSegmentsOutputByRunning(
108109
std::vector<at::ScalarType> input_types;
109110
for (auto& i : seg_block.raw_inputs()) {
110111
if (ivalues_maps[i].isTensor()) {
111-
input_shapes.push_back(util::toVec(util::toDims(ivalues_maps[i].toTensor().sizes())));
112+
// set the input_shape and data_type
113+
at::ScalarType t = ivalues_maps[i].toTensor().scalar_type();
114+
if (!partition_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) {
115+
TORCHTRT_THROW_ERROR(
116+
"Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled");
117+
} else if (partition_info.truncate_long_and_double && t == at::kLong) {
118+
ivalues_maps[i] = ivalues_maps[i].toTensor().to(at::kInt);
119+
LOG_WARNING("Truncating graph input type from at::kLong to at::kInt");
120+
} else if (partition_info.truncate_long_and_double && t == at::kDouble) {
121+
ivalues_maps[i] = ivalues_maps[i].toTensor().to(at::kFloat);
122+
LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat");
123+
}
124+
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(ivalues_maps[i].toTensor().dtype());
125+
if (dtype == c10::nullopt) {
126+
TORCHTRT_THROW_ERROR("Unsupported input data type " << ivalues_maps[i].toTensor().dtype());
127+
}
128+
if (ivalues_maps[i].toTensor().sizes().size() == 0) {
129+
// handle Scalar types, which has sizes of []
130+
input_shapes.push_back(util::toVec(util::toDims(c10::List<long int>({1}))));
131+
} else {
132+
input_shapes.push_back(util::toVec(util::toDims(ivalues_maps[i].toTensor().sizes())));
133+
}
112134
input_types.push_back(ivalues_maps[i].toTensor().scalar_type());
113135
}
114136
}
@@ -119,11 +141,12 @@ void getSegmentsOutputByRunning(
119141

120142
void runShapeAnalysis(
121143
std::vector<SegmentedBlock>& segmented_blocks,
122-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map) {
144+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
145+
const PartitionInfo& partition_info) {
123146
// register every segment's input shape, and it's running output IValues
124147
for (auto& seg_block : segmented_blocks) {
125148
torch::jit::ConstantPooling(seg_block.g());
126-
getSegmentsOutputByRunning(seg_block, example_tensor_map);
149+
getSegmentsOutputByRunning(seg_block, example_tensor_map, partition_info);
127150
}
128151
return;
129152
}

core/partitioning/shape_analysis.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
1212

1313
void runShapeAnalysis(
1414
std::vector<SegmentedBlock>& segmented_blocks,
15-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps);
15+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
16+
const PartitionInfo& partition_info);
1617

1718
} // namespace partitioning
1819
} // namespace core

core/util/trt_util.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
239239
{at::kHalf, nvinfer1::DataType::kHALF},
240240
{at::kInt, nvinfer1::DataType::kINT32},
241241
{at::kChar, nvinfer1::DataType::kINT8},
242-
{at::kBool, nvinfer1::DataType::kBOOL},
243-
};
242+
{at::kBool, nvinfer1::DataType::kBOOL}};
244243
return at_trt_type_map;
245244
}
246245

cpp/src/compile_spec.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
5757
internal.partition_info.enabled = !external.require_full_compilation;
5858
internal.partition_info.min_block_size = external.min_block_size;
5959
internal.partition_info.forced_fallback_operators = std::move(external.torch_executed_ops);
60+
internal.partition_info.truncate_long_and_double = external.truncate_long_and_double;
6061
internal.lower_info.forced_fallback_modules = std::move(external.torch_executed_modules);
6162

6263
switch (external.device.device_type) {

0 commit comments

Comments
 (0)