Skip to content

Commit 0fd0f50

Browse files
authored
Added bicubic support for interpolation with AA (#3810)
* Added support for bicubic mode with AA * Updated comment in the test
1 parent e35793a commit 0fd0f50

File tree

5 files changed

+192
-75
lines changed

5 files changed

+192
-75
lines changed

test/test_functional_tensor.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,7 @@ def test_perspective_interpolation_warning(tester):
10211021
@pytest.mark.parametrize('device', ["cpu", ])
10221022
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
10231023
@pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]])
1024-
@pytest.mark.parametrize('interpolation', [BILINEAR, ])
1024+
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC])
10251025
def test_resize_antialias(device, dt, size, interpolation, tester):
10261026

10271027
if dt == torch.float16 and device == "cpu":
@@ -1051,8 +1051,17 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
10511051
tester.approxEqualTensorToPIL(
10521052
resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}"
10531053
)
1054+
1055+
accepted_tol = 1.0 + 1e-5
1056+
if interpolation == BICUBIC:
1057+
# this overall mean value to make the tests pass
1058+
# High value is mostly required for test cases with
1059+
# downsampling and upsampling where we can not exactly
1060+
# match PIL implementation.
1061+
accepted_tol = 15.0
1062+
10541063
tester.approxEqualTensorToPIL(
1055-
resized_tensor_f, resized_pil_img, tol=1.0 + 1e-5, agg_method="max",
1064+
resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max",
10561065
msg=f"{size}, {interpolation}, {dt}"
10571066
)
10581067

torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp

Lines changed: 155 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -141,66 +141,7 @@ void ti_cpu_upsample_generic_aa(
141141
// Helper structs to use with ti_upsample_generic_Nd_kernel_impl
142142
template <typename index_t, typename scalar_t>
143143
struct HelperInterpBase {
144-
static inline void init_indices_weights(
145-
std::vector<Tensor>& output,
146-
int64_t output_size,
147-
int64_t ndims,
148-
int64_t reshape_dim,
149-
int interp_size) {
150-
auto new_shape = std::vector<int64_t>(ndims, 1);
151-
new_shape[reshape_dim] = output_size;
152-
153-
for (int j = 0; j < interp_size; j++) {
154-
output.emplace_back(
155-
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
156-
output.emplace_back(
157-
empty(new_shape, CPU(c10::CppTypeToScalarType<scalar_t>())));
158-
}
159-
}
160-
};
161-
162-
template <typename index_t, typename scalar_t>
163-
struct HelperInterpLinear : public HelperInterpBase<index_t, scalar_t> {
164-
static const int interp_size = 2;
165-
166-
static inline std::vector<Tensor> compute_indices_weights(
167-
int64_t input_size,
168-
int64_t output_size,
169-
int64_t stride,
170-
int64_t ndims,
171-
int64_t reshape_dim,
172-
bool align_corners,
173-
const c10::optional<double> opt_scale,
174-
bool antialias,
175-
int& out_interp_size) {
176-
scalar_t scale = area_pixel_compute_scale<scalar_t>(
177-
input_size, output_size, align_corners, opt_scale);
178-
TORCH_INTERNAL_ASSERT(antialias);
179-
180-
return _compute_indices_weights_aa(
181-
input_size,
182-
output_size,
183-
stride,
184-
ndims,
185-
reshape_dim,
186-
align_corners,
187-
scale,
188-
out_interp_size);
189-
}
190-
191-
// taken from
192-
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
193-
// src/libImaging/Resample.c#L20-L29
194-
static inline scalar_t _filter(scalar_t x) {
195-
if (x < 0.0) {
196-
x = -x;
197-
}
198-
if (x < 1.0) {
199-
return 1.0 - x;
200-
}
201-
return 0.0;
202-
}
203-
144+
template <typename filter_fn_t>
204145
static inline std::vector<Tensor> _compute_indices_weights_aa(
205146
int64_t input_size,
206147
int64_t output_size,
@@ -209,14 +150,15 @@ struct HelperInterpLinear : public HelperInterpBase<index_t, scalar_t> {
209150
int64_t reshape_dim,
210151
bool align_corners,
211152
scalar_t scale,
212-
int& out_interp_size) {
213-
int interp_size = HelperInterpLinear<index_t, scalar_t>::interp_size;
153+
int& in_out_interp_size,
154+
filter_fn_t filter_fn) {
155+
int interp_size = in_out_interp_size;
214156
scalar_t support =
215-
(scale >= 1.0) ? (interp_size / 2) * scale : interp_size / 2 * 1.0;
157+
(scale >= 1.0) ? (interp_size * 0.5) * scale : interp_size * 0.5;
216158
interp_size = (int)ceilf(support) * 2 + 1;
217159

218160
// return interp_size
219-
out_interp_size = interp_size;
161+
in_out_interp_size = interp_size;
220162

221163
std::vector<Tensor> output;
222164
auto new_shape = std::vector<int64_t>(ndims, 1);
@@ -269,7 +211,7 @@ struct HelperInterpLinear : public HelperInterpBase<index_t, scalar_t> {
269211

270212
total_w = 0.0;
271213
for (j = 0; j < xmax; j++) {
272-
scalar_t w = _filter((j + xmin - center + 0.5) * invscale);
214+
scalar_t w = filter_fn((j + xmin - center + 0.5) * invscale);
273215
wt_ptr[i * interp_size + j] = w;
274216
total_w += w;
275217
}
@@ -287,6 +229,102 @@ struct HelperInterpLinear : public HelperInterpBase<index_t, scalar_t> {
287229
}
288230
};
289231

232+
template <typename index_t, typename scalar_t>
233+
struct HelperInterpLinear : public HelperInterpBase<index_t, scalar_t> {
234+
static const int interp_size = 2;
235+
236+
// taken from
237+
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
238+
// src/libImaging/Resample.c#L20-L29
239+
static inline scalar_t _filter(scalar_t x) {
240+
if (x < 0.0) {
241+
x = -x;
242+
}
243+
if (x < 1.0) {
244+
return 1.0 - x;
245+
}
246+
return 0.0;
247+
}
248+
249+
static inline std::vector<Tensor> compute_indices_weights(
250+
int64_t input_size,
251+
int64_t output_size,
252+
int64_t stride,
253+
int64_t ndims,
254+
int64_t reshape_dim,
255+
bool align_corners,
256+
const c10::optional<double> opt_scale,
257+
bool antialias,
258+
int& out_interp_size) {
259+
TORCH_INTERNAL_ASSERT(antialias);
260+
scalar_t scale = area_pixel_compute_scale<scalar_t>(
261+
input_size, output_size, align_corners, opt_scale);
262+
263+
out_interp_size = HelperInterpLinear<index_t, scalar_t>::interp_size;
264+
return HelperInterpLinear<index_t, scalar_t>::_compute_indices_weights_aa(
265+
input_size,
266+
output_size,
267+
stride,
268+
ndims,
269+
reshape_dim,
270+
align_corners,
271+
scale,
272+
out_interp_size,
273+
_filter);
274+
}
275+
};
276+
277+
template <typename index_t, typename scalar_t>
278+
struct HelperInterpCubic : public HelperInterpBase<index_t, scalar_t> {
279+
static const int interp_size = 4;
280+
281+
static inline std::vector<Tensor> compute_indices_weights(
282+
int64_t input_size,
283+
int64_t output_size,
284+
int64_t stride,
285+
int64_t ndims,
286+
int64_t reshape_dim,
287+
bool align_corners,
288+
const c10::optional<double> opt_scale,
289+
bool antialias,
290+
int& out_interp_size) {
291+
TORCH_INTERNAL_ASSERT(antialias);
292+
scalar_t scale = area_pixel_compute_scale<scalar_t>(
293+
input_size, output_size, align_corners, opt_scale);
294+
295+
out_interp_size = HelperInterpCubic<index_t, scalar_t>::interp_size;
296+
return HelperInterpCubic<index_t, scalar_t>::_compute_indices_weights_aa(
297+
input_size,
298+
output_size,
299+
stride,
300+
ndims,
301+
reshape_dim,
302+
align_corners,
303+
scale,
304+
out_interp_size,
305+
_filter);
306+
}
307+
308+
// taken from
309+
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
310+
// src/libImaging/Resample.c#L46-L62
311+
static inline scalar_t _filter(scalar_t x) {
312+
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
313+
#define a -0.5
314+
if (x < 0.0) {
315+
x = -x;
316+
}
317+
if (x < 1.0) {
318+
return ((a + 2.0) * x - (a + 3.0)) * x * x + 1;
319+
}
320+
if (x < 2.0) {
321+
return (((x - 5) * x + 8) * x - 4) * a;
322+
}
323+
return 0.0;
324+
#undef a
325+
}
326+
};
327+
290328
template <
291329
typename index_t,
292330
int out_ndims,
@@ -396,16 +434,15 @@ void ti_separable_upsample_generic_Nd_kernel_impl(
396434
index_t,
397435
out_ndims,
398436
scale_t,
399-
HelperInterpLinear>(
437+
F>(
400438
temp_output, temp_input, interp_dim, align_corners, scales, antialias);
401439
temp_input = temp_output;
402440
}
403441
_ti_separable_upsample_generic_Nd_kernel_impl_single_dim<
404442
index_t,
405443
out_ndims,
406444
scale_t,
407-
HelperInterpLinear>(
408-
output, temp_input, 2, align_corners, scales, antialias);
445+
F>(output, temp_input, 2, align_corners, scales, antialias);
409446
}
410447

411448
void _ti_upsample_bilinear2d_kernel_impl(
@@ -423,6 +460,21 @@ void _ti_upsample_bilinear2d_kernel_impl(
423460
output, input, align_corners, {scales_h, scales_w}, antialias);
424461
}
425462

463+
void _ti_upsample_bicubic2d_kernel_impl(
464+
Tensor& output,
465+
const Tensor& input,
466+
bool align_corners,
467+
c10::optional<double> scales_h,
468+
c10::optional<double> scales_w,
469+
bool antialias) {
470+
ti_separable_upsample_generic_Nd_kernel_impl<
471+
int64_t,
472+
2,
473+
scale_t,
474+
HelperInterpCubic>(
475+
output, input, align_corners, {scales_h, scales_w}, antialias);
476+
}
477+
426478
} // namespace internal_upsample
427479
} // namespace native
428480
} // namespace at
@@ -463,6 +515,37 @@ at::Tensor interpolate_linear_aa_forward_kernel(
463515
return output;
464516
}
465517

518+
at::Tensor interpolate_bicubic_aa_forward_kernel(
519+
const at::Tensor& input,
520+
at::IntArrayRef output_size,
521+
bool align_corners) {
522+
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
523+
524+
c10::optional<c10::ArrayRef<double>> scale_factors = {};
525+
526+
// Copied from UpSampleBilinear2d.cpp
527+
auto output = at::empty({0}, input.options());
528+
auto osize = at::native::upsample::compute_output_size(
529+
input.sizes(), output_size, scale_factors);
530+
auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0);
531+
auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1);
532+
auto full_output_size =
533+
at::native::upsample_2d_common_check(input.sizes(), osize);
534+
535+
// Allow for empty batch size but not other dimensions
536+
TORCH_CHECK(
537+
input.numel() != 0 ||
538+
c10::multiply_integers(
539+
input.sizes().begin() + 1, input.sizes().end()),
540+
"Non-empty 4D data tensor expected but got a tensor with sizes ",
541+
input.sizes());
542+
543+
output.resize_(full_output_size, input.suggest_memory_format());
544+
at::native::internal_upsample::_ti_upsample_bicubic2d_kernel_impl(
545+
output, input, align_corners, scale_h, scale_w, /*antialias=*/true);
546+
return output;
547+
}
548+
466549
// TODO: Implement backward function
467550
// at::Tensor interpolate_linear_aa_backward_kernel(
468551
// const at::Tensor& grad) {
@@ -475,6 +558,10 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
475558
m.impl(
476559
TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa"),
477560
TORCH_FN(interpolate_linear_aa_forward_kernel));
561+
m.impl(
562+
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic_aa"),
563+
TORCH_FN(interpolate_bicubic_aa_forward_kernel));
564+
478565
// TODO: Implement backward function
479566
// m.impl(
480567
// TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa_backward"),

torchvision/csrc/ops/interpolate_aa.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,23 @@ at::Tensor interpolate_linear_aa(
1212
{
1313
static auto op =
1414
c10::Dispatcher::singleton()
15-
.findSchemaOrThrow("torchvision::interpolate_linear_aa", "")
15+
.findSchemaOrThrow("torchvision::_interpolate_linear_aa", "")
1616
.typed<decltype(interpolate_linear_aa)>();
1717
return op.call(input, output_size, align_corners);
1818
}
1919

20+
at::Tensor interpolate_bicubic_aa(
21+
const at::Tensor& input, // Input image
22+
at::IntArrayRef output_size, // Output image size
23+
bool align_corners) // The flag to align corners
24+
{
25+
static auto op =
26+
c10::Dispatcher::singleton()
27+
.findSchemaOrThrow("torchvision::_interpolate_bicubic_aa", "")
28+
.typed<decltype(_interpolate_bicubic_aa)>();
29+
return op.call(input, output_size, align_corners);
30+
}
31+
2032
namespace detail {
2133

2234
// TODO: Implement backward function
@@ -33,6 +45,8 @@ namespace detail {
3345
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
3446
m.def(TORCH_SELECTIVE_SCHEMA(
3547
"torchvision::_interpolate_linear_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor"));
48+
m.def(TORCH_SELECTIVE_SCHEMA(
49+
"torchvision::_interpolate_bicubic_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor"));
3650
// TODO: Implement backward function
3751
// m.def(TORCH_SELECTIVE_SCHEMA(
3852
// "torchvision::_interpolate_linear_aa_backward(Tensor grad, Tensor rois,

torchvision/csrc/ops/interpolate_aa.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ VISION_API at::Tensor _interpolate_linear_aa(
1111
at::IntArrayRef output_size,
1212
bool align_corners = false);
1313

14+
VISION_API at::Tensor _interpolate_bicubic_aa(
15+
const at::Tensor& input,
16+
at::IntArrayRef output_size,
17+
bool align_corners = false);
18+
1419
namespace detail {
1520

1621
// TODO: Implement backward function

torchvision/transforms/functional_tensor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,8 @@ def resize(
503503
if antialias is None:
504504
antialias = False
505505

506-
if antialias and interpolation not in ["bilinear", ]:
507-
raise ValueError("Antialias option is supported for bilinear interpolation mode only")
506+
if antialias and interpolation not in ["bilinear", "bicubic"]:
507+
raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")
508508

509509
w, h = _get_image_size(img)
510510

@@ -537,8 +537,10 @@ def resize(
537537
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
538538

539539
if antialias:
540-
# Apply antialias for donwsampling on both dims
541-
img = torch.ops.torchvision._interpolate_linear_aa(img, [new_h, new_w], align_corners=False)
540+
if interpolation == "bilinear":
541+
img = torch.ops.torchvision._interpolate_linear_aa(img, [new_h, new_w], align_corners=False)
542+
elif interpolation == "bicubic":
543+
img = torch.ops.torchvision._interpolate_bicubic_aa(img, [new_h, new_w], align_corners=False)
542544
else:
543545
img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners)
544546

0 commit comments

Comments
 (0)