Skip to content

Commit ef62f6b

Browse files
authored
Merge pull request #892 from NVIDIA/support_aten_extend
Support aten extend
2 parents 0230cc6 + 97eb4eb commit ef62f6b

File tree

4 files changed

+61
-5
lines changed

4 files changed

+61
-5
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,31 @@ auto aten_registrations TORCHTRT_UNUSED =
285285
EvalOptions().validSchemas({
286286
"aten::append.t(t[](a!) self, t(c -> *) el) -> (t[](a!))",
287287
})})
288+
.evaluator({c10::Symbol::fromQualString("aten::extend"),
289+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
290+
if (args.at(n->input(0)).IValue()->isList() && args.at(n->input(1)).IValue()->isList()) {
291+
c10::IValue* self_ptr = args.at(n->input(0)).IValueMut();
292+
auto self = self_ptr->to<c10::List<c10::IValue>>();
293+
auto other = args.at(n->input(1)).IValue()->to<c10::List<c10::IValue>>();
294+
const int64_t other_size = other.size();
295+
296+
// Modify value in place
297+
for (int64_t i = 0; i < other_size; i++) {
298+
self.push_back(other.get(i));
299+
}
300+
301+
*self_ptr = c10::IValue(self);
302+
return {};
303+
} else {
304+
TORCHTRT_THROW_ERROR(
305+
"Unimplemented data type for aten::extend.t evaluator: "
306+
<< args.at(n->input(0)).IValue()->type()->str() << ", "
307+
<< args.at(n->input(1)).IValue()->type()->str());
308+
}
309+
},
310+
EvalOptions().validSchemas({
311+
"aten::extend.t(t[](a!) self, t[] other) -> ()",
312+
})})
288313
.evaluator({c10::Symbol::fromQualString("aten::neg"),
289314
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
290315
auto el = args.at(n->input(0)).unwrapToInt();

core/conversion/var/Var.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Var::Var() {
1313
type_ = Type::kNone;
1414
}
1515

16-
Var::Var(const torch::jit::IValue* p) : type_(Type::kIValue) {
16+
Var::Var(torch::jit::IValue* p) : type_(Type::kIValue) {
1717
ptr_.ivalue = p;
1818
}
1919

@@ -56,7 +56,7 @@ Var& Var::operator=(const Var& a) {
5656
return (*this);
5757
}
5858

59-
Var& Var::operator=(const torch::jit::IValue* in) {
59+
Var& Var::operator=(torch::jit::IValue* in) {
6060
ptr_.ivalue = in;
6161
type_ = Type::kIValue;
6262
return (*this);
@@ -116,6 +116,10 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
116116
}
117117

118118
const torch::jit::IValue* Var::IValue() const {
119+
return IValueMut();
120+
}
121+
122+
torch::jit::IValue* Var::IValueMut() const {
119123
TORCHTRT_CHECK(isIValue(), "Requested IValue from Var, however Var type is " << type_name());
120124
if (type_ == Type::kIValue) {
121125
return ptr_.ivalue;

core/conversion/var/Var.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@ class Var : torch::CustomClassHolder {
1717
enum Type { kITensor, kIValue, kNone };
1818

1919
Var();
20-
Var(const torch::jit::IValue* p);
20+
Var(torch::jit::IValue* p);
2121
Var(nvinfer1::ITensor* p);
2222
Var(const Var& a);
2323
Var& operator=(const Var& a);
24-
Var& operator=(const torch::jit::IValue* in);
24+
Var& operator=(torch::jit::IValue* in);
2525
Var& operator=(nvinfer1::ITensor* in);
2626
const torch::jit::IValue* IValue() const;
27+
torch::jit::IValue* IValueMut() const;
2728
nvinfer1::ITensor* ITensor() const;
2829

2930
// TODO: Can we consolidate this in a way that prevents requesting invalid
@@ -63,7 +64,7 @@ class Var : torch::CustomClassHolder {
6364

6465
private:
6566
union VarContainer {
66-
const torch::jit::IValue* ivalue;
67+
torch::jit::IValue* ivalue;
6768
nvinfer1::ITensor* tensor;
6869
void* none;
6970
};

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,32 @@ TEST(Evaluators, FloorFloatIntEvaluatesCorrectly) {
303303
ASSERT_TRUE(jit_results[0] == trt_results[0]);
304304
}
305305

306+
TEST(Evaluators, ATenExtendEvaluatesCorrectly) {
307+
const auto graph = R"IR(
308+
graph(%0 : Tensor, %1 : Tensor):
309+
%2 : int = prim::Constant[value=0]()
310+
%3 : Tensor[] = prim::ListConstruct(%0)
311+
%4 : Tensor[] = prim::ListConstruct(%1)
312+
aten::extend(%3, %4)
313+
%5 : Tensor = aten::cat(%3, %2)
314+
return (%5))IR";
315+
316+
auto g = std::make_shared<torch::jit::Graph>();
317+
torch::jit::parseIR(graph, &*g);
318+
319+
auto in0 = at::randint(1, 10, {3, 4}, {at::kCUDA});
320+
auto in1 = at::randint(1, 10, {5, 4}, {at::kCUDA});
321+
322+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
323+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in0, in1});
324+
325+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
326+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in0, in1});
327+
328+
ASSERT_TRUE(
329+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
330+
}
331+
306332
TEST(Evaluators, ATenAppendWithITensorEvaluatesCorrectly) {
307333
const auto graph = R"IR(
308334
graph(%0 : Tensor, %1 : Tensor):

0 commit comments

Comments
 (0)