-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Deprecate attention block #2697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecate attention block #2697
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
219829e
to
80e9ff8
Compare
80e9ff8
to
fcd629d
Compare
" If you are only loading a model instead of a whole pipeline, the same instructions apply with `ModelMixin#save_pretrained`." | ||
) | ||
|
||
deprecate("AttentionBlock", "0.18.0", deprecation_message, standard_warn=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten is this the right deprecation version? iirc we talked about two minor versions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'd maybe even bumb it up a bit more to "0.20.0"
maybe
a821525
to
6c72fa9
Compare
@@ -783,6 +785,71 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool | |||
else: | |||
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) | |||
|
|||
def _convert_deprecated_attention_blocks(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten Is it ok to put all of the conversion code for each class in this one method and determine which model it's in by checking self.__class__.__name__
? I think it's a bit more straightforward than splitting it up in the different model definitions
98cdd56
to
19178f7
Compare
88b463c
to
fd56be6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look great, just wondering if we could maybe try to use existing attn processors by doing some pre- and post-reshaping before the attention later (or is this not possible - haven't checked)
bias=True, | ||
upcast_softmax=True, | ||
norm_num_groups=resnet_groups, | ||
processor=SpatialAttnProcessor(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm is there really a difference between Spatial Attention and "normal" attention - isn't it just some reshaping before and after that's different?
Could we maybe try to use the normal attention processor and do the reshaping outside of the attention (e.g. in the VAE directly) - this would also allow us to use all the other attention processors out of the box (e.g. LoRA, xfomers, torch 2.0)
Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wip commit up 5aa3446
592fcfb
to
938c44a
Compare
5aa3446
to
14718c6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool I think the changes to the attention processor are nice! Think we can make the code a bit shorter though with less if-else statements and arguably easier to read code (not a huge fan of booleans such as reshaped_input=True
)
14718c6
to
fdffad4
Compare
There are some failures when running the integration tests against this branch #2759 I'll need to fix those before we merge |
Ah yeah, can you also update the testing branch? A bunch of test failures have recently been fixed on main |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
re: #1880
PR description
We are in the process of deprecating the
AttentionBlock
class. All attention will be done throughattention_processor.Attention
.The goal of this PR is to remove the body of
AttentionBlock
and replace it withSpatialAttnProcessor
andXFormersSpatialAttnProcessor
. This will allow checkpoints which have weights in the format ofAttentionBlock
to useattention_processor.Attention
.After model load, we directly replace all instances of AttentionBlock with
attention_processor.Attention
. The forward method ofAttentionBlock
throws an error.The next step will be to re-upload
AttentionBlock
weights on the hub to be in the same format asattention_processor.Attention
along with a configuration flag to useattention_processor.Attention
.Hub stats for attention blocks in diffusers
from https://github.com/williamberman/attention-block-refactor
UNet2DModel, VQModel, and AutoEncoderKL are top level models which all have the deprecated attention block as a sub module. The conditional unet can be configured to use the deprecated attention block but there are no hub uploads which are configured as such.
The most common model is AutoEncoderKL because all stable diffusion checkpoints use it
UNet2DModel
1039 hub uploads
blocks using deprecated attention
VQModel
13 hub uploads
blocks using deprecated attention
AutoencoderKL
2927 hub uploads
blocks using deprecated attention
UNet2DModel, VQModel, and AutoencoderKL all have the flag
attention_block_type
added. This flag is passed down to AttnDownBlock2D, AttnUpBlock2D, AttnSkipDownBlock2D, AttnSkipUpBlock2D, UNetMidBlock2D, AttnDownEncoderBlock2D, and AttnUpDecoderBlock2D for switching to the non-deprecated attention blockWhen do we convert the attention blocks
_convert_deprecated_attention_blocks
is the method added to the ModelMixin which converts all deprecated attention blocks.We cannot call
_convert_deprecated_attention_block
in the constructor of the model because weights in the old format can be loaded after the constructor is called.ModelMixin
are created byModelMixin#from_pretrained
. After the weights have been loaded, convert the attention blocks.ModelMixin
are created only with their constructor and passed to the pipeline's constructor. Convert the attention blocks inDiffusionPipeline#register_modules
.