Skip to content

Commit 36da9fc

Browse files
authored
fix: Add check to ensure einsum converter has no more than 2 tensor inputs (#1439)
* fix: Add check to ensure einsum converter has correct args - TRT einsum implementation currently supports 2 inputs, however the converter will accept any number of inputs and TRT throws an error at compilation - `aten::einsum` converter now checks that the tensor argument list does not exceed 2 elements, and throws an informative error otherwise * Add escaped quotes so user can copy-paste printed solution
1 parent a22d23c commit 36da9fc

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

core/conversion/converters/impl/einsum.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ auto einsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat
1818
auto equation = args[0].unwrapToString();
1919
auto in = args[1].IValue()->toListRef();
2020

21+
TORCHTRT_CHECK(
22+
in.size() <= 2,
23+
"TensorRT currently supports up to 2 input tensors "
24+
<< "to einsum but operation had " << in.size()
25+
<< " input tensors, please specify torch_executed_ops=[\"aten::einsum\"] "
26+
<< "at compilation time to avoid this error.");
27+
2128
std::vector<nvinfer1::ITensor*> tensors;
2229

2330
// Populate vector of ITensor pointers

0 commit comments

Comments
 (0)