From 0b56a0e52eb29391b64127c342bba8804497af12 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Sun, 21 Jul 2024 00:12:27 +0700 Subject: [PATCH 1/5] add hunyuan model test --- .../test_models_transformer_hunyuan_dit.py | 103 ++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 tests/models/test_models_transformer_hunyuan_dit.py diff --git a/tests/models/test_models_transformer_hunyuan_dit.py b/tests/models/test_models_transformer_hunyuan_dit.py new file mode 100644 index 000000000000..9913d1f4dbb8 --- /dev/null +++ b/tests/models/test_models_transformer_hunyuan_dit.py @@ -0,0 +1,103 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import HunyuanDiT2DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class HunyuanDiTTests(ModelTesterMixin, unittest.TestCase): + model_class = HunyuanDiT2DModel + main_input_name = "hidden_states" + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + height = width = 16 + embedding_dim = 32 + sequence_length = 77 + sequence_length_t5 = 256 + + hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + text_embedding_mask = torch.ones(size=(batch_size, sequence_length)).to(torch_device) + encoder_hidden_states_t5 = torch.randn((batch_size, sequence_length_t5, embedding_dim)).to(torch_device) + text_embedding_mask_t5 = torch.ones(size=(batch_size, sequence_length_t5)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,), dtype=encoder_hidden_states.dtype).to(torch_device) + + original_size = [1024, 1024] + target_size = [16, 16] + crops_coords_top_left = [0, 0] + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids, add_time_ids], dtype=encoder_hidden_states.dtype).to(torch_device) + style = torch.zeros(size=(batch_size,), dtype=int).to(torch_device) + image_rotary_emb = [ + torch.ones(size=(1, 8), dtype=encoder_hidden_states.dtype), + torch.zeros(size=(1, 8), dtype=encoder_hidden_states.dtype) + ] + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "text_embedding_mask": text_embedding_mask, + "encoder_hidden_states_t5": encoder_hidden_states_t5, + "text_embedding_mask_t5": text_embedding_mask_t5, + "timestep": timestep, + "image_meta_size": add_time_ids, + "style": style, + "image_rotary_emb": image_rotary_emb, + } + + @property + def input_shape(self): + return (4, 16, 16) + + @property + def output_shape(self): + return (8, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "sample_size": 16, + "patch_size": 2, + "in_channels": 4, + "num_layers": 2, + "attention_head_dim": 8, + "num_attention_heads": 3, + "cross_attention_dim": 32, + "cross_attention_dim_t5": 32, + "pooled_projection_dim": 16, + "hidden_size": 24, + "activation_fn": "gelu-approximate", + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_output(self): + super().test_output( + expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape + ) \ No newline at end of file From 794d98500e2b23485cd190e466732760c7c95204 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Sun, 21 Jul 2024 19:19:11 +0700 Subject: [PATCH 2/5] apply suggestions --- .../test_models_transformer_hunyuan_dit.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) rename tests/models/{ => transformers}/test_models_transformer_hunyuan_dit.py (91%) diff --git a/tests/models/test_models_transformer_hunyuan_dit.py b/tests/models/transformers/test_models_transformer_hunyuan_dit.py similarity index 91% rename from tests/models/test_models_transformer_hunyuan_dit.py rename to tests/models/transformers/test_models_transformer_hunyuan_dit.py index 9913d1f4dbb8..3c7f5623a91c 100644 --- a/tests/models/test_models_transformer_hunyuan_dit.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_dit.py @@ -85,7 +85,7 @@ def prepare_init_args_and_inputs_for_common(self): "sample_size": 16, "patch_size": 2, "in_channels": 4, - "num_layers": 2, + "num_layers": 1, "attention_head_dim": 8, "num_attention_heads": 3, "cross_attention_dim": 32, @@ -100,4 +100,12 @@ def prepare_init_args_and_inputs_for_common(self): def test_output(self): super().test_output( expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape - ) \ No newline at end of file + ) + + @unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0") + def test_set_xformers_attn_processor_for_determinism(self): + pass + + @unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0") + def test_set_attn_processor_for_determinism(self): + pass \ No newline at end of file From 63c53fa861f72d58c9b8fb467ec5eba513da0741 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Mon, 29 Jul 2024 21:52:03 +0700 Subject: [PATCH 3/5] reduce dims further --- .../test_models_transformer_hunyuan_dit.py | 39 ++++++++++++------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_hunyuan_dit.py b/tests/models/transformers/test_models_transformer_hunyuan_dit.py index 3c7f5623a91c..c15858e570f4 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_dit.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_dit.py @@ -22,6 +22,17 @@ enable_full_determinism, torch_device, ) +from diffusers.models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + HunyuanAttnProcessor2_0, + XFormersAttnProcessor, +) +from diffusers.utils import is_xformers_available +from diffusers.utils.testing_utils import ( + require_torch_gpu, + torch_device, +) from ..test_modeling_common import ModelTesterMixin @@ -37,10 +48,10 @@ class HunyuanDiTTests(ModelTesterMixin, unittest.TestCase): def dummy_input(self): batch_size = 2 num_channels = 4 - height = width = 16 - embedding_dim = 32 - sequence_length = 77 - sequence_length_t5 = 256 + height = width = 8 + embedding_dim = 8 + sequence_length = 4 + sequence_length_t5 = 4 hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) @@ -74,24 +85,26 @@ def dummy_input(self): @property def input_shape(self): - return (4, 16, 16) + return (4, 8, 8) @property def output_shape(self): - return (8, 16, 16) + return (8, 8, 8) def prepare_init_args_and_inputs_for_common(self): init_dict = { - "sample_size": 16, + "sample_size": 8, "patch_size": 2, "in_channels": 4, - "num_layers": 1, + "num_layers": 2, "attention_head_dim": 8, - "num_attention_heads": 3, - "cross_attention_dim": 32, - "cross_attention_dim_t5": 32, - "pooled_projection_dim": 16, - "hidden_size": 24, + "num_attention_heads": 2, + "cross_attention_dim": 8, + "cross_attention_dim_t5": 8, + "pooled_projection_dim": 4, + "hidden_size": 16, + "text_len": 4, + "text_len_t5": 4, "activation_fn": "gelu-approximate", } inputs_dict = self.dummy_input From fbc52662d8525ce6fcd3c62d6c3bca0017b6c471 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Mon, 29 Jul 2024 22:01:35 +0700 Subject: [PATCH 4/5] reduce dims further --- .../models/transformers/test_models_transformer_hunyuan_dit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/transformers/test_models_transformer_hunyuan_dit.py b/tests/models/transformers/test_models_transformer_hunyuan_dit.py index c15858e570f4..09ec047aa83b 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_dit.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_dit.py @@ -96,7 +96,7 @@ def prepare_init_args_and_inputs_for_common(self): "sample_size": 8, "patch_size": 2, "in_channels": 4, - "num_layers": 2, + "num_layers": 1, "attention_head_dim": 8, "num_attention_heads": 2, "cross_attention_dim": 8, From 7b0ceade80888b445fefc64175c1e8498ca8c6c7 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Tue, 30 Jul 2024 09:54:04 +0700 Subject: [PATCH 5/5] run make style --- .../test_models_transformer_hunyuan_dit.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_hunyuan_dit.py b/tests/models/transformers/test_models_transformer_hunyuan_dit.py index 09ec047aa83b..ea05abed38d9 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_dit.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_dit.py @@ -22,17 +22,6 @@ enable_full_determinism, torch_device, ) -from diffusers.models.attention_processor import ( - AttnProcessor, - AttnProcessor2_0, - HunyuanAttnProcessor2_0, - XFormersAttnProcessor, -) -from diffusers.utils import is_xformers_available -from diffusers.utils.testing_utils import ( - require_torch_gpu, - torch_device, -) from ..test_modeling_common import ModelTesterMixin @@ -68,7 +57,7 @@ def dummy_input(self): style = torch.zeros(size=(batch_size,), dtype=int).to(torch_device) image_rotary_emb = [ torch.ones(size=(1, 8), dtype=encoder_hidden_states.dtype), - torch.zeros(size=(1, 8), dtype=encoder_hidden_states.dtype) + torch.zeros(size=(1, 8), dtype=encoder_hidden_states.dtype), ] return { @@ -121,4 +110,4 @@ def test_set_xformers_attn_processor_for_determinism(self): @unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0") def test_set_attn_processor_for_determinism(self): - pass \ No newline at end of file + pass