Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Apr 12, 2024

What does this PR do?

Introduces two variants of Transformer2DModel:

  • DiTTransformer2DModel
  • PixArtTransformer2DModel

For the other instances where Transformer2DModel is used, they should later be turned to blocks as they shouldn't be inheriting from ModelMixin (has been discussed internally).

TODO:

(Will be tackled after I get an initial review)

  • Tests for each individual variant
  • Documentation

Some comments are in-line.

LMK.

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu April 12, 2024 05:44
@sayakpaul sayakpaul changed the title [Core] [Core] Introduce class variants for Transformer2Model Apr 12, 2024
@sayakpaul sayakpaul changed the title [Core] Introduce class variants for Transformer2Model [Core] Introduce class variants for Transformer2DModel Apr 12, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member Author

@yiyixuxu @DN6 a gentle ping here.

@DN6 DN6 marked this pull request as ready for review April 29, 2024 05:35
@DN6
Copy link
Collaborator

DN6 commented Apr 29, 2024

Is the plan here to eventually map the Transformer2DModel to the variant? e.g A pipeline that uses Transformer2DModel with patched inference will now try to create PatchedTransformer2DModel under the hood?

Also how feasible is it to break it up into model specific variants rather than input specific variants? e.g PixArtTransformer2DModel?

@sayakpaul
Copy link
Member Author

Is the plan here to eventually map the Transformer2DModel to the variant? e.g A pipeline that uses Transformer2DModel with patched inference will now try to create PatchedTransformer2DModel under the hood?

Yeah, that's the plan.

Also how feasible is it to break it up into model specific variants rather than input specific variants? e.g PixArtTransformer2DModel?

Feasible, but I am not sure if we have enough such transformer-based pipelines yet. Most of them vary across very few things (such as the norm type and a cross-attention layer).

I think there is a fair trade-off to be had when deciding which variant to use. If there are too many arguments that are changing, better to use a dedicated class (like we did for the private model). If not, rely on an existing variant that is dependent on the input type.

@sayakpaul sayakpaul changed the title [Core] Introduce class variants for Transformer2DModel [WIP][Core] Introduce class variants for Transformer2DModel May 1, 2024
@sayakpaul sayakpaul marked this pull request as draft May 1, 2024 04:11
@sayakpaul
Copy link
Member Author

@DN6 done. I think I have addressed all your comments. LMK.

@sayakpaul sayakpaul requested a review from DN6 May 28, 2024 14:44
@sayakpaul
Copy link
Member Author

@DN6 resolved your comment on the location of _CLASS_REMAPPING_DICT. I have also moved _fetch_remapped_cls_from_config to model_loading_utils. I think this is better as _fetch_remapped_cls_from_config has nothing to do with the Hub.

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work 👍🏽

@DN6
Copy link
Collaborator

DN6 commented May 30, 2024

LGTM. cc: @yiyixuxu in case you want to take a look too.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!!
I left a comment - let me know if it is a concern, and feel free to merge if it's not or have addressed it

Comment on lines +315 to +320
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh ok do these tests should also fail on the current implementation - I don't think this refactor introduced any change that would cause them to fail, no?

del module.proj_attn


class LegacyModelMixin(ModelMixin):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beautiful!

return {"hidden_states": hidden_states, "timestep": timesteps, "class_labels": class_label_ids}

@property
def input_shape(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these properties used at all? maybe we can leverage them so we don't have to specify in in test_output?
can be in a separate PR if it makes sense.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good idea. Can look into a little "input_shape" refactor in a future PR :)

@sayakpaul sayakpaul merged commit 983dec3 into main May 31, 2024
@sayakpaul sayakpaul deleted the transormer2d-variants branch May 31, 2024 08:10
sayakpaul added a commit that referenced this pull request Jun 4, 2024
* init for patches

* finish patched model.

* continuous transformer

* vectorized transformer2d.

* style.

* inits.

* fix-copies.

* introduce DiTTransformer2DModel.

* fixes

* use REMAPPING as suggested by @DN6

* better logging.

* add pixart transformer model.

* inits.

* caption_channels.

* attention masking.

* fix use_additional_conditions.

* remove print.

* debug

* flatten

* fix: assertion for sigma

* handle remapping for modeling_utils

* add tests for dit transformer2d

* quality

* placeholder for pixart tests

* pixart tests

* add _no_split_modules

* add docs.

* check

* check

* check

* check

* fix tests

* fix tests

* move Transformer output to modeling_output

* move errors better and bring back use_additional_conditions attribute.

* add unnecessary things from DiT.

* clean up pixart

* fix remapping

* fix device_map things in pixart2d.

* replace Transformer2DModel with appropriate classes in dit, pixart tests

* empty

* legacy mixin classes./

* use a remapping dict for fetching class names.

* change to specifc model types in the pipeline implementations.

* move _fetch_remapped_cls_from_config to modeling_loading_utils.py

* fix dependency problems.

* add deprecation note.
def test_pixart_512_without_resolution_binning(self):
generator = torch.manual_seed(0)

transformer = Transformer2DModel.from_pretrained(
Copy link
Collaborator

@yiyixuxu yiyixuxu Jun 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should have kept this test, can we add it back, and name it test_pixart_512_without_resolution_binning_legacy_class or something like this?
ane make sure to have a similar slow test for dit

in the future, I think we should always kept the test with the legacy class name, no? so that we can make sure that everything still work fine from the old API
cc @DN6

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* init for patches

* finish patched model.

* continuous transformer

* vectorized transformer2d.

* style.

* inits.

* fix-copies.

* introduce DiTTransformer2DModel.

* fixes

* use REMAPPING as suggested by @DN6

* better logging.

* add pixart transformer model.

* inits.

* caption_channels.

* attention masking.

* fix use_additional_conditions.

* remove print.

* debug

* flatten

* fix: assertion for sigma

* handle remapping for modeling_utils

* add tests for dit transformer2d

* quality

* placeholder for pixart tests

* pixart tests

* add _no_split_modules

* add docs.

* check

* check

* check

* check

* fix tests

* fix tests

* move Transformer output to modeling_output

* move errors better and bring back use_additional_conditions attribute.

* add unnecessary things from DiT.

* clean up pixart

* fix remapping

* fix device_map things in pixart2d.

* replace Transformer2DModel with appropriate classes in dit, pixart tests

* empty

* legacy mixin classes./

* use a remapping dict for fetching class names.

* change to specifc model types in the pipeline implementations.

* move _fetch_remapped_cls_from_config to modeling_loading_utils.py

* fix dependency problems.

* add deprecation note.
@lengmo1996
Copy link
Contributor

Hello, I need to use a deprecated model (vq-diffusion) now. Due to version issues, Transformer2DModel has been mapped to two variants, but these two variants are slightly different from the original vq-diffusion (specifically, different types of norms are used). Directly loading the pre-trained model will cause the from_pretrained of the LegacyModelMixin class to fall into a loop call until the buffer overflows. If DiTTransformer2DModel uses ada_norm, an error [NotImplementedError: Forward pass is not implemented when patch_size is not None and norm_type is 'ada_norm'] will be reported. So how can I adjust the code to make it run?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants