Skip to content

Commit ff2d955

Browse files
authored
Merge pull request #1167 from pytorch/fix_renaming_itensor
fix: converter renaming already named tensors
2 parents 84bad88 + 248d8aa commit ff2d955

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

core/conversion/conversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ void AddInputs(
188188
ctx->input_is_dynamic = true;
189189
}
190190

191-
ctx->value_tensor_map[in] = trt_in;
191+
ctx->RecordNewITensor(in, trt_in);
192192
ctx->num_inputs += 1;
193193
}
194194

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ ConversionCtx::~ConversionCtx() {
143143
}
144144

145145
nvinfer1::ITensor* ConversionCtx::AssociateValueAndTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor) {
146-
tensor->setName(value->debugName().c_str());
147-
this->value_tensor_map[value] = tensor;
146+
RecordNewITensor(value, tensor);
147+
148148
return tensor;
149149
}
150150

@@ -153,6 +153,15 @@ torch::jit::IValue* ConversionCtx::AssociateValueAndIValue(const torch::jit::Val
153153
return &this->evaluated_value_map[value];
154154
}
155155

156+
void ConversionCtx::RecordNewITensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor) {
157+
value_tensor_map[value] = tensor;
158+
auto ret = seen_itensors.insert(tensor);
159+
if (!ret.second) {
160+
LOG_WARNING(
161+
"Trying to record the value " << value->debugName() << " with the ITensor " << tensor->getName() << " again.");
162+
}
163+
}
164+
156165
std::string ConversionCtx::SerializeEngine() {
157166
#if NV_TENSORRT_MAJOR > 7
158167
auto serialized_network = builder->buildSerializedNetwork(*net, *cfg);

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ struct ConversionCtx {
4848
ConversionCtx(BuilderSettings settings);
4949
std::string SerializeEngine();
5050
nvinfer1::ITensor* AssociateValueAndTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor);
51+
void RecordNewITensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor);
5152
torch::jit::IValue* AssociateValueAndIValue(const torch::jit::Value* value, torch::jit::IValue tensor);
5253
bool CheckLayerAddition(const torch::jit::Node* n);
5354

@@ -71,6 +72,9 @@ struct ConversionCtx {
7172

7273
std::unordered_map<const torch::jit::Value*, nvinfer1::ITensor*> value_tensor_map;
7374
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> evaluated_value_map;
75+
76+
// record already named ITensors to prevent rewriting another name to the same tensor
77+
std::unordered_set<nvinfer1::ITensor*> seen_itensors;
7478
};
7579

7680
} // namespace conversion

0 commit comments

Comments
 (0)