Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 63 additions & 57 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ class Upsample2D(nn.Module):
"""
An upsampling layer with an optional convolution.

:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
"""

def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
Expand Down Expand Up @@ -61,9 +62,10 @@ class Downsample2D(nn.Module):
"""
A downsampling layer with an optional convolution.

:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
"""

def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
Expand Down Expand Up @@ -115,21 +117,22 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.

Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order.

Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).

Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
`x`.
output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
datatype as `hidden_states`.
"""

assert isinstance(factor, int) and factor >= 1
Expand Down Expand Up @@ -164,7 +167,6 @@ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
inC = weight.shape[1]
num_groups = hidden_states.shape[1] // inC

# Transpose weights.
Expand Down Expand Up @@ -214,20 +216,23 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=

def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order.

Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
Scaling factor for signal magnitude (default: 1.0).
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight:
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
performed by `inChannels = x.shape[0] // numGroups`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).

Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
datatype as `x`.
output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
same datatype as `x`.
"""

assert isinstance(factor, int) and factor >= 1
Expand All @@ -251,17 +256,17 @@ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain
torch.tensor(kernel, device=hidden_states.device),
pad=((pad_value + 1) // 2, pad_value // 2),
)
hidden_states = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
else:
pad_value = kernel.shape[0] - factor
hidden_states = upfirdn2d_native(
output = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)

return hidden_states
return output

def forward(self, hidden_states):
if self.use_conv:
Expand Down Expand Up @@ -393,20 +398,20 @@ def forward(self, hidden_states):

def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter.

Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
a: multiple of the upsampling factor.

Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).

Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
output: Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
Expand All @@ -419,30 +424,31 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):

kernel = kernel * (gain * (factor**2))
pad_value = kernel.shape[0] - factor
return upfirdn2d_native(
output = upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)
return output


def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Downsample2D a batch of 2D images with the given filter.

Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.

Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).

Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
output: Tensor of the shape `[N, C, H // factor, W // factor]`
"""

assert isinstance(factor, int) and factor >= 1
Expand All @@ -456,34 +462,34 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):

kernel = kernel * gain
pad_value = kernel.shape[0] - factor
return upfirdn2d_native(
output = upfirdn2d_native(
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
)
return output


def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
up_x = up_y = up
down_x = down_y = down
pad_x0 = pad_y0 = pad[0]
pad_x1 = pad_y1 = pad[1]

_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
# Rename this variable (input); it shadows a builtin.sonarlint(python:S5806)
_, channel, in_h, in_w = tensor.shape
tensor = tensor.reshape(-1, in_h, in_w, 1)

_, in_h, in_w, minor = input.shape
_, in_h, in_w, minor = tensor.shape
kernel_h, kernel_w = kernel.shape

out = input.view(-1, in_h, 1, in_w, 1, minor)
out = tensor.view(-1, in_h, 1, in_w, 1, minor)

# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
if input.device.type == "mps":
if tensor.device.type == "mps":
out = out.to("cpu")
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)

out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out.to(input.device) # Move back to mps if necessary
out = out.to(tensor.device) # Move back to mps if necessary
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
Expand Down