Skip to content

Commit db2adb2

Browse files
committed
Move autograd implementations on separate files.
1 parent 7d831a2 commit db2adb2

18 files changed

+993
-838
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ include(GNUInstallDirs)
5252
include(CMakePackageConfigHelpers)
5353

5454
set(TVCPP torchvision/csrc)
55-
list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops ${TVCPP}/ops/cpu)
55+
list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops
56+
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu)
5657
if(WITH_CUDA)
5758
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
5859
endif()

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def get_extensions():
136136

137137
main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + glob.glob(os.path.join(extensions_dir, 'ops',
138138
'*.cpp'))
139-
source_cpu = glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp'))
139+
source_cpu = glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) + glob.glob(
140+
os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp'))
140141

141142
is_rocm_pytorch = False
142143
if torch.__version__ >= '1.5':
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
#include "../deform_conv2d.h"
2+
3+
#include <torch/autograd.h>
4+
#include <torch/types.h>
5+
6+
namespace vision {
7+
namespace ops {
8+
9+
namespace {
10+
11+
class DeformConv2dFunction
12+
: public torch::autograd::Function<DeformConv2dFunction> {
13+
public:
14+
static torch::autograd::variable_list forward(
15+
torch::autograd::AutogradContext* ctx,
16+
const torch::autograd::Variable& input,
17+
const torch::autograd::Variable& weight,
18+
const torch::autograd::Variable& offset,
19+
const torch::autograd::Variable& mask,
20+
const torch::autograd::Variable& bias,
21+
int64_t stride_h,
22+
int64_t stride_w,
23+
int64_t pad_h,
24+
int64_t pad_w,
25+
int64_t dilation_h,
26+
int64_t dilation_w,
27+
int64_t groups,
28+
int64_t offset_groups,
29+
bool use_mask) {
30+
at::AutoNonVariableTypeMode g;
31+
auto output = deform_conv2d(
32+
input,
33+
weight,
34+
offset,
35+
mask,
36+
bias,
37+
stride_h,
38+
stride_w,
39+
pad_h,
40+
pad_w,
41+
dilation_h,
42+
dilation_w,
43+
groups,
44+
offset_groups,
45+
use_mask);
46+
47+
ctx->save_for_backward({input, weight, offset, mask, bias});
48+
ctx->saved_data["stride_h"] = stride_h;
49+
ctx->saved_data["stride_w"] = stride_w;
50+
ctx->saved_data["pad_h"] = pad_h;
51+
ctx->saved_data["pad_w"] = pad_w;
52+
ctx->saved_data["dilation_h"] = dilation_h;
53+
ctx->saved_data["dilation_w"] = dilation_w;
54+
ctx->saved_data["groups"] = groups;
55+
ctx->saved_data["offset_groups"] = offset_groups;
56+
ctx->saved_data["use_mask"] = use_mask;
57+
58+
return {
59+
output,
60+
};
61+
}
62+
63+
static torch::autograd::variable_list backward(
64+
torch::autograd::AutogradContext* ctx,
65+
const torch::autograd::variable_list& grad_output) {
66+
auto saved = ctx->get_saved_variables();
67+
auto input = saved[0];
68+
auto weight = saved[1];
69+
auto offset = saved[2];
70+
auto mask = saved[3];
71+
auto bias = saved[4];
72+
73+
auto stride_h = ctx->saved_data["stride_h"].toInt();
74+
auto stride_w = ctx->saved_data["stride_w"].toInt();
75+
auto pad_h = ctx->saved_data["pad_h"].toInt();
76+
auto pad_w = ctx->saved_data["pad_w"].toInt();
77+
auto dilation_h = ctx->saved_data["dilation_h"].toInt();
78+
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
79+
auto groups = ctx->saved_data["groups"].toInt();
80+
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
81+
auto use_mask = ctx->saved_data["use_mask"].toBool();
82+
83+
auto grads = detail::_deform_conv2d_backward(
84+
grad_output[0],
85+
input,
86+
weight,
87+
offset,
88+
mask,
89+
bias,
90+
stride_h,
91+
stride_w,
92+
pad_h,
93+
pad_w,
94+
dilation_h,
95+
dilation_w,
96+
groups,
97+
offset_groups,
98+
use_mask);
99+
auto grad_input = std::get<0>(grads);
100+
auto grad_weight = std::get<1>(grads);
101+
auto grad_offset = std::get<2>(grads);
102+
auto grad_mask = std::get<3>(grads);
103+
auto grad_bias = std::get<4>(grads);
104+
105+
return {
106+
grad_input,
107+
grad_weight,
108+
grad_offset,
109+
grad_mask,
110+
grad_bias,
111+
torch::autograd::Variable(),
112+
torch::autograd::Variable(),
113+
torch::autograd::Variable(),
114+
torch::autograd::Variable(),
115+
torch::autograd::Variable(),
116+
torch::autograd::Variable(),
117+
torch::autograd::Variable(),
118+
torch::autograd::Variable(),
119+
torch::autograd::Variable(),
120+
};
121+
}
122+
};
123+
124+
// TODO: There should be an easier way to do this
125+
class DeformConv2dBackwardFunction
126+
: public torch::autograd::Function<DeformConv2dBackwardFunction> {
127+
public:
128+
static torch::autograd::variable_list forward(
129+
torch::autograd::AutogradContext* ctx,
130+
const torch::autograd::Variable& grad,
131+
const torch::autograd::Variable& input,
132+
const torch::autograd::Variable& weight,
133+
const torch::autograd::Variable& offset,
134+
const torch::autograd::Variable& mask,
135+
const torch::autograd::Variable& bias,
136+
int64_t stride_h,
137+
int64_t stride_w,
138+
int64_t pad_h,
139+
int64_t pad_w,
140+
int64_t dilation_h,
141+
int64_t dilation_w,
142+
int64_t groups,
143+
int64_t offset_groups,
144+
bool use_mask) {
145+
at::AutoNonVariableTypeMode g;
146+
auto result = detail::_deform_conv2d_backward(
147+
grad,
148+
input,
149+
weight,
150+
offset,
151+
mask,
152+
bias,
153+
stride_h,
154+
stride_w,
155+
pad_h,
156+
pad_w,
157+
dilation_h,
158+
dilation_w,
159+
groups,
160+
offset_groups,
161+
use_mask);
162+
163+
auto grad_input = std::get<0>(result);
164+
auto grad_weight = std::get<1>(result);
165+
auto grad_offset = std::get<2>(result);
166+
auto grad_mask = std::get<3>(result);
167+
auto grad_bias = std::get<4>(result);
168+
169+
return {
170+
grad_input,
171+
grad_weight,
172+
grad_offset,
173+
grad_mask,
174+
grad_bias,
175+
};
176+
}
177+
178+
static torch::autograd::variable_list backward(
179+
torch::autograd::AutogradContext* ctx,
180+
const torch::autograd::variable_list& grad_output) {
181+
TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
182+
}
183+
};
184+
185+
at::Tensor deform_conv2d_autograd(
186+
const at::Tensor& input,
187+
const at::Tensor& weight,
188+
const at::Tensor& offset,
189+
const at::Tensor& mask,
190+
const at::Tensor& bias,
191+
int64_t stride_h,
192+
int64_t stride_w,
193+
int64_t pad_h,
194+
int64_t pad_w,
195+
int64_t dilation_h,
196+
int64_t dilation_w,
197+
int64_t groups,
198+
int64_t offset_groups,
199+
bool use_mask) {
200+
return DeformConv2dFunction::apply(
201+
input,
202+
weight,
203+
offset,
204+
mask,
205+
bias,
206+
stride_h,
207+
stride_w,
208+
pad_h,
209+
pad_w,
210+
dilation_h,
211+
dilation_w,
212+
groups,
213+
offset_groups,
214+
use_mask)[0];
215+
}
216+
217+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
218+
deform_conv2d_backward_autograd(
219+
const at::Tensor& grad,
220+
const at::Tensor& input,
221+
const at::Tensor& weight,
222+
const at::Tensor& offset,
223+
const at::Tensor& mask,
224+
const at::Tensor& bias,
225+
int64_t stride_h,
226+
int64_t stride_w,
227+
int64_t pad_h,
228+
int64_t pad_w,
229+
int64_t dilation_h,
230+
int64_t dilation_w,
231+
int64_t groups,
232+
int64_t offset_groups,
233+
bool use_mask) {
234+
auto result = DeformConv2dBackwardFunction::apply(
235+
grad,
236+
input,
237+
weight,
238+
offset,
239+
mask,
240+
bias,
241+
stride_h,
242+
stride_w,
243+
pad_h,
244+
pad_w,
245+
dilation_h,
246+
dilation_w,
247+
groups,
248+
offset_groups,
249+
use_mask);
250+
251+
return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
252+
}
253+
254+
} // namespace
255+
256+
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
257+
m.impl("deform_conv2d", deform_conv2d_autograd);
258+
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
259+
}
260+
261+
} // namespace ops
262+
} // namespace vision

0 commit comments

Comments
 (0)