@@ -514,7 +514,7 @@ def get_1d_rotary_pos_embed(
514514 linear_factor = 1.0 ,
515515 ntk_factor = 1.0 ,
516516 repeat_interleave_real = True ,
517- freqs_dtype = torch .float32 , # torch.float32 (hunyuan, stable audio) , torch.float64 (flux)
517+ freqs_dtype = torch .float32 , # torch.float32, torch.float64 (flux)
518518):
519519 """
520520 Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -551,15 +551,18 @@ def get_1d_rotary_pos_embed(
551551 t = torch .from_numpy (pos ).to (freqs .device ) # type: ignore # [S]
552552 freqs = torch .outer (t , freqs ) # type: ignore # [S, D/2]
553553 if use_real and repeat_interleave_real :
554+ # flux, hunyuan-dit, cogvideox
554555 freqs_cos = freqs .cos ().repeat_interleave (2 , dim = 1 ).float () # [S, D]
555556 freqs_sin = freqs .sin ().repeat_interleave (2 , dim = 1 ).float () # [S, D]
556557 return freqs_cos , freqs_sin
557558 elif use_real :
559+ # stable audio
558560 freqs_cos = torch .cat ([freqs .cos (), freqs .cos ()], dim = - 1 ).float () # [S, D]
559561 freqs_sin = torch .cat ([freqs .sin (), freqs .sin ()], dim = - 1 ).float () # [S, D]
560562 return freqs_cos , freqs_sin
561563 else :
562- freqs_cis = torch .polar (torch .ones_like (freqs ), freqs ).float () # complex64 # [S, D/2]
564+ # lumina
565+ freqs_cis = torch .polar (torch .ones_like (freqs ), freqs ) # complex64 # [S, D/2]
563566 return freqs_cis
564567
565568
@@ -590,11 +593,11 @@ def apply_rotary_emb(
590593 cos , sin = cos .to (x .device ), sin .to (x .device )
591594
592595 if use_real_unbind_dim == - 1 :
593- # Use for example in Lumina
596+ # Used for flux, cogvideox, hunyuan-dit
594597 x_real , x_imag = x .reshape (* x .shape [:- 1 ], - 1 , 2 ).unbind (- 1 ) # [B, S, H, D//2]
595598 x_rotated = torch .stack ([- x_imag , x_real ], dim = - 1 ).flatten (3 )
596599 elif use_real_unbind_dim == - 2 :
597- # Use for example in Stable Audio
600+ # Used for Stable Audio
598601 x_real , x_imag = x .reshape (* x .shape [:- 1 ], 2 , - 1 ).unbind (- 2 ) # [B, S, H, D//2]
599602 x_rotated = torch .cat ([- x_imag , x_real ], dim = - 1 )
600603 else :
@@ -604,6 +607,7 @@ def apply_rotary_emb(
604607
605608 return out
606609 else :
610+ # used for lumina
607611 x_rotated = torch .view_as_complex (x .float ().reshape (* x .shape [:- 1 ], - 1 , 2 ))
608612 freqs_cis = freqs_cis .unsqueeze (2 )
609613 x_out = torch .view_as_real (x_rotated * freqs_cis ).flatten (3 )
0 commit comments