|
35 | 35 | from diffusers import VQModel
|
36 | 36 | from diffusers.models.vq_diffusion_attention import VQDiffusionTransformer
|
37 | 37 | from diffusers.pipelines import VQDiffusionPipeline
|
| 38 | +from transformers import CLIPTextModel, CLIPTokenizer |
38 | 39 | from yaml.loader import FullLoader
|
39 | 40 |
|
40 | 41 |
|
@@ -819,9 +820,37 @@ def read_config_file(filename):
|
819 | 820 |
|
820 | 821 | # done transformer_model
|
821 | 822 |
|
| 823 | + # text encoder |
| 824 | + |
| 825 | + print("loading CLIP text encoder") |
| 826 | + |
| 827 | + clip_name = "openai/clip-vit-base-patch32" |
| 828 | + |
| 829 | + # The original VQ-Diffusion specifies the pad value by the int used in the |
| 830 | + # returned tokens. Each model uses `0` as the pad value. The transformers clip api |
| 831 | + # specifies the pad value via the token before it has been tokenized. The `!` pad |
| 832 | + # token is the same as padding with the `0` pad value. |
| 833 | + pad_token = "!" |
| 834 | + |
| 835 | + tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto") |
| 836 | + |
| 837 | + assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0 |
| 838 | + |
| 839 | + text_encoder_model = CLIPTextModel.from_pretrained( |
| 840 | + clip_name, |
| 841 | + # `CLIPTextModel` does not support device_map="auto" |
| 842 | + # device_map="auto" |
| 843 | + ) |
| 844 | + |
| 845 | + print("done loading CLIP text encoder") |
| 846 | + |
| 847 | + # done text encoder |
| 848 | + |
822 | 849 | print(f"saving VQ diffusion model, path: {args.dump_path}")
|
823 | 850 |
|
824 |
| - pipe = VQDiffusionPipeline(vqvae=vqvae_model, transformer=transformer_model) |
| 851 | + pipe = VQDiffusionPipeline( |
| 852 | + vqvae=vqvae_model, transformer=transformer_model, tokenizer=tokenizer_model, text_encoder=text_encoder_model |
| 853 | + ) |
825 | 854 | pipe.save_pretrained(args.dump_path)
|
826 | 855 |
|
827 | 856 | print("done writing VQ diffusion model")
|
0 commit comments