Skip to content

Commit 268a49b

Browse files
committed
feat(core//conversion): Implement converter for torch unbind
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent b8b8fce commit 268a49b

File tree

3 files changed

+600
-564
lines changed

3 files changed

+600
-564
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,32 @@ namespace converters {
1515
namespace impl {
1616
namespace {
1717

18-
bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list) {
18+
bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list, bool unbind) {
1919
auto in = args[0].ITensor();
20-
auto axis = args[2].unwrapToInt();
21-
auto inDimSize = in->getDimensions().d[axis];
22-
auto numOutputs = 1, numRemainder = 0;
20+
auto numOutputs = 1, numRemainder = 0, axis = 0;
2321
std::vector<int64_t> sizes;
2422

25-
if (split_list) {
26-
sizes = args[1].unwrapToIntList().vec();
27-
numOutputs = sizes.size();
23+
if (unbind) {
24+
axis = args[1].unwrapToInt();
25+
numOutputs = in->getDimensions().d[axis];
26+
sizes.insert(sizes.end(), numOutputs, 1);
2827
} else {
29-
auto split_size = args[1].unwrapToInt();
30-
numOutputs = inDimSize / split_size;
31-
numRemainder = inDimSize % split_size;
32-
for (int64_t i = 0; i < numOutputs; i++) {
33-
sizes.push_back(split_size);
34-
}
35-
if (numRemainder) {
36-
numOutputs += 1;
37-
sizes.push_back(numRemainder);
28+
axis = args[2].unwrapToInt();
29+
auto inDimSize = in->getDimensions().d[axis];
30+
if (split_list) {
31+
sizes = args[1].unwrapToIntList().vec();
32+
numOutputs = sizes.size();
33+
} else {
34+
auto split_size = args[1].unwrapToInt();
35+
numOutputs = inDimSize / split_size;
36+
numRemainder = inDimSize % split_size;
37+
for (int64_t i = 0; i < numOutputs; i++) {
38+
sizes.push_back(split_size);
39+
}
40+
if (numRemainder) {
41+
numOutputs += 1;
42+
sizes.push_back(numRemainder);
43+
}
3844
}
3945
}
4046

@@ -340,19 +346,25 @@ auto select_registrations TORCHTRT_UNUSED =
340346
}})
341347
.pattern({"aten::split(Tensor self, int[] split_sizes, int dim=0) -> (Tensor[])",
342348
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
343-
add_split(ctx, n, args, true);
349+
add_split(ctx, n, args, true, false);
344350
LOG_DEBUG("Converted split op into a list of IValues");
345351
return true;
346352
}})
347353
.pattern({"aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> (Tensor[])",
348354
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
349-
add_split(ctx, n, args, false);
355+
add_split(ctx, n, args, false, false);
350356
LOG_DEBUG("Converted split op into a list of IValues");
351357
return true;
352358
}})
353359
.pattern({"aten::split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> (Tensor[])",
354360
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
355-
add_split(ctx, n, args, true);
361+
add_split(ctx, n, args, true, false);
362+
LOG_DEBUG("Converted split op into a list of IValues");
363+
return true;
364+
}})
365+
.pattern({"aten::unbind.int(Tensor(a -> *) self, int dim=0) -> (Tensor[])",
366+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
367+
add_split(ctx, n, args, false, true);
356368
LOG_DEBUG("Converted split op into a list of IValues");
357369
return true;
358370
}})

core/lowering/register_trt_placeholder_ops.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() {
1010
RegisterOperators trt_placeholder_ops_reg({
1111
/// Op marks a Tensor to be conveted from an Torch Tensor
1212
/// to a TRT constant Tensor
13-
Operator(
14-
"trt::const(Tensor val) -> Tensor",
15-
[](Stack& stack) { /*noop*/ },
16-
aliasAnalysisFromSchema()),
13+
Operator("trt::const(Tensor val) -> Tensor", [](Stack& stack) { /*noop*/ }, aliasAnalysisFromSchema()),
1714
});
1815

1916
} // namespace jit

0 commit comments

Comments
 (0)