@@ -9,9 +9,10 @@ class Upsample2D(nn.Module):
99 """
1010 An upsampling layer with an optional convolution.
1111
12- :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
13- applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
14- upsampling occurs in the inner-two dimensions.
12+ Parameters:
13+ channels: channels in the inputs and outputs.
14+ use_conv: a bool determining if a convolution is applied.
15+ dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
1516 """
1617
1718 def __init__ (self , channels , use_conv = False , use_conv_transpose = False , out_channels = None , name = "conv" ):
@@ -61,9 +62,10 @@ class Downsample2D(nn.Module):
6162 """
6263 A downsampling layer with an optional convolution.
6364
64- :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
65- applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
66- downsampling occurs in the inner-two dimensions.
65+ Parameters:
66+ channels: channels in the inputs and outputs.
67+ use_conv: a bool determining if a convolution is applied.
68+ dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
6769 """
6870
6971 def __init__ (self , channels , use_conv = False , out_channels = None , padding = 1 , name = "conv" ):
@@ -115,21 +117,22 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
115117 def _upsample_2d (self , hidden_states , weight = None , kernel = None , factor = 2 , gain = 1 ):
116118 """Fused `upsample_2d()` followed by `Conv2d()`.
117119
118- Args:
119120 Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
120- efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
121- order.
122- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
123- C]`.
124- weight: Weight tensor of the shape `[filterH, filterW, inChannels,
125- outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
126- kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
127- (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
128- factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
121+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
122+ arbitrary order.
123+
124+ Args:
125+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
126+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
127+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
128+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
129+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
130+ factor: Integer upsampling factor (default: 2).
131+ gain: Scaling factor for signal magnitude (default: 1.0).
129132
130133 Returns:
131- Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
132- `x `.
134+ output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
135+ datatype as `hidden_states `.
133136 """
134137
135138 assert isinstance (factor , int ) and factor >= 1
@@ -164,7 +167,6 @@ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1
164167 output_shape [1 ] - (hidden_states .shape [3 ] - 1 ) * stride [1 ] - convW ,
165168 )
166169 assert output_padding [0 ] >= 0 and output_padding [1 ] >= 0
167- inC = weight .shape [1 ]
168170 num_groups = hidden_states .shape [1 ] // inC
169171
170172 # Transpose weights.
@@ -214,20 +216,23 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
214216
215217 def _downsample_2d (self , hidden_states , weight = None , kernel = None , factor = 2 , gain = 1 ):
216218 """Fused `Conv2d()` followed by `downsample_2d()`.
219+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
220+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
221+ arbitrary order.
217222
218223 Args:
219- Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
220- efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary :
221- order.
222- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
223- filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
224- numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
225- factor`, which corresponds to average pooling. factor : Integer downsampling factor (default: 2). gain:
226- Scaling factor for signal magnitude (default: 1.0).
224+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
225+ weight :
226+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
227+ performed by `inChannels = x.shape[0] // numGroups`.
228+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
229+ factor`, which corresponds to average pooling.
230+ factor: Integer downsampling factor (default: 2).
231+ gain: Scaling factor for signal magnitude (default: 1.0).
227232
228233 Returns:
229- Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
230- datatype as `x`.
234+ output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
235+ same datatype as `x`.
231236 """
232237
233238 assert isinstance (factor , int ) and factor >= 1
@@ -251,17 +256,17 @@ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain
251256 torch .tensor (kernel , device = hidden_states .device ),
252257 pad = ((pad_value + 1 ) // 2 , pad_value // 2 ),
253258 )
254- hidden_states = F .conv2d (upfirdn_input , weight , stride = stride_value , padding = 0 )
259+ output = F .conv2d (upfirdn_input , weight , stride = stride_value , padding = 0 )
255260 else :
256261 pad_value = kernel .shape [0 ] - factor
257- hidden_states = upfirdn2d_native (
262+ output = upfirdn2d_native (
258263 hidden_states ,
259264 torch .tensor (kernel , device = hidden_states .device ),
260265 down = factor ,
261266 pad = ((pad_value + 1 ) // 2 , pad_value // 2 ),
262267 )
263268
264- return hidden_states
269+ return output
265270
266271 def forward (self , hidden_states ):
267272 if self .use_conv :
@@ -393,20 +398,20 @@ def forward(self, hidden_states):
393398
394399def upsample_2d (hidden_states , kernel = None , factor = 2 , gain = 1 ):
395400 r"""Upsample2D a batch of 2D images with the given filter.
396-
397- Args:
398401 Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
399402 filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
400- `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
401- multiple of the upsampling factor.
402- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
403- C]`.
404- k: FIR filter of the shape `[firH, firW]` or `[firN]`
403+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
404+ a: multiple of the upsampling factor.
405+
406+ Args:
407+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
408+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
405409 (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
406- factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
410+ factor: Integer upsampling factor (default: 2).
411+ gain: Scaling factor for signal magnitude (default: 1.0).
407412
408413 Returns:
409- Tensor of the shape `[N, C, H * factor, W * factor]`
414+ output: Tensor of the shape `[N, C, H * factor, W * factor]`
410415 """
411416 assert isinstance (factor , int ) and factor >= 1
412417 if kernel is None :
@@ -419,30 +424,31 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
419424
420425 kernel = kernel * (gain * (factor ** 2 ))
421426 pad_value = kernel .shape [0 ] - factor
422- return upfirdn2d_native (
427+ output = upfirdn2d_native (
423428 hidden_states ,
424429 kernel .to (device = hidden_states .device ),
425430 up = factor ,
426431 pad = ((pad_value + 1 ) // 2 + factor - 1 , pad_value // 2 ),
427432 )
433+ return output
428434
429435
430436def downsample_2d (hidden_states , kernel = None , factor = 2 , gain = 1 ):
431437 r"""Downsample2D a batch of 2D images with the given filter.
432-
433- Args:
434438 Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
435439 given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
436440 specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
437441 shape is a multiple of the downsampling factor.
438- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
439- C]`.
442+
443+ Args:
444+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
440445 kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
441446 (separable). The default is `[1] * factor`, which corresponds to average pooling.
442- factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
447+ factor: Integer downsampling factor (default: 2).
448+ gain: Scaling factor for signal magnitude (default: 1.0).
443449
444450 Returns:
445- Tensor of the shape `[N, C, H // factor, W // factor]`
451+ output: Tensor of the shape `[N, C, H // factor, W // factor]`
446452 """
447453
448454 assert isinstance (factor , int ) and factor >= 1
@@ -456,34 +462,34 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
456462
457463 kernel = kernel * gain
458464 pad_value = kernel .shape [0 ] - factor
459- return upfirdn2d_native (
465+ output = upfirdn2d_native (
460466 hidden_states , kernel .to (device = hidden_states .device ), down = factor , pad = ((pad_value + 1 ) // 2 , pad_value // 2 )
461467 )
468+ return output
462469
463470
464- def upfirdn2d_native (input , kernel , up = 1 , down = 1 , pad = (0 , 0 )):
471+ def upfirdn2d_native (tensor , kernel , up = 1 , down = 1 , pad = (0 , 0 )):
465472 up_x = up_y = up
466473 down_x = down_y = down
467474 pad_x0 = pad_y0 = pad [0 ]
468475 pad_x1 = pad_y1 = pad [1 ]
469476
470- _ , channel , in_h , in_w = input .shape
471- input = input .reshape (- 1 , in_h , in_w , 1 )
472- # Rename this variable (input); it shadows a builtin.sonarlint(python:S5806)
477+ _ , channel , in_h , in_w = tensor .shape
478+ tensor = tensor .reshape (- 1 , in_h , in_w , 1 )
473479
474- _ , in_h , in_w , minor = input .shape
480+ _ , in_h , in_w , minor = tensor .shape
475481 kernel_h , kernel_w = kernel .shape
476482
477- out = input .view (- 1 , in_h , 1 , in_w , 1 , minor )
483+ out = tensor .view (- 1 , in_h , 1 , in_w , 1 , minor )
478484
479485 # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
480- if input .device .type == "mps" :
486+ if tensor .device .type == "mps" :
481487 out = out .to ("cpu" )
482488 out = F .pad (out , [0 , 0 , 0 , up_x - 1 , 0 , 0 , 0 , up_y - 1 ])
483489 out = out .view (- 1 , in_h * up_y , in_w * up_x , minor )
484490
485491 out = F .pad (out , [0 , 0 , max (pad_x0 , 0 ), max (pad_x1 , 0 ), max (pad_y0 , 0 ), max (pad_y1 , 0 )])
486- out = out .to (input .device ) # Move back to mps if necessary
492+ out = out .to (tensor .device ) # Move back to mps if necessary
487493 out = out [
488494 :,
489495 max (- pad_y0 , 0 ) : out .shape [1 ] - max (- pad_y1 , 0 ),
0 commit comments