From 7eb95d706fd962c10fa2810f634238f479431274 Mon Sep 17 00:00:00 2001 From: Chanchana Sornsoontorn Date: Thu, 13 Apr 2023 00:34:36 +0700 Subject: [PATCH] =?UTF-8?q?=E2=9A=99=EF=B8=8Fchore(transformer=5F2d)=20upd?= =?UTF-8?q?ate=20function=20signature=20for=20encoder=5Fhidden=5Fstates?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/diffusers/models/transformer_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 23364bfa1d16..fde1014bd2e7 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -225,7 +225,7 @@ def forward( hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*):