Skip to content

Commit 4480603

Browse files
authored
Port roi_align to actually use dispatcher (#2366)
* Switch torchvision registrations to new operator registration API. This is still registering everything as catchalls, so we're really just moving deck chairs around, but payoff is coming soon. Signed-off-by: Edward Z. Yang <[email protected]> * Port roi_align to actually use dispatcher Signed-off-by: Edward Z. Yang <[email protected]>
1 parent 7fd2491 commit 4480603

File tree

6 files changed

+170
-107
lines changed

6 files changed

+170
-107
lines changed

torchvision/csrc/ROIAlign.h

Lines changed: 91 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
#include "hip/vision_cuda.h"
1010
#endif
1111

12-
// Interface for Python
13-
at::Tensor ROIAlign_forward(
12+
// TODO: put this stuff in torchvision namespace
13+
14+
at::Tensor roi_align(
1415
const at::Tensor& input, // Input feature map.
1516
const at::Tensor& rois, // List of ROIs to pool over.
1617
const double spatial_scale, // The scale of the image features. ROIs will be
@@ -21,21 +22,10 @@ at::Tensor ROIAlign_forward(
2122
const bool aligned) // The flag for pixel shift
2223
// along each axis.
2324
{
24-
if (input.is_cuda()) {
25-
#if defined(WITH_CUDA) || defined(WITH_HIP)
26-
return ROIAlign_forward_cuda(
27-
input,
28-
rois,
29-
spatial_scale,
30-
pooled_height,
31-
pooled_width,
32-
sampling_ratio,
33-
aligned);
34-
#else
35-
AT_ERROR("Not compiled with GPU support");
36-
#endif
37-
}
38-
return ROIAlign_forward_cpu(
25+
static auto op = c10::Dispatcher::singleton()
26+
.findSchemaOrThrow("torchvision::roi_align", "")
27+
.typed<decltype(roi_align)>();
28+
return op.call(
3929
input,
4030
rois,
4131
spatial_scale,
@@ -45,37 +35,23 @@ at::Tensor ROIAlign_forward(
4535
aligned);
4636
}
4737

48-
at::Tensor ROIAlign_backward(
38+
at::Tensor _roi_align_backward(
4939
const at::Tensor& grad,
5040
const at::Tensor& rois,
51-
const float spatial_scale,
52-
const int pooled_height,
53-
const int pooled_width,
54-
const int batch_size,
55-
const int channels,
56-
const int height,
57-
const int width,
58-
const int sampling_ratio,
41+
const double spatial_scale,
42+
const int64_t pooled_height,
43+
const int64_t pooled_width,
44+
const int64_t batch_size,
45+
const int64_t channels,
46+
const int64_t height,
47+
const int64_t width,
48+
const int64_t sampling_ratio,
5949
const bool aligned) {
60-
if (grad.is_cuda()) {
61-
#if defined(WITH_CUDA) || defined(WITH_HIP)
62-
return ROIAlign_backward_cuda(
63-
grad,
64-
rois,
65-
spatial_scale,
66-
pooled_height,
67-
pooled_width,
68-
batch_size,
69-
channels,
70-
height,
71-
width,
72-
sampling_ratio,
73-
aligned);
74-
#else
75-
AT_ERROR("Not compiled with GPU support");
76-
#endif
77-
}
78-
return ROIAlign_backward_cpu(
50+
static auto op =
51+
c10::Dispatcher::singleton()
52+
.findSchemaOrThrow("torchvision::_roi_align_backward", "")
53+
.typed<decltype(_roi_align_backward)>();
54+
return op.call(
7955
grad,
8056
rois,
8157
spatial_scale,
@@ -107,7 +83,8 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
10783
ctx->saved_data["aligned"] = aligned;
10884
ctx->saved_data["input_shape"] = input.sizes();
10985
ctx->save_for_backward({rois});
110-
auto result = ROIAlign_forward(
86+
at::AutoNonVariableTypeMode g;
87+
auto result = roi_align(
11188
input,
11289
rois,
11390
spatial_scale,
@@ -125,7 +102,7 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
125102
auto saved = ctx->get_saved_variables();
126103
auto rois = saved[0];
127104
auto input_shape = ctx->saved_data["input_shape"].toIntList();
128-
auto grad_in = ROIAlign_backward(
105+
auto grad_in = _roi_align_backward(
129106
grad_output[0],
130107
rois,
131108
ctx->saved_data["spatial_scale"].toDouble(),
@@ -147,7 +124,47 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
147124
}
148125
};
149126

150-
at::Tensor roi_align(
127+
// TODO: There should be an easier way to do this
128+
class ROIAlignBackwardFunction
129+
: public torch::autograd::Function<ROIAlignBackwardFunction> {
130+
public:
131+
static torch::autograd::variable_list forward(
132+
torch::autograd::AutogradContext* ctx,
133+
torch::autograd::Variable grad,
134+
torch::autograd::Variable rois,
135+
const double spatial_scale,
136+
const int64_t pooled_height,
137+
const int64_t pooled_width,
138+
const int64_t batch_size,
139+
const int64_t channels,
140+
const int64_t height,
141+
const int64_t width,
142+
const int64_t sampling_ratio,
143+
const bool aligned) {
144+
at::AutoNonVariableTypeMode g;
145+
auto result = _roi_align_backward(
146+
grad,
147+
rois,
148+
spatial_scale,
149+
pooled_height,
150+
pooled_width,
151+
batch_size,
152+
channels,
153+
height,
154+
width,
155+
sampling_ratio,
156+
aligned);
157+
return {result};
158+
}
159+
160+
static torch::autograd::variable_list backward(
161+
torch::autograd::AutogradContext* ctx,
162+
torch::autograd::variable_list grad_output) {
163+
TORCH_CHECK(0, "double backwards on roi_align not supported");
164+
}
165+
};
166+
167+
at::Tensor ROIAlign_autograd(
151168
const at::Tensor& input,
152169
const at::Tensor& rois,
153170
const double spatial_scale,
@@ -164,3 +181,29 @@ at::Tensor roi_align(
164181
sampling_ratio,
165182
aligned)[0];
166183
}
184+
185+
at::Tensor ROIAlign_backward_autograd(
186+
const at::Tensor& grad,
187+
const at::Tensor& rois,
188+
const double spatial_scale,
189+
const int64_t pooled_height,
190+
const int64_t pooled_width,
191+
const int64_t batch_size,
192+
const int64_t channels,
193+
const int64_t height,
194+
const int64_t width,
195+
const int64_t sampling_ratio,
196+
const bool aligned) {
197+
return ROIAlignBackwardFunction::apply(
198+
grad,
199+
rois,
200+
spatial_scale,
201+
pooled_height,
202+
pooled_width,
203+
batch_size,
204+
channels,
205+
height,
206+
width,
207+
sampling_ratio,
208+
aligned)[0];
209+
}

torchvision/csrc/cpu/ROIAlign_cpu.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -381,10 +381,10 @@ void ROIAlignBackward(
381381
at::Tensor ROIAlign_forward_cpu(
382382
const at::Tensor& input,
383383
const at::Tensor& rois,
384-
const float spatial_scale,
385-
const int pooled_height,
386-
const int pooled_width,
387-
const int sampling_ratio,
384+
const double spatial_scale,
385+
const int64_t pooled_height,
386+
const int64_t pooled_width,
387+
const int64_t sampling_ratio,
388388
const bool aligned) {
389389
AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
390390
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
@@ -430,14 +430,14 @@ at::Tensor ROIAlign_forward_cpu(
430430
at::Tensor ROIAlign_backward_cpu(
431431
const at::Tensor& grad,
432432
const at::Tensor& rois,
433-
const float spatial_scale,
434-
const int pooled_height,
435-
const int pooled_width,
436-
const int batch_size,
437-
const int channels,
438-
const int height,
439-
const int width,
440-
const int sampling_ratio,
433+
const double spatial_scale,
434+
const int64_t pooled_height,
435+
const int64_t pooled_width,
436+
const int64_t batch_size,
437+
const int64_t channels,
438+
const int64_t height,
439+
const int64_t width,
440+
const int64_t sampling_ratio,
441441
const bool aligned) {
442442
AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
443443
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");

torchvision/csrc/cpu/vision_cpu.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,23 @@ at::Tensor ROIPool_backward_cpu(
2323
at::Tensor ROIAlign_forward_cpu(
2424
const at::Tensor& input,
2525
const at::Tensor& rois,
26-
const float spatial_scale,
27-
const int pooled_height,
28-
const int pooled_width,
29-
const int sampling_ratio,
26+
const double spatial_scale,
27+
const int64_t pooled_height,
28+
const int64_t pooled_width,
29+
const int64_t sampling_ratio,
3030
const bool aligned);
3131

3232
at::Tensor ROIAlign_backward_cpu(
3333
const at::Tensor& grad,
3434
const at::Tensor& rois,
35-
const float spatial_scale,
36-
const int pooled_height,
37-
const int pooled_width,
38-
const int batch_size,
39-
const int channels,
40-
const int height,
41-
const int width,
42-
const int sampling_ratio,
35+
const double spatial_scale,
36+
const int64_t pooled_height,
37+
const int64_t pooled_width,
38+
const int64_t batch_size,
39+
const int64_t channels,
40+
const int64_t height,
41+
const int64_t width,
42+
const int64_t sampling_ratio,
4343
const bool aligned);
4444

4545
std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu(

torchvision/csrc/cuda/ROIAlign_cuda.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -307,10 +307,10 @@ __global__ void RoIAlignBackward(
307307
at::Tensor ROIAlign_forward_cuda(
308308
const at::Tensor& input,
309309
const at::Tensor& rois,
310-
const float spatial_scale,
311-
const int pooled_height,
312-
const int pooled_width,
313-
const int sampling_ratio,
310+
const double spatial_scale,
311+
const int64_t pooled_height,
312+
const int64_t pooled_width,
313+
const int64_t sampling_ratio,
314314
const bool aligned) {
315315
AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor");
316316
AT_ASSERTM(rois.is_cuda(), "rois must be a CUDA tensor");
@@ -368,14 +368,14 @@ at::Tensor ROIAlign_forward_cuda(
368368
at::Tensor ROIAlign_backward_cuda(
369369
const at::Tensor& grad,
370370
const at::Tensor& rois,
371-
const float spatial_scale,
372-
const int pooled_height,
373-
const int pooled_width,
374-
const int batch_size,
375-
const int channels,
376-
const int height,
377-
const int width,
378-
const int sampling_ratio,
371+
const double spatial_scale,
372+
const int64_t pooled_height,
373+
const int64_t pooled_width,
374+
const int64_t batch_size,
375+
const int64_t channels,
376+
const int64_t height,
377+
const int64_t width,
378+
const int64_t sampling_ratio,
379379
const bool aligned) {
380380
AT_ASSERTM(grad.is_cuda(), "grad must be a CUDA tensor");
381381
AT_ASSERTM(rois.is_cuda(), "rois must be a CUDA tensor");

torchvision/csrc/cuda/vision_cuda.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,23 @@
99
at::Tensor ROIAlign_forward_cuda(
1010
const at::Tensor& input,
1111
const at::Tensor& rois,
12-
const float spatial_scale,
13-
const int pooled_height,
14-
const int pooled_width,
15-
const int sampling_ratio,
12+
const double spatial_scale,
13+
const int64_t pooled_height,
14+
const int64_t pooled_width,
15+
const int64_t sampling_ratio,
1616
const bool aligned);
1717

1818
at::Tensor ROIAlign_backward_cuda(
1919
const at::Tensor& grad,
2020
const at::Tensor& rois,
21-
const float spatial_scale,
22-
const int pooled_height,
23-
const int pooled_width,
24-
const int batch_size,
25-
const int channels,
26-
const int height,
27-
const int width,
28-
const int sampling_ratio,
21+
const double spatial_scale,
22+
const int64_t pooled_height,
23+
const int64_t pooled_width,
24+
const int64_t batch_size,
25+
const int64_t channels,
26+
const int64_t height,
27+
const int64_t width,
28+
const int64_t sampling_ratio,
2929
const bool aligned);
3030

3131
std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(

torchvision/csrc/vision.cpp

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,34 @@ int64_t _cuda_version() {
4242
#endif
4343
}
4444

45-
static auto registry =
46-
torch::RegisterOperators()
47-
.op("torchvision::nms", &nms)
48-
.op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor",
49-
&roi_align)
50-
.op("torchvision::roi_pool", &roi_pool)
51-
.op("torchvision::_new_empty_tensor_op", &new_empty_tensor)
52-
.op("torchvision::ps_roi_align", &ps_roi_align)
53-
.op("torchvision::ps_roi_pool", &ps_roi_pool)
54-
.op("torchvision::deform_conv2d", &deform_conv2d)
55-
.op("torchvision::_cuda_version", &_cuda_version);
45+
TORCH_LIBRARY(torchvision, m) {
46+
m.def("nms", &nms);
47+
m.def(
48+
"roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor");
49+
m.def(
50+
"_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor");
51+
m.def("roi_pool", &roi_pool);
52+
m.def("_new_empty_tensor_op", &new_empty_tensor);
53+
m.def("ps_roi_align", &ps_roi_align);
54+
m.def("ps_roi_pool", &ps_roi_pool);
55+
m.def("deform_conv2d", &deform_conv2d);
56+
m.def("_cuda_version", &_cuda_version);
57+
}
58+
59+
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
60+
m.impl("roi_align", ROIAlign_forward_cpu);
61+
m.impl("_roi_align_backward", ROIAlign_backward_cpu);
62+
}
63+
64+
// TODO: Place this in a hypothetical separate torchvision_cuda library
65+
#if defined(WITH_CUDA) || defined(WITH_HIP)
66+
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
67+
m.impl("roi_align", ROIAlign_forward_cuda);
68+
m.impl("_roi_align_backward", ROIAlign_backward_cuda);
69+
}
70+
#endif
71+
72+
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
73+
m.impl("roi_align", ROIAlign_autograd);
74+
m.impl("_roi_align_backward", ROIAlign_backward_autograd);
75+
}

0 commit comments

Comments
 (0)