diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 23364bfa1d16..fde1014bd2e7 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -225,7 +225,7 @@ def forward( hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*):