3
3
import math
4
4
import os
5
5
import re
6
+ from functools import partial
6
7
from typing import Sequence
7
8
8
9
import numpy as np
@@ -655,11 +656,13 @@ def test_resize_antialias(device, dt, size, interpolation):
655
656
def test_assert_resize_antialias (interpolation ):
656
657
657
658
# Checks implementation on very large scales
658
- # and catch TORCH_CHECK inside interpolate_aa_kernels.cu
659
+ # and catch TORCH_CHECK inside PyTorch implementation
659
660
torch .manual_seed (12 )
660
- tensor , pil_img = _create_data (1000 , 1000 , device = "cuda" )
661
+ tensor , _ = _create_data (1000 , 1000 , device = "cuda" )
661
662
662
- with pytest .raises (RuntimeError , match = r"Max supported scale factor is" ):
663
+ # Error message is not yet updated in pytorch nightly
664
+ # with pytest.raises(RuntimeError, match=r"Provided interpolation parameters can not be handled"):
665
+ with pytest .raises (RuntimeError , match = r"Too much shared memory required" ):
663
666
F .resize (tensor , size = (5 , 5 ), interpolation = interpolation , antialias = True )
664
667
665
668
@@ -674,32 +677,12 @@ def test_interpolate_antialias_backward(device, dt, size, interpolation):
674
677
return
675
678
676
679
torch .manual_seed (12 )
677
- if interpolation == BILINEAR :
678
- forward_op = torch .ops .torchvision ._interpolate_bilinear2d_aa
679
- backward_op = torch .ops .torchvision ._interpolate_bilinear2d_aa_backward
680
- elif interpolation == BICUBIC :
681
- forward_op = torch .ops .torchvision ._interpolate_bicubic2d_aa
682
- backward_op = torch .ops .torchvision ._interpolate_bicubic2d_aa_backward
683
-
684
- class F (torch .autograd .Function ):
685
- @staticmethod
686
- def forward (ctx , i ):
687
- result = forward_op (i , size , False )
688
- ctx .save_for_backward (i , result )
689
- return result
690
-
691
- @staticmethod
692
- def backward (ctx , grad_output ):
693
- i , result = ctx .saved_tensors
694
- ishape = i .shape
695
- oshape = result .shape [2 :]
696
- return backward_op (grad_output , oshape , ishape , False )
697
-
698
680
x = (torch .rand (1 , 32 , 29 , 3 , dtype = torch .double , device = device ).permute (0 , 3 , 1 , 2 ).requires_grad_ (True ),)
699
- assert torch .autograd .gradcheck (F .apply , x , eps = 1e-8 , atol = 1e-6 , rtol = 1e-6 , fast_mode = False )
681
+ resize = partial (F .resize , size = size , interpolation = interpolation , antialias = True )
682
+ assert torch .autograd .gradcheck (resize , x , eps = 1e-8 , atol = 1e-6 , rtol = 1e-6 , fast_mode = False )
700
683
701
684
x = (torch .rand (1 , 3 , 32 , 29 , dtype = torch .double , device = device , requires_grad = True ),)
702
- assert torch .autograd .gradcheck (F . apply , x , eps = 1e-8 , atol = 1e-6 , rtol = 1e-6 , fast_mode = False )
685
+ assert torch .autograd .gradcheck (resize , x , eps = 1e-8 , atol = 1e-6 , rtol = 1e-6 , fast_mode = False )
703
686
704
687
705
688
def check_functional_vs_PIL_vs_scripted (
0 commit comments