Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 66 additions & 37 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
self.fir_kernel = fir_kernel
self.out_channels = out_channels

def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.

Args:
Expand Down Expand Up @@ -151,34 +151,46 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
convW = weight.shape[3]
inC = weight.shape[1]

p = (kernel.shape[0] - factor) - (convW - 1)
pad_value = (kernel.shape[0] - factor) - (convW - 1)

stride = (factor, factor)
# Determine data dimensions.
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_shape = (
(hidden_states.shape[2] - 1) * factor + convH,
(hidden_states.shape[3] - 1) * factor + convW,
)
output_padding = (
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
inC = weight.shape[1]
num_groups = x.shape[1] // inC
num_groups = hidden_states.shape[1] // inC

# Transpose weights.
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))

x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
inverse_conv = F.conv_transpose2d(
hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
)

x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
output = upfirdn2d_native(
inverse_conv,
torch.tensor(kernel, device=inverse_conv.device),
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
)
else:
p = kernel.shape[0] - factor
x = upfirdn2d_native(
x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)

return x
return output

def forward(self, hidden_states):
if self.use_conv:
Expand All @@ -200,7 +212,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
self.use_conv = use_conv
self.out_channels = out_channels

def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.

Args:
Expand Down Expand Up @@ -232,20 +244,29 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):

if self.use_conv:
_, _, convH, convW = weight.shape
p = (kernel.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
x = F.conv2d(x, weight, stride=s, padding=0)
pad_value = (kernel.shape[0] - factor) + (convW - 1)
stride_value = [factor, factor]
upfirdn_input = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
pad=((pad_value + 1) // 2, pad_value // 2),
)
hidden_states = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
else:
p = kernel.shape[0] - factor
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
pad_value = kernel.shape[0] - factor
hidden_states = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)

return x
return hidden_states

def forward(self, hidden_states):
if self.use_conv:
hidden_states = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = hidden_states + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)

Expand Down Expand Up @@ -332,17 +353,17 @@ def __init__(
if self.use_in_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

def forward(self, x, temb):
hidden_states = x
def forward(self, input_tensor, temb):
hidden_states = input_tensor

hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)

if self.upsample is not None:
x = self.upsample(x)
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
elif self.downsample is not None:
x = self.downsample(x)
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)

hidden_states = self.conv1(hidden_states)
Expand All @@ -358,19 +379,19 @@ def forward(self, x, temb):
hidden_states = self.conv2(hidden_states)

if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
input_tensor = self.conv_shortcut(input_tensor)

out = (x + hidden_states) / self.output_scale_factor
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

return out
return output_tensor


class Mish(torch.nn.Module):
def forward(self, x):
return x * torch.tanh(torch.nn.functional.softplus(x))
def forward(self, hidden_states):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))


def upsample_2d(x, kernel=None, factor=2, gain=1):
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter.

Args:
Expand All @@ -397,11 +418,16 @@ def upsample_2d(x, kernel=None, factor=2, gain=1):
kernel /= torch.sum(kernel)

kernel = kernel * (gain * (factor**2))
p = kernel.shape[0] - factor
return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
pad_value = kernel.shape[0] - factor
return upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)


def downsample_2d(x, kernel=None, factor=2, gain=1):
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Downsample2D a batch of 2D images with the given filter.

Args:
Expand Down Expand Up @@ -429,8 +455,10 @@ def downsample_2d(x, kernel=None, factor=2, gain=1):
kernel /= torch.sum(kernel)

kernel = kernel * gain
p = kernel.shape[0] - factor
return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
pad_value = kernel.shape[0] - factor
return upfirdn2d_native(
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
)


def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
Expand All @@ -441,6 +469,7 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):

_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
# Rename this variable (input); it shadows a builtin.sonarlint(python:S5806)

_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
Expand Down