-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[Core] Introduce class variants for Transformer2DModel
#7647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Transformer2Model
Transformer2ModelTransformer2DModel
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Is the plan here to eventually map the Also how feasible is it to break it up into model specific variants rather than input specific variants? e.g |
Yeah, that's the plan.
Feasible, but I am not sure if we have enough such transformer-based pipelines yet. Most of them vary across very few things (such as the norm type and a cross-attention layer). I think there is a fair trade-off to be had when deciding which variant to use. If there are too many arguments that are changing, better to use a dedicated class (like we did for the private model). If not, rely on an existing variant that is dependent on the input type. |
Transformer2DModelTransformer2DModel
|
@DN6 done. I think I have addressed all your comments. LMK. |
|
@DN6 resolved your comment on the location of |
DN6
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work 👍🏽
|
LGTM. cc: @yiyixuxu in case you want to take a look too. |
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!!
I left a comment - let me know if it is a concern, and feel free to merge if it's not or have addressed it
| shift, scale = ( | ||
| self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device) | ||
| ).chunk(2, dim=1) | ||
| hidden_states = self.norm_out(hidden_states) | ||
| # Modulation | ||
| hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh ok do these tests should also fail on the current implementation - I don't think this refactor introduced any change that would cause them to fail, no?
| del module.proj_attn | ||
|
|
||
|
|
||
| class LegacyModelMixin(ModelMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
beautiful!
| return {"hidden_states": hidden_states, "timestep": timesteps, "class_labels": class_label_ids} | ||
|
|
||
| @property | ||
| def input_shape(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these properties used at all? maybe we can leverage them so we don't have to specify in in test_output?
can be in a separate PR if it makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good idea. Can look into a little "input_shape" refactor in a future PR :)
* init for patches * finish patched model. * continuous transformer * vectorized transformer2d. * style. * inits. * fix-copies. * introduce DiTTransformer2DModel. * fixes * use REMAPPING as suggested by @DN6 * better logging. * add pixart transformer model. * inits. * caption_channels. * attention masking. * fix use_additional_conditions. * remove print. * debug * flatten * fix: assertion for sigma * handle remapping for modeling_utils * add tests for dit transformer2d * quality * placeholder for pixart tests * pixart tests * add _no_split_modules * add docs. * check * check * check * check * fix tests * fix tests * move Transformer output to modeling_output * move errors better and bring back use_additional_conditions attribute. * add unnecessary things from DiT. * clean up pixart * fix remapping * fix device_map things in pixart2d. * replace Transformer2DModel with appropriate classes in dit, pixart tests * empty * legacy mixin classes./ * use a remapping dict for fetching class names. * change to specifc model types in the pipeline implementations. * move _fetch_remapped_cls_from_config to modeling_loading_utils.py * fix dependency problems. * add deprecation note.
| def test_pixart_512_without_resolution_binning(self): | ||
| generator = torch.manual_seed(0) | ||
|
|
||
| transformer = Transformer2DModel.from_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should have kept this test, can we add it back, and name it test_pixart_512_without_resolution_binning_legacy_class or something like this?
ane make sure to have a similar slow test for dit
in the future, I think we should always kept the test with the legacy class name, no? so that we can make sure that everything still work fine from the old API
cc @DN6
* init for patches * finish patched model. * continuous transformer * vectorized transformer2d. * style. * inits. * fix-copies. * introduce DiTTransformer2DModel. * fixes * use REMAPPING as suggested by @DN6 * better logging. * add pixart transformer model. * inits. * caption_channels. * attention masking. * fix use_additional_conditions. * remove print. * debug * flatten * fix: assertion for sigma * handle remapping for modeling_utils * add tests for dit transformer2d * quality * placeholder for pixart tests * pixart tests * add _no_split_modules * add docs. * check * check * check * check * fix tests * fix tests * move Transformer output to modeling_output * move errors better and bring back use_additional_conditions attribute. * add unnecessary things from DiT. * clean up pixart * fix remapping * fix device_map things in pixart2d. * replace Transformer2DModel with appropriate classes in dit, pixart tests * empty * legacy mixin classes./ * use a remapping dict for fetching class names. * change to specifc model types in the pipeline implementations. * move _fetch_remapped_cls_from_config to modeling_loading_utils.py * fix dependency problems. * add deprecation note.
|
Hello, I need to use a deprecated model (vq-diffusion) now. Due to version issues, Transformer2DModel has been mapped to two variants, but these two variants are slightly different from the original vq-diffusion (specifically, different types of norms are used). Directly loading the pre-trained model will cause the from_pretrained of the LegacyModelMixin class to fall into a loop call until the buffer overflows. If DiTTransformer2DModel uses ada_norm, an error [NotImplementedError: Forward pass is not implemented when |
What does this PR do?
Introduces two variants of
Transformer2DModel:DiTTransformer2DModelPixArtTransformer2DModelFor the other instances where
Transformer2DModelis used, they should later be turned to blocks as they shouldn't be inheriting fromModelMixin(has been discussed internally).TODO:
(Will be tackled after I get an initial review)
Some comments are in-line.
LMK.