-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[Refactor] splitingResnetBlock2D into multiple blocks
#6166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
still wip, cc @DN6 for awareness |
| # there is always at least one resnet | ||
| resnets = [ | ||
| ResnetBlock2D( | ||
| get_resnet_block( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be better to just append the appropriate Resnet block here than introduce the block fetching function. Is the purpose here to ensure backwards compatibility? Or do these blocks use a mix of Resnet classes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm aware that some of the blocks use a mix, i.e. the ones Kandinsky used for MOVQ: DownEncoderBlock2D, AttnDownEncoderBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D.
We could use get_resnet_block for these blocks instead - I would prefer that way too, because it is more explicit and readable. I used the function everywhere for now because:
- Technically all these blocks can be configured to use both types, I don't have the complete knowledge to know if it would break existing models. I think pretty unlikely for "spatial norm" though it is very specific to MOVQ
- use this as an example to show how we can ensure backward compatibility. I only split
ResnetBlock2Dinto 2 blocks here in the draft and as far as I'm aware,ResnetBlockCondNorm2Dis only used by latent upscaler and Kandinsky; but maybe we want to have more blocks with more specific configurations. It can become unmanageable without a block-fetching function
I agree with you it is better just to use the appropriate Resnet blocks, and very happy to refactor later. I will wait to do that after hearing more about:
- feedbacks on how many blocks we want to split into
- what's the best strategy to ensure backward compatibility, given that we might not have complete data on how these blocks are used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feedbacks on how many blocks we want to split into
I think we'd want to split into the most common block types. I think this number is still quite unclear.
We could use get_resnet_block for these blocks instead - I would prefer that way too, because it is more explicit and readable.
How about we use the fetcher function but also add a comment above to denote which type of ResNet block is being used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice first attempt here! I agree with @DN6 that creating a get_resnet_block(...) is not really making the code much more readable. Instead it would be better if we can directly add the correct Resnet class to the blocks.
I think most blocks only use one of the two ResNets - in this case can we add the correct ResNet right away? In the exceptional case where the block can use both ResNets, let's add an if-else statement?
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
can we do these in future PRs? I would rather save these types of tasks for the last, i.e split one files into multiple, move them around etc |
|
@patrickvonplaten @DN6 |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice looking good - got only one comment regarding the deprecation warning
Co-authored-by: Patrick von Platen <[email protected]>
src/diffusers/models/resnet.py
Outdated
| deprecate( | ||
| "ada_group", | ||
| "1.0.0", | ||
| "Passing `ada_group` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead", | ||
| ) | ||
| if time_embedding_norm == "spatial": | ||
| raise ValueError( | ||
| "spatial", | ||
| "1.0.0", | ||
| "Passing `spatial` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead", | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use deprecate for one and ValueError for the other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops 😅 fixed it
DN6
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's looking really good! 👍🏽
src/diffusers/models/resnet.py
Outdated
| "Passing `ada_group` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if saying something is deprecated is a good idea in a "ValueError". Have we ever done that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's indeed rename the message here to something like "This class cannot be used with "type==ada_group", please use XXX instead"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated!
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice. Just one comment.
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
…6166) --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
This PR spin out
ResnetBlockCondNorm2DfromResnetBlock2D:ResnetBlockCondNorm2Dis a resnet block with a normalization layer that incorporates conditional information, e.g.AdaGroupNorm,SpatialNorm. This resnet block is mainly used by Kandinsky decoder movq and latent upscaler