diff --git a/test/test_layers.py b/test/test_layers.py index d508393c64a..c1fd68da74c 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -1,81 +1,81 @@ +import numpy as np import torch from torch.autograd import gradcheck from torchvision import layers - +from itertools import product import unittest class ROIPoolTester(unittest.TestCase): + @classmethod + def setup_class(cls): + cls.dtype = torch.float64 + + def slow_roi_pooling(self, x, rois, pool_h, pool_w, spatial_scale=1, device=torch.device('cpu'), dtype=torch.float64): + y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, dtype=dtype, device=device) + + rois = torch.round(rois * spatial_scale) + + for n in range(0, y.size(0)): + for r, roi in enumerate(rois): + if roi[0] == n: + start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1 + start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1 + roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w] + bin_h, bin_w = roi_x.size(2) / pool_h, roi_x.size(3) / pool_w + + for j in range(0, pool_h): + for i in range(0, pool_w): + y[r, :, j, i] = torch.max(y[r, :, j, i], + torch.max(roi_x[:, :, + int(np.floor(j * bin_h)):int(np.ceil((j + 1) * bin_h)), + int(np.floor(i * bin_w)):int(np.ceil((i + 1) * bin_w))]) + ) + return y def test_roi_pool_basic_cpu(self): - dtype = torch.float32 device = torch.device('cpu') - x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device) + x = torch.rand(1, 1, 10, 10, dtype=self.dtype, device=device) rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy) - dtype=dtype, device=device) + dtype=self.dtype, device=device) pool_h, pool_w = (5, 5) roi_pool = layers.ROIPool((pool_h, pool_w), 1) y = roi_pool(x, rois) - gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w) - - for n in range(0, gt_y.size(0)): - start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1 - start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1 - roi_x = x[:, :, start_h:end_h, start_w:end_w] - bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w - for j in range(0, pool_h): - for i in range(0, pool_w): - gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w]) + gt_y = self.slow_roi_pooling(x, rois, pool_h, pool_w, device=device, dtype=self.dtype) - assert torch.equal(gt_y, y), 'ROIPool layer incorrect' + assert torch.allclose(gt_y, y), 'ROIPool layer incorrect on CPU' def test_roi_pool_cpu(self): - dtype = torch.float32 device = torch.device('cpu') - x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device) + x = torch.rand(2, 1, 10, 10, dtype=self.dtype, device=device) rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], - dtype=dtype, device=device) + dtype=self.dtype, device=device) pool_h, pool_w = (5, 5) roi_pool = layers.ROIPool((pool_h, pool_w), 1) y = roi_pool(x, rois) - gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device) - for n in range(0, gt_y.size(0)): - for r, roi in enumerate(rois): - if roi[0] == n: - start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1 - start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1 - roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w] - bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w - for j in range(0, pool_h): - for i in range(0, pool_w): - gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i], - torch.max(roi_x[:, :, - j * bin_h:(j + 1) * bin_h, - i * bin_w:(i + 1) * bin_w]) - ) + gt_y = self.slow_roi_pooling(x, rois, pool_h, pool_w, device=device, dtype=self.dtype) - assert torch.equal(gt_y, y), 'ROIPool layer incorrect' + assert torch.allclose(gt_y, y), 'ROIPool layer incorrect on CPU for batch > 1' def test_roi_pool_gradient_cpu(self): - dtype = torch.float32 device = torch.device('cpu') - layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device) - x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True) - cx = torch.ones(1, 1, 10, 10, dtype=dtype, requires_grad=True).cuda() + x = torch.ones(1, 1, 10, 10, dtype=self.dtype, device=device, requires_grad=True) rois = torch.tensor([ [0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 0, 0, 4, 4]], - dtype=dtype, device=device) + dtype=self.dtype, device=device) + + layer = layers.ROIPool((5, 5), 1).to(dtype=self.dtype, device=device) y = layer(x, rois) s = y.sum() @@ -90,86 +90,71 @@ def test_roi_pool_gradient_cpu(self): [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], - [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device) + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], + device=device, dtype=self.dtype) + + assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for roi_pool' + + def test_roi_pool_gradcheck_cpu(self): + device = torch.device('cpu') + x = torch.rand(1, 1, 10, 10, dtype=self.dtype, device=device, requires_grad=True) + rois = torch.tensor([ + [0, 0, 0, 9, 9], + [0, 0, 5, 5, 9], + [0, 5, 5, 9, 9]], dtype=self.dtype, device=device) + + m = layers.ROIPool((5, 5), 1).to(dtype=self.dtype, device=device) + + def func(input): + return m(input, rois) - assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool' + assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool CPU' @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") - def test_roi_pool_basic_gpu(self): - dtype = torch.float32 + def test_roi_pool_basic_cuda(self): device = torch.device('cuda') - x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device) + x = torch.rand(1, 1, 10, 10, dtype=self.dtype, device=device) rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy) - dtype=dtype, device=device) + dtype=self.dtype, device=device) pool_h, pool_w = (5, 5) roi_pool = layers.ROIPool((pool_h, pool_w), 1) y = roi_pool(x, rois) - gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w) + gt_y = self.slow_roi_pooling(x, rois, pool_h, pool_w, device=device, dtype=self.dtype) - for n in range(0, gt_y.size(0)): - start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1 - start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1 - roi_x = x[:, :, start_h:end_h, start_w:end_w] - bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w - for j in range(0, pool_h): - for i in range(0, pool_w): - gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w]) - - assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect' + assert torch.allclose(gt_y.cuda(), y), 'ROIPool layer incorrect' @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") - def test_roi_pool_gpu(self): - dtype = torch.float32 + def test_roi_pool_cuda(self): device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') - x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device) + x = torch.rand(2, 1, 10, 10, dtype=self.dtype, device=device) rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], - dtype=dtype, device=device) + dtype=self.dtype, device=device) pool_h, pool_w = (5, 5) roi_pool = layers.ROIPool((pool_h, pool_w), 1) y = roi_pool(x, rois) - gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device) - for n in range(0, gt_y.size(0)): - for r, roi in enumerate(rois): - if roi[0] == n: - start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1 - start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1 - roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w] - bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w - for j in range(0, pool_h): - for i in range(0, pool_w): - gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i], - torch.max(roi_x[:, :, - j * bin_h:(j + 1) * bin_h, - i * bin_w:(i + 1) * bin_w]) - ) + gt_y = self.slow_roi_pooling(x, rois, pool_h, pool_w, device=device, dtype=self.dtype) - assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect' + assert torch.allclose(gt_y.cuda(), y), 'ROIPool layer incorrect' @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") - def test_roi_pool_gradient_gpu(self): - dtype = torch.float32 + def test_roi_pool_gradient_cuda(self): device = torch.device('cuda') - layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device) - x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True) + layer = layers.ROIPool((5, 5), 1).to(dtype=self.dtype, device=device) + x = torch.ones(1, 1, 10, 10, dtype=self.dtype, device=device, requires_grad=True) rois = torch.tensor([ [0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 0, 0, 4, 4]], - dtype=dtype, device=device) + dtype=self.dtype, device=device) - def func(input): - return layer(input, rois) - - x.requires_grad = True y = layer(x, rois) - # print(argmax, argmax.shape) s = y.sum() s.backward() gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], @@ -181,9 +166,198 @@ def func(input): [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], - [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device) + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], + device=device, dtype=self.dtype) + + assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for roi_pool' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_pool_gradcheck_cuda(self): + device = torch.device('cuda') + x = torch.rand(1, 1, 10, 10, dtype=self.dtype, device=device, requires_grad=True) + rois = torch.tensor([ + [0, 0, 0, 9, 9], + [0, 0, 5, 5, 9], + [0, 5, 5, 9, 9]], dtype=self.dtype, device=device) + + m = layers.ROIPool((5, 5), 1).to(dtype=self.dtype, device=device) + + def func(input): + return m(input, rois) + + assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool CUDA' + + +class ROIAlignTester(unittest.TestCase): + @classmethod + def setup_class(cls): + torch.manual_seed(123) + cls.dtype = torch.float32 + cls.x = torch.rand(1, 1, 10, 10, dtype=cls.dtype) + cls.single_roi = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy) + dtype=cls.dtype) + cls.rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9]], + dtype=cls.dtype) + + cls.gt_y_single = torch.tensor([[[[0.41617328, 0.5040753, 0.25266218, 0.4296828, 0.29928464], + [0.5210769, 0.57222337, 0.2524979, 0.32063985, 0.32635176], + [0.73108256, 0.6114335, 0.62033176, 0.8188273, 0.5562218], + [0.83115816, 0.70803946, 0.7084047, 0.74928707, 0.7769296], + [0.54266506, 0.45964524, 0.5780159, 0.80522037, 0.7321807]]]], dtype=cls.dtype) + + cls.gt_y_multiple = torch.tensor([[[[0.49311584, 0.35972416, 0.40843594, 0.3638034, 0.49751836], + [0.70881474, 0.75481665, 0.5826779, 0.34767765, 0.46865487], + [0.4740328, 0.69306874, 0.3617804, 0.47145438, 0.66130304], + [0.6861706, 0.17634538, 0.47194335, 0.42473823, 0.37930614], + [0.62666404, 0.49973848, 0.37911576, 0.5842756, 0.7176864]]], + [[[0.67499936, 0.6607055, 0.42656037, 0.46134934, 0.42144877], + [0.7471722, 0.7235433, 0.14512213, 0.13031253, 0.289369], + [0.8443615, 0.6659734, 0.23614208, 0.14719573, 0.4268827], + [0.69429564, 0.5621515, 0.5019923, 0.40678093, 0.34556213], + [0.51315194, 0.7177093, 0.6494485, 0.6775592, 0.43865064]]], + [[[0.24465509, 0.36108392, 0.64635646, 0.4051828, 0.33956185], + [0.49006107, 0.42982674, 0.34184104, 0.15493104, 0.49633422], + [0.54400194, 0.5265246, 0.22381854, 0.3929715, 0.6757667], + [0.32961223, 0.38482672, 0.68877804, 0.71822757, 0.711909], + [0.561259, 0.71047884, 0.84651315, 0.8541089, 0.644432]]]], + dtype=cls.dtype) + + cls.x_grad = torch.tensor([[[[0.075625, 0.15125, 0.15124999, 0.15125002, 0.15812504, 0.15812503, 0.15124999, 0.15124999, 0.15125006, 0.0756249], + [0.15125, 0.30250007, 0.3025, 0.30250007, 0.31625012, + 0.31625003, 0.3025, 0.3025, 0.30250013, 0.1512498], + [0.15124999, 0.3025, 0.30249995, 0.3025, 0.31625006, + 0.31625, 0.30249995, 0.30249995, 0.30250007, 0.15124978], + [0.15125002, 0.30250007, 0.3025, 0.30250007, 0.31625012, + 0.3162501, 0.3025, 0.3025, 0.30250013, 0.15124981], + [0.15812504, 0.31625012, 0.31625006, 0.31625012, 0.33062524, + 0.3306251, 0.31625006, 0.31625006, 0.3162502, 0.15812483], + [0.5181251, 1.0962502, 1.0362502, 1.0962503, 0.69062525, 0.6906252, + 1.0962502, 1.0362502, 1.0962503, 0.5181248], + [0.93125, 1.9925, 1.8624997, 1.9925, 1.0962502, 1.0962502, + 1.9925, 1.8624998, 1.9925, 0.9312496], + [0.8712501, 1.8625, 1.7425002, 1.8625001, 1.0362502, 1.0362502, + 1.8625, 1.7425001, 1.8625002, 0.8712497], + [0.93125004, 1.9925, 1.8625002, 1.9925, 1.0962503, 1.0962503, + 1.9925001, 1.8625001, 1.9925001, 0.93124974], + [0.43562484, 0.9312497, 0.8712497, 0.9312497, 0.5181249, 0.5181248, + 0.9312496, 0.8712497, 0.93124974, 0.43562466]]]], + dtype=cls.dtype) + + def test_roi_align_basic_cpu(self): + device = torch.device('cpu') + x = self.x.to(device) + single_roi = self.single_roi.to(device) + gt_y_single = self.gt_y_single.to(device) + + pool_h, pool_w = (5, 5) + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) + y = roi_align(x, single_roi) + + assert torch.allclose(gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CPU' + + def test_roi_align_cpu(self): + device = torch.device('cpu') + x = self.x.to(device) + rois = self.rois.to(device) + gt_y_multiple = self.gt_y_multiple.to(device) + + pool_h, pool_w = (5, 5) + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) + y = roi_align(x, rois) + + assert torch.allclose(gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CPU' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_align_basic_cuda(self): + device = torch.device('cuda') + x = self.x.to(device) + single_roi = self.single_roi.to(device) + gt_y_single = self.gt_y_single.to(device) + + pool_h, pool_w = (5, 5) + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) + y = roi_align(x, single_roi) + + assert torch.allclose(gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CUDA' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_align_cuda(self): + device = torch.device('cuda') + x = self.x.to(device) + rois = self.rois.to(device) + gt_y_multiple = self.gt_y_multiple.to(device) + + pool_h, pool_w = (5, 5) + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) + y = roi_align(x, rois) + + assert torch.allclose(gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CUDA' + + def test_roi_align_gradient_cpu(self): + """ + Compute gradients for ROIAlign with multiple bounding boxes on CPU + """ + device = torch.device('cpu') + pool_h, pool_w = (5, 5) + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) + + x = self.x.to(device).clone() + rois = self.rois.to(device) + gt_grad = self.x_grad.to(device) + + x.requires_grad = True + y = roi_align(x, rois) + s = y.sum() + s.backward() + + assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for ROIAlign CPU' + + def test_roi_align_gradcheck_cpu(self): + dtype = torch.float64 + device = torch.device('cpu') + m = layers.ROIAlign((5, 5), 0.5, 1).to(dtype=dtype, device=device) + x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True) + rois = self.rois.to(device=device, dtype=dtype) + + def func(input): + return m(input, rois) + + assert gradcheck(func, (x,)), 'gradcheck failed for ROIAlign CPU' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_align_gradient_cuda(self): + """ + Compute gradients for ROIAlign with multiple bounding boxes on the GPU + """ + device = torch.device('cuda') + pool_h, pool_w = (5, 5) + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) + + x = self.x.to(device).clone() + rois = self.rois.to(device) + gt_grad = self.x_grad.to(device) + + x.requires_grad = True + y = roi_align(x, rois) + s = y.sum() + s.backward() + + assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for ROIAlign CUDA' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_align_gradcheck_cuda(self): + dtype = torch.float64 + device = torch.device('cuda') + m = layers.ROIAlign((5, 5), 0.5, 1).to(dtype=dtype, device=device) + x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True) + rois = self.rois.to(device=device, dtype=dtype) + + def func(input): + return m(input, rois) - assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool' + assert gradcheck(func, (x,)), 'gradcheck failed for ROIAlign CUDA' if __name__ == '__main__': diff --git a/torchvision/csrc/ROIAlign.h b/torchvision/csrc/ROIAlign.h index 94348abec09..c2e6090857f 100644 --- a/torchvision/csrc/ROIAlign.h +++ b/torchvision/csrc/ROIAlign.h @@ -7,12 +7,13 @@ #endif // Interface for Python -at::Tensor ROIAlign_forward(const at::Tensor& input, - const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio) { +at::Tensor ROIAlign_forward(const at::Tensor& input, // Input feature map. + const at::Tensor& rois, // List of ROIs to pool over. + const float spatial_scale, // The scale of the image features. ROIs will be scaled to this. + const int pooled_height, // The height of the pooled feature map. + const int pooled_width, // The width of the pooled feature + const int sampling_ratio) // The number of points to sample in each bin along each axis. +{ if (input.type().is_cuda()) { #ifdef WITH_CUDA return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); @@ -40,6 +41,6 @@ at::Tensor ROIAlign_backward(const at::Tensor& grad, AT_ERROR("Not compiled with GPU support"); #endif } - AT_ERROR("Not implemented on the CPU"); + return ROIAlign_backward_cpu(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio); } diff --git a/torchvision/csrc/cpu/ROIAlign_cpu.cpp b/torchvision/csrc/cpu/ROIAlign_cpu.cpp index 3850b2833ab..295aa8415f2 100644 --- a/torchvision/csrc/cpu/ROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/ROIAlign_cpu.cpp @@ -110,9 +110,9 @@ void pre_calc_for_bilinear_interpolate( } template -void ROIAlignForward_cpu_kernel( +void ROIAlignForward( const int nthreads, - const T* bottom_data, + const T* input, const T& spatial_scale, const int channels, const int height, @@ -120,12 +120,8 @@ void ROIAlignForward_cpu_kernel( const int pooled_height, const int pooled_width, const int sampling_ratio, - const T* bottom_rois, - //int roi_cols, - T* top_data) { - //AT_ASSERT(roi_cols == 4 || roi_cols == 5); - int roi_cols = 5; - + const T* rois, + T* output) { int n_rois = nthreads / channels / pooled_width / pooled_height; // (n, c, ph, pw) is an element in the pooled output // can be parallelized using omp @@ -133,23 +129,18 @@ void ROIAlignForward_cpu_kernel( for (int n = 0; n < n_rois; n++) { int index_n = n * channels * pooled_width * pooled_height; - // roi could have 4 or 5 columns - const T* offset_bottom_rois = bottom_rois + n * roi_cols; - int roi_batch_ind = 0; - if (roi_cols == 5) { - roi_batch_ind = offset_bottom_rois[0]; - offset_bottom_rois++; - } + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_bottom_rois[0] * spatial_scale; - T roi_start_h = offset_bottom_rois[1] * spatial_scale; - T roi_end_w = offset_bottom_rois[2] * spatial_scale; - T roi_end_h = offset_bottom_rois[3] * spatial_scale; - // T roi_start_w = round(offset_bottom_rois[0] * spatial_scale); - // T roi_start_h = round(offset_bottom_rois[1] * spatial_scale); - // T roi_end_w = round(offset_bottom_rois[2] * spatial_scale); - // T roi_end_h = round(offset_bottom_rois[3] * spatial_scale); + T roi_start_w = offset_rois[1] * spatial_scale; + T roi_start_h = offset_rois[2] * spatial_scale; + T roi_end_w = offset_rois[3] * spatial_scale; + T roi_end_h = offset_rois[4] * spatial_scale; + // T roi_start_w = round(offset_rois[0] * spatial_scale); + // T roi_start_h = round(offset_rois[1] * spatial_scale); + // T roi_end_w = round(offset_rois[2] * spatial_scale); + // T roi_end_h = round(offset_rois[3] * spatial_scale); // Force malformed ROIs to be 1x1 T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); @@ -188,8 +179,8 @@ void ROIAlignForward_cpu_kernel( for (int c = 0; c < channels; c++) { int index_n_c = index_n + c * pooled_width * pooled_height; - const T* offset_bottom_data = - bottom_data + (roi_batch_ind * channels + c) * height * width; + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; int pre_calc_index = 0; for (int ph = 0; ph < pooled_height; ph++) { @@ -200,46 +191,186 @@ void ROIAlignForward_cpu_kernel( for (int iy = 0; iy < roi_bin_grid_h; iy++) { for (int ix = 0; ix < roi_bin_grid_w; ix++) { PreCalc pc = pre_calc[pre_calc_index]; - output_val += pc.w1 * offset_bottom_data[pc.pos1] + - pc.w2 * offset_bottom_data[pc.pos2] + - pc.w3 * offset_bottom_data[pc.pos3] + - pc.w4 * offset_bottom_data[pc.pos4]; + output_val += pc.w1 * offset_input[pc.pos1] + + pc.w2 * offset_input[pc.pos2] + + pc.w3 * offset_input[pc.pos3] + + pc.w4 * offset_input[pc.pos4]; pre_calc_index += 1; } } output_val /= count; - top_data[index] = output_val; + output[index] = output_val; } // for pw } // for ph } // for c } // for n } +template +void bilinear_interpolate_gradient( + const int height, const int width, + T y, T x, + T& w1, T& w2, T& w3, T& w4, + int& x_low, int& x_high, int& y_low, int& y_high, + const int index /* index for debug only*/) { + + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} + +template +inline void add(T* address, const T& val) { + *address += val; +} + +template +void ROIAlignBackward( + const int nthreads, + const T* grad_output, + const T& spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + T* grad_input, + const T* rois, + const int n_stride, const int c_stride, + const int h_stride, const int w_stride) { + for (int index = 0; index < nthreads; index++) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale; + T roi_start_h = offset_rois[2] * spatial_scale; + T roi_end_w = offset_rois[3] * spatial_scale; + T roi_end_h = offset_rois[4] * spatial_scale; + + // Force malformed ROIs to be 1x1 + T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); + T roi_height = std::max(roi_end_h - roi_start_h, (T)1.); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_grad_input = grad_input + ((roi_batch_ind * channels + c) * height * width); + + int output_offset = n*n_stride + c*c_stride; + const T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = offset_grad_output[ph*h_stride + pw*w_stride]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) + { + const T y = roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) + { + const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient(height, width, y, x, + w1, w2, w3, w4, + x_low, x_high, y_low, y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + // atomic add is not needed for now since it is single threaded + add(offset_grad_input + y_low * width + x_low, static_cast(g1)); + add(offset_grad_input + y_low * width + x_high, static_cast(g2)); + add(offset_grad_input + y_high * width + x_low, static_cast(g3)); + add(offset_grad_input + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } // for +} // ROIAlignBackward + + at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale, const int pooled_height, const int pooled_width, const int sampling_ratio) { - AT_ASSERTM(!input.type().is_cuda(), "input must be a CPU tensor"); - AT_ASSERTM(!rois.type().is_cuda(), "rois must be a CPU tensor"); + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); auto num_rois = rois.size(0); auto channels = input.size(1); auto height = input.size(2); auto width = input.size(3); - at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type()); + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); auto output_size = num_rois * pooled_height * pooled_width * channels; if (output.numel() == 0) return output; - AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { - ROIAlignForward_cpu_kernel( + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] { + ROIAlignForward( output_size, input.data(), spatial_scale, @@ -254,3 +385,52 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, }); return output; } + + +at::Tensor ROIAlign_backward_cpu(const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio) { + AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) + { + return grad_input; + } + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_forward", [&] { + ROIAlignBackward( + grad.numel(), + grad.data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + grad_input.data(), + rois.data(), + n_stride, + c_stride, + h_stride, + w_stride); + }); + return grad_input; +} diff --git a/torchvision/csrc/cpu/ROIPool_cpu.cpp b/torchvision/csrc/cpu/ROIPool_cpu.cpp index 8ae35930533..2464ffb4076 100644 --- a/torchvision/csrc/cpu/ROIPool_cpu.cpp +++ b/torchvision/csrc/cpu/ROIPool_cpu.cpp @@ -2,91 +2,158 @@ #include #include + +template +inline void add(T* address, const T& val) { + *address += val; +} + +template +void RoIPoolForward( + const T *input, + const T spatial_scale, const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const T *rois, const int num_rois, + T *output, int *argmax_data) +{ + for (int n = 0; n < num_rois; ++n) + { + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = round(offset_rois[1] * spatial_scale); + int roi_start_h = round(offset_rois[2] * spatial_scale); + int roi_end_w = round(offset_rois[3] * spatial_scale); + int roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + int roi_width = std::max(roi_end_w - roi_start_w + 1, 1); + int roi_height = std::max(roi_end_h - roi_start_h + 1, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + for (int ph = 0; ph < pooled_height; ++ph) + { + for (int pw = 0; pw < pooled_width; ++pw) + { + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), height); + hend = std::min(std::max(hend + roi_start_h, 0), height); + wstart = std::min(std::max(wstart + roi_start_w, 0), width); + wend = std::min(std::max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + T maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + + for (int c = 0; c < channels; ++c) + { + const T* input_offset = input + (roi_batch_ind * channels + c) * height * width; + + for (int h = hstart; h < hend; ++h) + { + for (int w = wstart; w < wend; ++w) + { + int input_index = h * width + w; + if (input_offset[input_index] > maxval) + { + maxval = input_offset[input_index]; + maxidx = input_index; + } + } + } + int index = ((n*channels + c) * pooled_height + ph) * pooled_width + pw; + output[index] = maxval; + argmax_data[index] = maxidx; + } // channels + } // pooled_width + } // pooled_height + } // num_rois +} + +template +void RoIPoolBackward( + const T* grad_output, + const int* argmax_data, const int num_rois, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, + T* grad_input, const T* rois, + const int n_stride, const int c_stride, + const int h_stride, const int w_stride +) +{ + for (int n = 0; n < num_rois; ++n) + { + const T* offset_rois = rois + n*5; + int roi_batch_ind = offset_rois[0]; + + for (int c = 0; c < channels; ++c) + { + T* grad_input_offset = grad_input + ((roi_batch_ind * channels + c) * height * width); + const int* argmax_data_offset = argmax_data + (n*channels + c)*pooled_height*pooled_width; + + for (int ph = 0; ph < pooled_height; ++ph) + { + for (int pw = 0; pw < pooled_width; ++pw) + { + int output_offset = n*n_stride + c*c_stride; + int argmax = argmax_data_offset[ph*pooled_width + pw]; + + if (argmax != -1) { + add(grad_input_offset + argmax, + static_cast(grad_output[output_offset + ph*h_stride + pw*w_stride])); + } + } // pooled_width + } // pooled_height + } // channels + } // num_rois +} + + std::tuple ROIPool_forward_cpu(const at::Tensor &input, const at::Tensor &rois, const float spatial_scale, const int pooled_height, const int pooled_width) { - AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); - AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); - int num_rois = rois.size(0); - int channels = input.size(1); - int input_height = input.size(2); - int input_width = input.size(3); + int num_rois = rois.size(0); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); - at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type()); - at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type().toScalarType(at::kInt)); - - // define accessors for indexing - auto input_a = input.accessor(); - auto rois_a = rois.accessor(); - auto output_a = output.accessor(); - auto argmax_a = argmax.accessor(); - - if (output.numel() == 0) - { - return std::make_tuple(output, argmax); - } - - for (int n = 0; n < num_rois; ++n) - { - int roi_batch_ind = rois_a[n][0]; - int roi_start_w = round(rois_a[n][1] * spatial_scale); - int roi_start_h = round(rois_a[n][2] * spatial_scale); - int roi_end_w = round(rois_a[n][3] * spatial_scale); - int roi_end_h = round(rois_a[n][4] * spatial_scale); - - // Force malformed ROIs to be 1x1 - int roi_width = std::max(roi_end_w - roi_start_w + 1, 1); - int roi_height = std::max(roi_end_h - roi_start_h + 1, 1); - float bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - float bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - for (int ph = 0; ph < pooled_height; ++ph) - { - for (int pw = 0; pw < pooled_width; ++pw) - { - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - // Add roi offsets and clip to input boundaries - hstart = std::min(std::max(hstart + roi_start_h, 0), input_height); - hend = std::min(std::max(hend + roi_start_h, 0), input_height); - wstart = std::min(std::max(wstart + roi_start_w, 0), input_width); - wend = std::min(std::max(wend + roi_start_w, 0), input_width); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - // Define an empty pooling region to be zero - float maxval = is_empty ? 0 : -FLT_MAX; - // If nothing is pooled, argmax = -1 causes nothing to be backprop'd - int maxidx = -1; - - for (int c = 0; c < channels; ++c) - { - for (int h = hstart; h < hend; ++h) - { - for (int w = wstart; w < wend; ++w) - { - int index = h * input_width + w; - if (input_a[roi_batch_ind][c][h][w] > maxval) - { - maxval = input_a[roi_batch_ind][c][h][w]; - maxidx = index; - } - } - } - output_a[n][c][ph][pw] = maxval; - argmax_a[n][c][ph][pw] = maxidx; - } - } - } - } + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt)); + if (output.numel() == 0) + { return std::make_tuple(output, argmax); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIPool_forward", [&] { + RoIPoolForward( + input.data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois.data(), + num_rois, + output.data(), + argmax.data()); + }); + return std::make_tuple(output, argmax); } at::Tensor ROIPool_backward_cpu(const at::Tensor &grad, @@ -100,53 +167,43 @@ at::Tensor ROIPool_backward_cpu(const at::Tensor &grad, const int height, const int width) { - // Check if input tensors are CPU tensors - AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor"); - AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); - AT_ASSERTM(argmax.device().is_cpu(), "argmax must be a CPU tensor"); + // Check if input tensors are CPU tensors + AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + AT_ASSERTM(argmax.device().is_cpu(), "argmax must be a CPU tensor"); - auto num_rois = rois.size(0); + auto num_rois = rois.size(0); - at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type()); - - // handle possibly empty gradients - if (grad.numel() == 0) - { - return grad_input; - } - - // get stride values to ensure indexing into gradients is correct. - int n_stride = grad.stride(0); - int c_stride = grad.stride(1); - int h_stride = grad.stride(2); - int w_stride = grad.stride(3); - - // define accessors for tensors - auto grad_input_a = grad_input.accessor(); - auto grad_a = grad.accessor(); - auto argmax_a = argmax.accessor(); - auto rois_a = rois.accessor(); - - for (int n = 0; n < num_rois; ++n) - { - int roi_batch_ind = rois_a[n][0]; - - for (int c = 0; c < channels; ++c) - { - for (int ph = 0; ph < pooled_height; ++ph) - { - for (int pw = 0; pw < pooled_width; ++pw) - { - int argmax_idx = argmax_a[n][c][ph][pw]; - // get height and width index from argmax index - int h = argmax_idx / height; - int w = argmax_idx % width; - - grad_input_a[roi_batch_ind][c][h][w] += grad_a[n * n_stride][c * c_stride][ph * h_stride][pw * w_stride]; - } - } - } - } + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + // handle possibly empty gradients + if (grad.numel() == 0) + { return grad_input; + } + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIPool_backward", [&] { + RoIPoolBackward( + grad.data(), + argmax.data(), + num_rois, + channels, + height, + width, + pooled_height, + pooled_width, + grad_input.data(), + rois.data(), + n_stride, + c_stride, + h_stride, + w_stride); + }); + return grad_input; } \ No newline at end of file diff --git a/torchvision/csrc/cpu/vision.h b/torchvision/csrc/cpu/vision.h index eebf4b95ad5..e9b32559fda 100644 --- a/torchvision/csrc/cpu/vision.h +++ b/torchvision/csrc/cpu/vision.h @@ -25,6 +25,17 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor &input, const int pooled_width, const int sampling_ratio); +at::Tensor ROIAlign_backward_cpu(const at::Tensor &grad, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio); + at::Tensor nms_cpu(const at::Tensor &dets, const at::Tensor &scores, const float threshold); diff --git a/torchvision/csrc/cuda/ROIAlign_cuda.cu b/torchvision/csrc/cuda/ROIAlign_cuda.cu index 9cc5ae28934..c21e5538997 100644 --- a/torchvision/csrc/cuda/ROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/ROIAlign_cuda.cu @@ -5,14 +5,11 @@ #include #include -// TODO make it in a common file -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ - i += blockDim.x * gridDim.x) +#include "cuda_helpers.h" template -__device__ T bilinear_interpolate(const T* bottom_data, +__device__ T bilinear_interpolate(const T* input, const int height, const int width, T y, T x, const int index /* index for debug only*/) { @@ -48,11 +45,12 @@ __device__ T bilinear_interpolate(const T* bottom_data, T ly = y - y_low; T lx = x - x_low; T hy = 1. - ly, hx = 1. - lx; + // do bilinear interpolation - T v1 = bottom_data[y_low * width + x_low]; - T v2 = bottom_data[y_low * width + x_high]; - T v3 = bottom_data[y_high * width + x_low]; - T v4 = bottom_data[y_high * width + x_high]; + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); @@ -61,12 +59,12 @@ __device__ T bilinear_interpolate(const T* bottom_data, } template -__global__ void RoIAlignForward(const int nthreads, const T* bottom_data, +__global__ void RoIAlignForward(const int nthreads, const T* input, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int sampling_ratio, - const T* bottom_rois, T* top_data) { + const T* rois, T* output) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -74,18 +72,14 @@ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data, int c = (index / pooled_width / pooled_height) % channels; int n = index / pooled_width / pooled_height / channels; - const T* offset_bottom_rois = bottom_rois + n * 5; - int roi_batch_ind = offset_bottom_rois[0]; + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_bottom_rois[1] * spatial_scale; - T roi_start_h = offset_bottom_rois[2] * spatial_scale; - T roi_end_w = offset_bottom_rois[3] * spatial_scale; - T roi_end_h = offset_bottom_rois[4] * spatial_scale; - // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); - // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); - // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); - // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + T roi_start_w = offset_rois[1] * spatial_scale; + T roi_start_h = offset_rois[2] * spatial_scale; + T roi_end_w = offset_rois[3] * spatial_scale; + T roi_end_h = offset_rois[4] * spatial_scale; // Force malformed ROIs to be 1x1 T roi_width = max(roi_end_w - roi_start_w, (T)1.); @@ -93,7 +87,7 @@ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data, T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; + const T* offset_input = input + (roi_batch_ind * channels + c) * height * width; // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 @@ -110,13 +104,13 @@ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data, { const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); - T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index); + T val = bilinear_interpolate(offset_input, height, width, y, x, index); output_val += val; } } output_val /= count; - top_data[index] = output_val; + output[index] = output_val; } } @@ -162,10 +156,10 @@ __device__ void bilinear_interpolate_gradient( T hy = 1. - ly, hx = 1. - lx; // reference in forward - // T v1 = bottom_data[y_low * width + x_low]; - // T v2 = bottom_data[y_low * width + x_high]; - // T v3 = bottom_data[y_high * width + x_low]; - // T v4 = bottom_data[y_high * width + x_high]; + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; @@ -174,13 +168,15 @@ __device__ void bilinear_interpolate_gradient( } template -__global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff, - const int num_rois, const T spatial_scale, +__global__ void RoIAlignBackward(const int nthreads, const T* grad_output, + const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int sampling_ratio, - T* bottom_diff, - const T* bottom_rois) { + T* grad_input, + const T* rois, + const int n_stride, const int c_stride, + const int h_stride, const int w_stride) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -188,30 +184,27 @@ __global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff, int c = (index / pooled_width / pooled_height) % channels; int n = index / pooled_width / pooled_height / channels; - const T* offset_bottom_rois = bottom_rois + n * 5; - int roi_batch_ind = offset_bottom_rois[0]; + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_bottom_rois[1] * spatial_scale; - T roi_start_h = offset_bottom_rois[2] * spatial_scale; - T roi_end_w = offset_bottom_rois[3] * spatial_scale; - T roi_end_h = offset_bottom_rois[4] * spatial_scale; - // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); - // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); - // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); - // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); - + T roi_start_w = offset_rois[1] * spatial_scale; + T roi_start_h = offset_rois[2] * spatial_scale; + T roi_end_w = offset_rois[3] * spatial_scale; + T roi_end_h = offset_rois[4] * spatial_scale; + // Force malformed ROIs to be 1x1 T roi_width = max(roi_end_w - roi_start_w, (T)1.); T roi_height = max(roi_end_h - roi_start_h, (T)1.); T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width; + T* offset_grad_input = grad_input + ((roi_batch_ind * channels + c) * height * width); - int top_offset = (n * channels + c) * pooled_height * pooled_width; - const T* offset_top_diff = top_diff + top_offset; - const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; + // We need to index the gradient using the tensor strides to access the correct values. + int output_offset = n*n_stride + c*c_stride; + const T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = offset_grad_output[ph*h_stride + pw*w_stride]; // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 @@ -235,17 +228,17 @@ __global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff, x_low, x_high, y_low, y_high, index); - T g1 = top_diff_this_bin * w1 / count; - T g2 = top_diff_this_bin * w2 / count; - T g3 = top_diff_this_bin * w3 / count; - T g4 = top_diff_this_bin * w4 / count; + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - atomicAdd(offset_bottom_diff + y_low * width + x_low, static_cast(g1)); - atomicAdd(offset_bottom_diff + y_low * width + x_high, static_cast(g2)); - atomicAdd(offset_bottom_diff + y_high * width + x_low, static_cast(g3)); - atomicAdd(offset_bottom_diff + y_high * width + x_high, static_cast(g4)); + atomicAdd(offset_grad_input + y_low * width + x_low, static_cast(g1)); + atomicAdd(offset_grad_input + y_low * width + x_high, static_cast(g2)); + atomicAdd(offset_grad_input + y_high * width + x_low, static_cast(g3)); + atomicAdd(offset_grad_input + y_high * width + x_high, static_cast(g4)); } // if } // ix } // iy @@ -259,8 +252,8 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const int pooled_height, const int pooled_width, const int sampling_ratio) { - AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); - AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); auto num_rois = rois.size(0); auto channels = input.size(1); @@ -280,7 +273,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, return output; } - AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] { RoIAlignForward<<>>( output_size, input.data(), @@ -298,7 +291,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, return output; } -// TODO remove the dependency on input and use instead its sizes -> save memory + at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale, @@ -309,10 +302,9 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, const int height, const int width, const int sampling_ratio) { - AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); - AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); - auto num_rois = rois.size(0); at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -326,11 +318,15 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, return grad_input; } - AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] { - RoIAlignBackwardFeature<<>>( + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_backward", [&] { + RoIAlignBackward<<>>( grad.numel(), grad.data(), - num_rois, spatial_scale, channels, height, @@ -339,7 +335,11 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, pooled_width, sampling_ratio, grad_input.data(), - rois.data()); + rois.data(), + n_stride, + c_stride, + h_stride, + w_stride); }); THCudaCheck(cudaGetLastError()); return grad_input; diff --git a/torchvision/csrc/cuda/ROIPool_cuda.cu b/torchvision/csrc/cuda/ROIPool_cuda.cu index 2ba8dc33e25..6090dd9da3a 100644 --- a/torchvision/csrc/cuda/ROIPool_cuda.cu +++ b/torchvision/csrc/cuda/ROIPool_cuda.cu @@ -6,7 +6,6 @@ #include #include "cuda_helpers.h" -#include template @@ -93,8 +92,8 @@ __global__ void RoIPoolBackward(const int nthreads, const T* grad_output, T* grad_input_offset = grad_input + ((roi_batch_ind * channels + c) * height * width); int output_offset = n*n_stride + c*c_stride; - const int* argmax_data_offset = argmax_data + n*channels*pooled_height*pooled_width; - int argmax = argmax_data_offset[c*pooled_height*pooled_width + ph*pooled_width + pw]; + const int* argmax_data_offset = argmax_data + (n*channels + c)*pooled_height*pooled_width; + int argmax = argmax_data_offset[ph*pooled_width + pw]; if (argmax != -1) { atomicAdd(grad_input_offset + argmax, @@ -108,16 +107,16 @@ std::tuple ROIPool_forward_cuda(const at::Tensor& input, const float spatial_scale, const int pooled_height, const int pooled_width) { - AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); - AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); auto num_rois = rois.size(0); auto channels = input.size(1); auto height = input.size(2); auto width = input.size(3); - at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type()); - at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type().toScalarType(at::kInt)); + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt)); auto output_size = num_rois * pooled_height * pooled_width * channels; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -159,13 +158,13 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, const int height, const int width) { // Check if input tensors are CUDA tensors - AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); - AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); - AT_ASSERTM(argmax.type().is_cuda(), "argmax must be a CUDA tensor"); + AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(argmax.device().is_cuda(), "argmax must be a CUDA tensor"); auto num_rois = rois.size(0); - at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type()); + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); cudaStream_t stream = at::cuda::getCurrentCUDAStream();