Skip to content

Commit dcc9856

Browse files
authored
Merge branch 'main' into fix_resize_small_edge
2 parents e54dc51 + 26fe8fa commit dcc9856

File tree

17 files changed

+159
-1634
lines changed

17 files changed

+159
-1634
lines changed

android/ops/CMakeLists.txt

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,6 @@ file(GLOB VISION_SRCS
1414
../../torchvision/csrc/ops/*.h
1515
../../torchvision/csrc/ops/*.cpp)
1616

17-
# Remove interpolate_aa sources as they are temporary code
18-
# see https://github.com/pytorch/vision/pull/3761
19-
# and IndexingUtils.h is unavailable on Android build
20-
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp")
21-
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.cpp")
22-
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.h")
23-
2417
add_library(${TARGET} SHARED
2518
${VISION_SRCS}
2619
)

ios/CMakeLists.txt

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,6 @@ file(GLOB VISION_SRCS
1111
../torchvision/csrc/ops/*.h
1212
../torchvision/csrc/ops/*.cpp)
1313

14-
# Remove interpolate_aa sources as they are temporary code
15-
# see https://github.com/pytorch/vision/pull/3761
16-
# and using TensorIterator unavailable with iOS
17-
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp")
18-
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.cpp")
19-
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.h")
20-
2114
add_library(${TARGET} STATIC
2215
${VISION_SRCS}
2316
)

test/test_functional_tensor.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import math
44
import os
55
import re
6+
from functools import partial
67
from typing import Sequence
78

89
import numpy as np
@@ -655,11 +656,13 @@ def test_resize_antialias(device, dt, size, interpolation):
655656
def test_assert_resize_antialias(interpolation):
656657

657658
# Checks implementation on very large scales
658-
# and catch TORCH_CHECK inside interpolate_aa_kernels.cu
659+
# and catch TORCH_CHECK inside PyTorch implementation
659660
torch.manual_seed(12)
660-
tensor, pil_img = _create_data(1000, 1000, device="cuda")
661+
tensor, _ = _create_data(1000, 1000, device="cuda")
661662

662-
with pytest.raises(RuntimeError, match=r"Max supported scale factor is"):
663+
# Error message is not yet updated in pytorch nightly
664+
# with pytest.raises(RuntimeError, match=r"Provided interpolation parameters can not be handled"):
665+
with pytest.raises(RuntimeError, match=r"Too much shared memory required"):
663666
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)
664667

665668

@@ -674,32 +677,12 @@ def test_interpolate_antialias_backward(device, dt, size, interpolation):
674677
return
675678

676679
torch.manual_seed(12)
677-
if interpolation == BILINEAR:
678-
forward_op = torch.ops.torchvision._interpolate_bilinear2d_aa
679-
backward_op = torch.ops.torchvision._interpolate_bilinear2d_aa_backward
680-
elif interpolation == BICUBIC:
681-
forward_op = torch.ops.torchvision._interpolate_bicubic2d_aa
682-
backward_op = torch.ops.torchvision._interpolate_bicubic2d_aa_backward
683-
684-
class F(torch.autograd.Function):
685-
@staticmethod
686-
def forward(ctx, i):
687-
result = forward_op(i, size, False)
688-
ctx.save_for_backward(i, result)
689-
return result
690-
691-
@staticmethod
692-
def backward(ctx, grad_output):
693-
i, result = ctx.saved_tensors
694-
ishape = i.shape
695-
oshape = result.shape[2:]
696-
return backward_op(grad_output, oshape, ishape, False)
697-
698680
x = (torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True),)
699-
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
681+
resize = partial(F.resize, size=size, interpolation=interpolation, antialias=True)
682+
assert torch.autograd.gradcheck(resize, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
700683

701684
x = (torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True),)
702-
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
685+
assert torch.autograd.gradcheck(resize, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
703686

704687

705688
def check_functional_vs_PIL_vs_scripted(

0 commit comments

Comments
 (0)