diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index dce30c6a4aa6..bf04c3e6a3ca 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -376,6 +376,12 @@ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + def forward(self, hidden_states): hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) - return hidden_states * F.gelu(gate) + return hidden_states * self.gelu(gate)