@@ -45,11 +45,14 @@ class _FeatureExtractorInceptionV3(Module): # type: ignore[no-redef]
4545class NoTrainInceptionV3 (_FeatureExtractorInceptionV3 ):
4646 """Module that never leaves evaluation mode."""
4747
48+ INPUT_IMAGE_SIZE : int
49+
4850 def __init__ (
4951 self ,
5052 name : str ,
5153 features_list : list [str ],
5254 feature_extractor_weights_path : Optional [str ] = None ,
55+ antialias : bool = True ,
5356 ) -> None :
5457 if not _TORCH_FIDELITY_AVAILABLE :
5558 raise ModuleNotFoundError (
@@ -58,6 +61,7 @@ def __init__(
5861 )
5962
6063 super ().__init__ (name , features_list , feature_extractor_weights_path )
64+ self .use_antialias = antialias
6165 # put into evaluation mode
6266 self .eval ()
6367
@@ -81,11 +85,21 @@ def _torch_fidelity_forward(self, x: Tensor) -> tuple[Tensor, ...]:
8185 remaining_features = self .features_list .copy ()
8286
8387 x = x .to (self ._dtype ) if hasattr (self , "_dtype" ) else x .to (torch .float )
84- x = interpolate_bilinear_2d_like_tensorflow1x (
85- x ,
86- size = (self .INPUT_IMAGE_SIZE , self .INPUT_IMAGE_SIZE ),
87- align_corners = False ,
88- )
88+ if self .use_antialias :
89+ x = torch .nn .functional .interpolate (
90+ x ,
91+ size = (self .INPUT_IMAGE_SIZE , self .INPUT_IMAGE_SIZE ),
92+ mode = "bilinear" ,
93+ align_corners = False ,
94+ antialias = True ,
95+ )
96+ else :
97+ x = interpolate_bilinear_2d_like_tensorflow1x (
98+ x ,
99+ size = (self .INPUT_IMAGE_SIZE , self .INPUT_IMAGE_SIZE ),
100+ align_corners = False ,
101+ )
102+
89103 x = (x - 128 ) / 128
90104
91105 x = self .Conv2d_1a_3x3 (x )
@@ -250,6 +264,9 @@ class FrechetInceptionDistance(Metric):
250264 - True: if input imgs are expected to be in the data type of torch.float32.
251265 - False: if input imgs are expected to be in the data type of torch.int8.
252266 input_img_size: tuple of integers. Indicates input img size to the custom feature extractor network if provided.
267+ use_antialias: boolian flag to indicate whether to use antialiasing when resizing images. This will change the
268+ resize function to use bilinear interpolation with antialiasing, which is different from the original
269+ Inception v3 implementation. Does not apply to custom feature extractor networks.
253270 kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
254271
255272 Raises:
@@ -301,6 +318,7 @@ def __init__(
301318 normalize : bool = False ,
302319 input_img_size : tuple [int , int , int ] = (3 , 299 , 299 ),
303320 feature_extractor_weights_path : Optional [str ] = None ,
321+ antialias : bool = True ,
304322 ** kwargs : Any ,
305323 ) -> None :
306324 super ().__init__ (** kwargs )
@@ -309,6 +327,7 @@ def __init__(
309327 raise ValueError ("Argument `normalize` expected to be a bool" )
310328 self .normalize = normalize
311329 self .used_custom_model = False
330+ antialias = antialias
312331
313332 if isinstance (feature , int ):
314333 num_features = feature
@@ -327,6 +346,7 @@ def __init__(
327346 name = "inception-v3-compat" ,
328347 features_list = [str (feature )],
329348 feature_extractor_weights_path = feature_extractor_weights_path ,
349+ antialias = antialias ,
330350 )
331351
332352 elif isinstance (feature , Module ):
0 commit comments