@@ -634,44 +634,50 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> backward_gradient_inputs(
634
634
input = input.reshape (
635
635
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
636
636
637
- grad_offset = grad_offset.reshape ({batch_sz / n_parallel_imgs,
638
- n_parallel_imgs,
639
- n_offset_grps * 2 * weight_h * weight_w,
640
- out_h,
641
- out_w});
642
- offset = offset.reshape ({batch_sz / n_parallel_imgs,
643
- n_parallel_imgs,
644
- n_offset_grps * 2 * weight_h * weight_w,
645
- out_h,
646
- out_w});
637
+ grad_offset = grad_offset.reshape (
638
+ {batch_sz / n_parallel_imgs,
639
+ n_parallel_imgs,
640
+ n_offset_grps * 2 * weight_h * weight_w,
641
+ out_h,
642
+ out_w});
643
+ offset = offset.reshape (
644
+ {batch_sz / n_parallel_imgs,
645
+ n_parallel_imgs,
646
+ n_offset_grps * 2 * weight_h * weight_w,
647
+ out_h,
648
+ out_w});
647
649
648
650
if (use_mask) {
649
- grad_mask = grad_mask.reshape ({batch_sz / n_parallel_imgs,
650
- n_parallel_imgs,
651
- n_offset_grps * weight_h * weight_w,
652
- out_h,
653
- out_w});
654
- mask = mask.reshape ({batch_sz / n_parallel_imgs,
655
- n_parallel_imgs,
656
- n_offset_grps * weight_h * weight_w,
657
- out_h,
658
- out_w});
651
+ grad_mask = grad_mask.reshape (
652
+ {batch_sz / n_parallel_imgs,
653
+ n_parallel_imgs,
654
+ n_offset_grps * weight_h * weight_w,
655
+ out_h,
656
+ out_w});
657
+ mask = mask.reshape (
658
+ {batch_sz / n_parallel_imgs,
659
+ n_parallel_imgs,
660
+ n_offset_grps * weight_h * weight_w,
661
+ out_h,
662
+ out_w});
659
663
}
660
664
661
665
grad_out = grad_out
662
- .reshape ({batch_sz / n_parallel_imgs,
663
- n_parallel_imgs,
664
- n_weight_grps,
665
- n_out_channels / n_weight_grps,
666
- out_h,
667
- out_w})
666
+ .reshape (
667
+ {batch_sz / n_parallel_imgs,
668
+ n_parallel_imgs,
669
+ n_weight_grps,
670
+ n_out_channels / n_weight_grps,
671
+ out_h,
672
+ out_w})
668
673
.permute ({0 , 2 , 3 , 1 , 4 , 5 });
669
674
670
- weight = weight.reshape ({n_weight_grps,
671
- weight.size (0 ) / n_weight_grps,
672
- weight.size (1 ),
673
- weight.size (2 ),
674
- weight.size (3 )});
675
+ weight = weight.reshape (
676
+ {n_weight_grps,
677
+ weight.size (0 ) / n_weight_grps,
678
+ weight.size (1 ),
679
+ weight.size (2 ),
680
+ weight.size (3 )});
675
681
676
682
columns = columns.view (
677
683
{n_weight_grps, columns.size (0 ) / n_weight_grps, columns.size (1 )});
@@ -775,37 +781,41 @@ at::Tensor backward_gradient_parameters(
775
781
}
776
782
777
783
at::Tensor grad_out_buf = grad_out
778
- .reshape ({batch_sz / n_parallel_imgs,
779
- n_parallel_imgs,
780
- n_weight_grps,
781
- n_out_channels / n_weight_grps,
782
- out_h,
783
- out_w})
784
+ .reshape (
785
+ {batch_sz / n_parallel_imgs,
786
+ n_parallel_imgs,
787
+ n_weight_grps,
788
+ n_out_channels / n_weight_grps,
789
+ out_h,
790
+ out_w})
784
791
.permute ({0 , 2 , 3 , 1 , 4 , 5 })
785
792
.contiguous ();
786
793
787
794
input = input.reshape (
788
795
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
789
796
790
- offset = offset.reshape ({batch_sz / n_parallel_imgs,
791
- n_parallel_imgs,
792
- n_offset_grps * 2 * weight_h * weight_w,
793
- out_h,
794
- out_w});
797
+ offset = offset.reshape (
798
+ {batch_sz / n_parallel_imgs,
799
+ n_parallel_imgs,
800
+ n_offset_grps * 2 * weight_h * weight_w,
801
+ out_h,
802
+ out_w});
795
803
796
804
if (use_mask) {
797
- mask = mask.reshape ({batch_sz / n_parallel_imgs,
798
- n_parallel_imgs,
799
- n_offset_grps * weight_h * weight_w,
800
- out_h,
801
- out_w});
805
+ mask = mask.reshape (
806
+ {batch_sz / n_parallel_imgs,
807
+ n_parallel_imgs,
808
+ n_offset_grps * weight_h * weight_w,
809
+ out_h,
810
+ out_w});
802
811
}
803
812
804
- grad_weight = grad_weight.view ({n_weight_grps,
805
- grad_weight.size (0 ) / n_weight_grps,
806
- grad_weight.size (1 ),
807
- grad_weight.size (2 ),
808
- grad_weight.size (3 )});
813
+ grad_weight = grad_weight.view (
814
+ {n_weight_grps,
815
+ grad_weight.size (0 ) / n_weight_grps,
816
+ grad_weight.size (1 ),
817
+ grad_weight.size (2 ),
818
+ grad_weight.size (3 )});
809
819
810
820
auto columns = at::empty (
811
821
{n_weight_grps,
@@ -846,10 +856,11 @@ at::Tensor backward_gradient_parameters(
846
856
}
847
857
}
848
858
849
- grad_weight = grad_weight.view ({grad_weight.size (0 ) * grad_weight.size (1 ),
850
- grad_weight.size (2 ),
851
- grad_weight.size (3 ),
852
- grad_weight.size (4 )});
859
+ grad_weight = grad_weight.view (
860
+ {grad_weight.size (0 ) * grad_weight.size (1 ),
861
+ grad_weight.size (2 ),
862
+ grad_weight.size (3 ),
863
+ grad_weight.size (4 )});
853
864
return grad_weight;
854
865
}
855
866
@@ -976,26 +987,29 @@ at::Tensor deform_conv2d_forward_kernel(
976
987
}
977
988
978
989
// Separate batches into blocks
979
- out = out.view ({batch_sz / n_parallel_imgs,
980
- n_parallel_imgs,
981
- out_channels,
982
- out_h,
983
- out_w});
990
+ out = out.view (
991
+ {batch_sz / n_parallel_imgs,
992
+ n_parallel_imgs,
993
+ out_channels,
994
+ out_h,
995
+ out_w});
984
996
input_c = input_c.view (
985
997
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
986
998
987
- offset_c = offset_c.view ({batch_sz / n_parallel_imgs,
988
- n_parallel_imgs,
989
- n_offset_grps * 2 * weight_h * weight_w,
990
- out_h,
991
- out_w});
999
+ offset_c = offset_c.view (
1000
+ {batch_sz / n_parallel_imgs,
1001
+ n_parallel_imgs,
1002
+ n_offset_grps * 2 * weight_h * weight_w,
1003
+ out_h,
1004
+ out_w});
992
1005
993
1006
if (use_mask) {
994
- mask_c = mask_c.view ({batch_sz / n_parallel_imgs,
995
- n_parallel_imgs,
996
- n_offset_grps * weight_h * weight_w,
997
- out_h,
998
- out_w});
1007
+ mask_c = mask_c.view (
1008
+ {batch_sz / n_parallel_imgs,
1009
+ n_parallel_imgs,
1010
+ n_offset_grps * weight_h * weight_w,
1011
+ out_h,
1012
+ out_w});
999
1013
}
1000
1014
1001
1015
at::Tensor out_buf = at::zeros (
@@ -1006,16 +1020,18 @@ at::Tensor deform_conv2d_forward_kernel(
1006
1020
out.options ());
1007
1021
1008
1022
// Separate channels into convolution groups
1009
- out_buf = out_buf.view ({out_buf.size (0 ),
1010
- n_weight_grps,
1011
- out_buf.size (1 ) / n_weight_grps,
1012
- out_buf.size (2 ),
1013
- out_buf.size (3 )});
1014
- weight_c = weight_c.view ({n_weight_grps,
1015
- weight_c.size (0 ) / n_weight_grps,
1016
- weight_c.size (1 ),
1017
- weight_c.size (2 ),
1018
- weight_c.size (3 )});
1023
+ out_buf = out_buf.view (
1024
+ {out_buf.size (0 ),
1025
+ n_weight_grps,
1026
+ out_buf.size (1 ) / n_weight_grps,
1027
+ out_buf.size (2 ),
1028
+ out_buf.size (3 )});
1029
+ weight_c = weight_c.view (
1030
+ {n_weight_grps,
1031
+ weight_c.size (0 ) / n_weight_grps,
1032
+ weight_c.size (1 ),
1033
+ weight_c.size (2 ),
1034
+ weight_c.size (3 )});
1019
1035
1020
1036
// Sample points and perform convolution
1021
1037
auto columns = at::zeros (
@@ -1056,11 +1072,12 @@ at::Tensor deform_conv2d_forward_kernel(
1056
1072
columns.view ({columns.size (0 ) * columns.size (1 ), columns.size (2 )});
1057
1073
}
1058
1074
1059
- out_buf = out_buf.view ({batch_sz / n_parallel_imgs,
1060
- out_channels,
1061
- n_parallel_imgs,
1062
- out_h,
1063
- out_w});
1075
+ out_buf = out_buf.view (
1076
+ {batch_sz / n_parallel_imgs,
1077
+ out_channels,
1078
+ n_parallel_imgs,
1079
+ out_h,
1080
+ out_w});
1064
1081
out_buf.transpose_ (1 , 2 );
1065
1082
out.copy_ (out_buf);
1066
1083
out = out.view ({batch_sz, out_channels, out_h, out_w});
0 commit comments