diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index b9718e67f279..dc8a91164977 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -9,9 +9,10 @@ class Upsample2D(nn.Module): """ An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is - applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): @@ -61,9 +62,10 @@ class Downsample2D(nn.Module): """ A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is - applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): @@ -115,21 +117,22 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel= def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `upsample_2d()` followed by `Conv2d()`. - Args: Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary: - order. - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, - C]`. - weight: Weight tensor of the shape `[filterH, filterW, inChannels, - outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). Returns: - Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as - `x`. + output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same + datatype as `hidden_states`. """ assert isinstance(factor, int) and factor >= 1 @@ -164,7 +167,6 @@ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1 output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, ) assert output_padding[0] >= 0 and output_padding[1] >= 0 - inC = weight.shape[1] num_groups = hidden_states.shape[1] // inC # Transpose weights. @@ -214,20 +216,23 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel= def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `Conv2d()` followed by `downsample_2d()`. + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. Args: - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary: - order. - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH, - filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // - numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * - factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: - Scaling factor for signal magnitude (default: 1.0). + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight: + Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be + performed by `inChannels = x.shape[0] // numGroups`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * + factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). Returns: - Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same - datatype as `x`. + output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and + same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 @@ -251,17 +256,17 @@ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain torch.tensor(kernel, device=hidden_states.device), pad=((pad_value + 1) // 2, pad_value // 2), ) - hidden_states = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) + output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) else: pad_value = kernel.shape[0] - factor - hidden_states = upfirdn2d_native( + output = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2), ) - return hidden_states + return output def forward(self, hidden_states): if self.use_conv: @@ -393,20 +398,20 @@ def forward(self, hidden_states): def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter. - - Args: Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified - `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a: - multiple of the upsampling factor. - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, - C]`. - k: FIR filter of the shape `[firH, firW]` or `[firN]` + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is + a: multiple of the upsampling factor. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). Returns: - Tensor of the shape `[N, C, H * factor, W * factor]` + output: Tensor of the shape `[N, C, H * factor, W * factor]` """ assert isinstance(factor, int) and factor >= 1 if kernel is None: @@ -419,30 +424,31 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): kernel = kernel * (gain * (factor**2)) pad_value = kernel.shape[0] - factor - return upfirdn2d_native( + output = upfirdn2d_native( hidden_states, kernel.to(device=hidden_states.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) + return output def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Downsample2D a batch of 2D images with the given filter. - - Args: Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a multiple of the downsampling factor. - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, - C]`. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. - factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). Returns: - Tensor of the shape `[N, C, H // factor, W // factor]` + output: Tensor of the shape `[N, C, H // factor, W // factor]` """ assert isinstance(factor, int) and factor >= 1 @@ -456,34 +462,34 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): kernel = kernel * gain pad_value = kernel.shape[0] - factor - return upfirdn2d_native( + output = upfirdn2d_native( hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) ) + return output -def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): +def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): up_x = up_y = up down_x = down_y = down pad_x0 = pad_y0 = pad[0] pad_x1 = pad_y1 = pad[1] - _, channel, in_h, in_w = input.shape - input = input.reshape(-1, in_h, in_w, 1) - # Rename this variable (input); it shadows a builtin.sonarlint(python:S5806) + _, channel, in_h, in_w = tensor.shape + tensor = tensor.reshape(-1, in_h, in_w, 1) - _, in_h, in_w, minor = input.shape + _, in_h, in_w, minor = tensor.shape kernel_h, kernel_w = kernel.shape - out = input.view(-1, in_h, 1, in_w, 1, minor) + out = tensor.view(-1, in_h, 1, in_w, 1, minor) # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535 - if input.device.type == "mps": + if tensor.device.type == "mps": out = out.to("cpu") out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) - out = out.to(input.device) # Move back to mps if necessary + out = out.to(tensor.device) # Move back to mps if necessary out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),