@@ -15,26 +15,32 @@ namespace converters {
15
15
namespace impl {
16
16
namespace {
17
17
18
- bool add_split (ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list) {
18
+ bool add_split (ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list, bool unbind ) {
19
19
auto in = args[0 ].ITensor ();
20
- auto axis = args[2 ].unwrapToInt ();
21
- auto inDimSize = in->getDimensions ().d [axis];
22
- auto numOutputs = 1 , numRemainder = 0 ;
20
+ auto numOutputs = 1 , numRemainder = 0 , axis = 0 ;
23
21
std::vector<int64_t > sizes;
24
22
25
- if (split_list) {
26
- sizes = args[1 ].unwrapToIntList ().vec ();
27
- numOutputs = sizes.size ();
23
+ if (unbind) {
24
+ axis = args[1 ].unwrapToInt ();
25
+ numOutputs = in->getDimensions ().d [axis];
26
+ sizes.insert (sizes.end (), numOutputs, 1 );
28
27
} else {
29
- auto split_size = args[1 ].unwrapToInt ();
30
- numOutputs = inDimSize / split_size;
31
- numRemainder = inDimSize % split_size;
32
- for (int64_t i = 0 ; i < numOutputs; i++) {
33
- sizes.push_back (split_size);
34
- }
35
- if (numRemainder) {
36
- numOutputs += 1 ;
37
- sizes.push_back (numRemainder);
28
+ axis = args[2 ].unwrapToInt ();
29
+ auto inDimSize = in->getDimensions ().d [axis];
30
+ if (split_list) {
31
+ sizes = args[1 ].unwrapToIntList ().vec ();
32
+ numOutputs = sizes.size ();
33
+ } else {
34
+ auto split_size = args[1 ].unwrapToInt ();
35
+ numOutputs = inDimSize / split_size;
36
+ numRemainder = inDimSize % split_size;
37
+ for (int64_t i = 0 ; i < numOutputs; i++) {
38
+ sizes.push_back (split_size);
39
+ }
40
+ if (numRemainder) {
41
+ numOutputs += 1 ;
42
+ sizes.push_back (numRemainder);
43
+ }
38
44
}
39
45
}
40
46
@@ -340,19 +346,25 @@ auto select_registrations TORCHTRT_UNUSED =
340
346
}})
341
347
.pattern({" aten::split(Tensor self, int[] split_sizes, int dim=0) -> (Tensor[])" ,
342
348
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
343
- add_split (ctx, n, args, true );
349
+ add_split (ctx, n, args, true , false );
344
350
LOG_DEBUG (" Converted split op into a list of IValues" );
345
351
return true ;
346
352
}})
347
353
.pattern({" aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> (Tensor[])" ,
348
354
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
349
- add_split (ctx, n, args, false );
355
+ add_split (ctx, n, args, false , false );
350
356
LOG_DEBUG (" Converted split op into a list of IValues" );
351
357
return true ;
352
358
}})
353
359
.pattern({" aten::split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> (Tensor[])" ,
354
360
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
355
- add_split (ctx, n, args, true );
361
+ add_split (ctx, n, args, true , false );
362
+ LOG_DEBUG (" Converted split op into a list of IValues" );
363
+ return true ;
364
+ }})
365
+ .pattern({" aten::unbind.int(Tensor(a -> *) self, int dim=0) -> (Tensor[])" ,
366
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
367
+ add_split (ctx, n, args, false , true );
356
368
LOG_DEBUG (" Converted split op into a list of IValues" );
357
369
return true ;
358
370
}})
0 commit comments