diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index acce7b574e56..50382bcab37d 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -328,39 +328,39 @@ def __init__( if self.use_nin_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - def forward(self, x, temb, hey=False): - h = x + def forward(self, x, temb): + hidden_states = x # make sure hidden states is in float32 # when running in half-precision - h = self.norm1(h.float()).type(h.dtype) - h = self.nonlinearity(h) + hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: x = self.upsample(x) - h = self.upsample(h) + hidden_states = self.upsample(hidden_states) elif self.downsample is not None: x = self.downsample(x) - h = self.downsample(h) + hidden_states = self.downsample(hidden_states) - h = self.conv1(h) + hidden_states = self.conv1(hidden_states) if temb is not None: temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] - h = h + temb + hidden_states = hidden_states + temb # make sure hidden states is in float32 # when running in half-precision - h = self.norm2(h.float()).type(h.dtype) - h = self.nonlinearity(h) + hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.nonlinearity(hidden_states) - h = self.dropout(h) - h = self.conv2(h) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: x = self.conv_shortcut(x) - out = (x + h) / self.output_scale_factor + out = (x + hidden_states) / self.output_scale_factor return out