diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 27fae24f71d8..507ca8632de3 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -333,7 +333,7 @@ def forward(self, x, temb): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm1(hidden_states).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: @@ -351,7 +351,7 @@ def forward(self, x, temb): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm2(hidden_states).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states)