@@ -120,7 +120,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
120120 norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
121121 If `None`, normalization and activation layers is skipped in post-processing.
122122 cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
123- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
124123 num_attention_heads (`int`, *optional*): The number of attention heads.
125124 """
126125
@@ -148,16 +147,10 @@ def __init__(
148147 layers_per_block : int = 2 ,
149148 norm_num_groups : Optional [int ] = 32 ,
150149 cross_attention_dim : int = 1024 ,
151- attention_head_dim : Union [int , Tuple [int ]] = None ,
152150 num_attention_heads : Optional [Union [int , Tuple [int ]]] = 64 ,
153151 ):
154152 super ().__init__ ()
155153
156- # We didn't define `attention_head_dim` when we first integrated this UNet. As a result,
157- # we had to use `num_attention_heads` in to pass values for arguments that actually denote
158- # attention head dimension. This is why we correct it here.
159- attention_head_dim = num_attention_heads or attention_head_dim
160-
161154 # Check inputs
162155 if len (down_block_types ) != len (up_block_types ):
163156 raise ValueError (
@@ -179,7 +172,7 @@ def __init__(
179172
180173 self .transformer_in = TransformerTemporalModel (
181174 num_attention_heads = 8 ,
182- attention_head_dim = attention_head_dim ,
175+ attention_head_dim = num_attention_heads ,
183176 in_channels = block_out_channels [0 ],
184177 num_layers = 1 ,
185178 norm_num_groups = norm_num_groups ,
0 commit comments