Skip to content

Deprecate attention block #2697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

williamberman
Copy link
Contributor

@williamberman williamberman commented Mar 15, 2023

re: #1880

PR description

We are in the process of deprecating the AttentionBlock class. All attention will be done through attention_processor.Attention.

The goal of this PR is to remove the body of AttentionBlock and replace it with SpatialAttnProcessor and XFormersSpatialAttnProcessor. This will allow checkpoints which have weights in the format of AttentionBlock to use attention_processor.Attention.

After model load, we directly replace all instances of AttentionBlock with attention_processor.Attention. The forward method of AttentionBlock throws an error.

The next step will be to re-upload AttentionBlock weights on the hub to be in the same format as attention_processor.Attention along with a configuration flag to use attention_processor.Attention.

Hub stats for attention blocks in diffusers

from https://github.com/williamberman/attention-block-refactor

UNet2DModel, VQModel, and AutoEncoderKL are top level models which all have the deprecated attention block as a sub module. The conditional unet can be configured to use the deprecated attention block but there are no hub uploads which are configured as such.

The most common model is AutoEncoderKL because all stable diffusion checkpoints use it

UNet2DModel

1039 hub uploads
blocks using deprecated attention

  • AttnDownBlock2D - 1033 hub uploads
  • AttnUpBlock2D - 1033 hub uploads
  • AttnSkipDownBlock2D - 6 hub uploads
  • AttnSkipUpBlock2D - 6 hub uploads

VQModel

13 hub uploads
blocks using deprecated attention

  • UNetMidBlock2D: 13
  • AttnDownEncoderBlock2D: 7
  • AttnUpDecoderBlock2D: 7

AutoencoderKL

2927 hub uploads
blocks using deprecated attention

  • UNetMidBlock2D: 2927
  • AttnDownEncoderBlock2D: 1
  • AttnUpDecoderBlock2D: 1

UNet2DModel, VQModel, and AutoencoderKL all have the flag attention_block_type added. This flag is passed down to AttnDownBlock2D, AttnUpBlock2D, AttnSkipDownBlock2D, AttnSkipUpBlock2D, UNetMidBlock2D, AttnDownEncoderBlock2D, and AttnUpDecoderBlock2D for switching to the non-deprecated attention block

When do we convert the attention blocks

_convert_deprecated_attention_blocks is the method added to the ModelMixin which converts all deprecated attention blocks.

We cannot call _convert_deprecated_attention_block in the constructor of the model because weights in the old format can be loaded after the constructor is called.

  1. Subclasses of ModelMixin are created by ModelMixin#from_pretrained. After the weights have been loaded, convert the attention blocks.
  2. Subclasses of ModelMixin are created only with their constructor and passed to the pipeline's constructor. Convert the attention blocks in DiffusionPipeline#register_modules.

@williamberman williamberman changed the title Depreate attention block Deprecate attention block Mar 15, 2023
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@williamberman williamberman force-pushed the depreate_attention_block branch 6 times, most recently from 219829e to 80e9ff8 Compare March 17, 2023 21:29
@williamberman williamberman force-pushed the depreate_attention_block branch from 80e9ff8 to fcd629d Compare March 17, 2023 22:06
" If you are only loading a model instead of a whole pipeline, the same instructions apply with `ModelMixin#save_pretrained`."
)

deprecate("AttentionBlock", "0.18.0", deprecation_message, standard_warn=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten is this the right deprecation version? iirc we talked about two minor versions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'd maybe even bumb it up a bit more to "0.20.0" maybe

@williamberman williamberman force-pushed the depreate_attention_block branch 4 times, most recently from a821525 to 6c72fa9 Compare March 17, 2023 23:51
@@ -783,6 +785,71 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)

def _convert_deprecated_attention_blocks(self):
Copy link
Contributor Author

@williamberman williamberman Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten Is it ok to put all of the conversion code for each class in this one method and determine which model it's in by checking self.__class__.__name__? I think it's a bit more straightforward than splitting it up in the different model definitions

@williamberman williamberman force-pushed the depreate_attention_block branch 6 times, most recently from 98cdd56 to 19178f7 Compare March 18, 2023 07:11
@williamberman williamberman marked this pull request as ready for review March 18, 2023 07:29
@williamberman williamberman force-pushed the depreate_attention_block branch 2 times, most recently from 88b463c to fd56be6 Compare March 19, 2023 01:56
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look great, just wondering if we could maybe try to use existing attn processors by doing some pre- and post-reshaping before the attention later (or is this not possible - haven't checked)

bias=True,
upcast_softmax=True,
norm_num_groups=resnet_groups,
processor=SpatialAttnProcessor(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm is there really a difference between Spatial Attention and "normal" attention - isn't it just some reshaping before and after that's different?

Could we maybe try to use the normal attention processor and do the reshaping outside of the attention (e.g. in the VAE directly) - this would also allow us to use all the other attention processors out of the box (e.g. LoRA, xfomers, torch 2.0)

Wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wip commit up 5aa3446

@williamberman williamberman force-pushed the depreate_attention_block branch 3 times, most recently from 592fcfb to 938c44a Compare March 20, 2023 21:52
@williamberman williamberman force-pushed the depreate_attention_block branch 3 times, most recently from 5aa3446 to 14718c6 Compare March 21, 2023 01:08
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool I think the changes to the attention processor are nice! Think we can make the code a bit shorter though with less if-else statements and arguably easier to read code (not a huge fan of booleans such as reshaped_input=True)

@williamberman williamberman force-pushed the depreate_attention_block branch from 14718c6 to fdffad4 Compare March 21, 2023 17:48
@williamberman
Copy link
Contributor Author

There are some failures when running the integration tests against this branch #2759 I'll need to fix those before we merge

@patrickvonplaten
Copy link
Contributor

Ah yeah, can you also update the testing branch? A bunch of test failures have recently been fixed on main

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Apr 16, 2023
@github-actions github-actions bot closed this Apr 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants