Skip to content

Use accelerate save & loading hooks to have better checkpoint structure #2048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 7, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Jan 20, 2023

This PR is a showcase of how accelerate saving & loading hooks (see PR here) could be used to make sure intermediate checkpointing in diffusers

Can be adapted and merged once 991 in accelerate gets merged.

This is ready for review.

@patrickvonplaten patrickvonplaten changed the title better accelerated saving Use accelerate save & loading hooks to have better checkpoint structure Jan 20, 2023
@patrickvonplaten patrickvonplaten changed the title Use accelerate save & loading hooks to have better checkpoint structure [Don't merge] Use accelerate save & loading hooks to have better checkpoint structure Jan 20, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jan 20, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor Author

@pcuenca @patil-suraj could you take a look here and give some feedback? Doing this saving & loading would help quite a bit to make this:
https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint
easier (no need to use accelerate then to load the checkpoint. Instead:

from diffusers import DiffusionPipeline, Unet2DConditionModel
from transformers import CLIPTextEncoder

unet = UNet2DConditionModel.from_pretrained(...)
text_encoder = CLIPTextEncoder.from_pretrained(...) 

pipe = DiffusionPipeline.from_pretrained(..., unet=unet, text_encoder=text_encoder)
...

@patil-suraj
Copy link
Contributor

Looks cool, will try it out !

@patrickvonplaten
Copy link
Contributor Author

Can now be used with accelerate main :-)

@patrickvonplaten
Copy link
Contributor Author

Will open a PR tomorrow to finish it.

@patil-suraj
Copy link
Contributor

Just tried it out, works like a charm!

@patrickvonplaten patrickvonplaten changed the title [Don't merge] Use accelerate save & loading hooks to have better checkpoint structure [WIP] Use accelerate save & loading hooks to have better checkpoint structure Jan 27, 2023
@patrickvonplaten patrickvonplaten changed the title [WIP] Use accelerate save & loading hooks to have better checkpoint structure Use accelerate save & loading hooks to have better checkpoint structure Jan 30, 2023
@patrickvonplaten
Copy link
Contributor Author

@patil-suraj @pcuenca this is ready for a review :-)

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is very cool, I tried it out last week, and it works well. I just left a comment about EMA weight checkpoints.

Comment on lines 413 to 433
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet"))

# make sure to pop weight so that corresponding model is not saved again
weights.pop()

def load_model_hook(models, input_dir):
for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()

# load diffusers style into model
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)

model.load_state_dict(load_model.state_dict())
del load_model

accelerator.register_save_state_pre_hook(save_model_hook)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is it also possible to extend this for ema_unet ? It's registered for checkpointing using register_for_checkpointing and will be saved as custom_checkpoints0.bin. For text_to_image, we need to load the ema weights for inference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a very good point. I tested both train_text_to_image and train_unconditional and I saw that:

  • The ema model is not sent to the save_model_hook function despite being registered for checkpointing. I suppose the hook only deals with weights being trained.
  • In train_unconditional no custom_checkpoints0.bin is saved. I don't know the reason why.

If we need to save ema weights then we might need to save the model ourselves.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think It's not passed to save_model_hook because it's not registered as a nn.Moudle in with accelerate.prepare. We could also, instead make the EMAModel an instance of nn.Module and pass to prepare so it'll work with save_model_hook

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you are right! Do you think it's better to make it a nn.Module or handle it as a special case in the training scripts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch @patil-suraj !

I would be in favor of making the EMAModel a torch.nn.Module so that it's recognized automatically by accelerate here: https://github.com/huggingface/accelerate/blob/d0df263b0917773935110411a347cc3628213342/src/accelerate/checkpointing.py#L138

The EMA model class defines a very nice save & load state dict functions here:

def state_dict(self) -> dict:
which include all the important parameters.

Then I'd slightly need to adapt the hooks to not pop the EMAModel and let accelerate handle the saving and loading.

Think that's a nice solution because:

  • EMA model has to config
  • IMO it's not strictly necessary to have a nice folder structure for the EMA model during training as the "non-ema" UNet should be checked for training progress instead. I'd argue, it's easier to directly see with the non-EMA model whether training is overfitting or not?

Wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree, no harm in making EMAModel an instance of nn.Module.

I'd argue, it's easier to directly see with the non-EMA model whether training is overfitting or not?

That's true, but if we are doing generation during training, ideally we would want to use the ema model, as that's the final model we use for actual inference.

Comment on lines 271 to 293
if version.parse(accelerate.__version__) >= version.parse("0.15.0.dev0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet"))

# make sure to pop weight so that corresponding model is not saved again
weights.pop()

def load_model_hook(models, input_dir):
for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()

# load diffusers style into model
load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)

model.load_state_dict(load_model.state_dict())
del load_model

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about the ema model as above.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, I'll test it out for real.

@@ -609,6 +611,37 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)

# `accelerate` 0.15.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.15.0.dev0"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if version.parse(accelerate.__version__) >= version.parse("0.15.0.dev0"):
if version.parse(accelerate.__version__) >= version.parse("0.15.0"):

nit: I think it's enough to refer to the official version (this check will succeed and it's maybe less intimidating :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this doesn't work properly with either 0.15.0 or 0.15.0.dev0.

The current pip version is 0.15.0, which is >= "0.15.0.dev0" but doesn't have the hook registration functions, so the script will raise an exception.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another comment: this method will be incompatible with state restoration from previous training runs. Maybe we should create a nice error message explaining it?

@@ -609,6 +611,37 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)

# `accelerate` 0.15.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.15.0.dev0"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this doesn't work properly with either 0.15.0 or 0.15.0.dev0.

The current pip version is 0.15.0, which is >= "0.15.0.dev0" but doesn't have the hook registration functions, so the script will raise an exception.

Comment on lines 413 to 433
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet"))

# make sure to pop weight so that corresponding model is not saved again
weights.pop()

def load_model_hook(models, input_dir):
for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()

# load diffusers style into model
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)

model.load_state_dict(load_model.state_dict())
del load_model

accelerator.register_save_state_pre_hook(save_model_hook)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a very good point. I tested both train_text_to_image and train_unconditional and I saw that:

  • The ema model is not sent to the save_model_hook function despite being registered for checkpointing. I suppose the hook only deals with weights being trained.
  • In train_unconditional no custom_checkpoints0.bin is saved. I don't know the reason why.

If we need to save ema weights then we might need to save the model ourselves.

@@ -566,8 +612,6 @@ def collate_fn(examples):
# Move text_encode and vae to gpu and cast to weight_dtype
text_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
if args.use_ema:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done above

@patrickvonplaten
Copy link
Contributor Author

Update:

As discussed with @pcuenca and @patil-suraj this PR adds a new file structure for EMA training which now looks as follows for a checkpoint directory:

.
├── optimizer.bin
├── random_states_0.pkl
├── scheduler.bin
├── unet
│   ├── config.json
│   └── diffusion_pytorch_model.bin
└── unet_ema
    ├── config.json
    └── diffusion_pytorch_model.bin

2 directories, 7 files

The EMAModel config params are saved inside the UNet's config.json which allows to both correctly reload the EMAModel as well as loading the diffusion model via UNet.from_pretrained(...):

{
  "_class_name": "UNet2DModel",
  "_diffusers_version": "0.13.0.dev0",
  "act_fn": "silu",
  "add_attention": true,
  "attention_head_dim": 8,
  "block_out_channels": [
    128,
    128,
    256,
    256,
    512,
    512
  ],
  "center_input_sample": false,
  "class_embed_type": null,
  "decay": 0.9999,
  "down_block_types": [
    "DownBlock2D",
    "DownBlock2D",
    "DownBlock2D",
    "DownBlock2D",
    "AttnDownBlock2D",
    "DownBlock2D"
  ],
  "downsample_padding": 1,
  "flip_sin_to_cos": true,
  "freq_shift": 0,
  "in_channels": 3,
  "inv_gamma": 1.0,
  "layers_per_block": 2,
  "mid_block_scale_factor": 1,
  "min_decay": 0.9999,
  "norm_eps": 1e-05,
  "norm_num_groups": 32,
  "num_class_embeds": null,
  "optimization_step": 100,
  "out_channels": 3,
  "power": 0.75,
  "resnet_time_scale_shift": "default",
  "sample_size": 64,
  "time_embedding_type": "positional",
  "up_block_types": [
    "UpBlock2D",
    "AttnUpBlock2D",
    "UpBlock2D",
    "UpBlock2D",
    "UpBlock2D",
    "UpBlock2D"
  ],
  "update_after_step": 0,
  "use_ema_warmup": true
}

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, looks great! I'll try to test again later.

@patrickvonplaten patrickvonplaten merged commit f5ccffe into main Feb 7, 2023
@patil-suraj patil-suraj deleted the better_accelerated_saving branch February 8, 2023 08:20
yiyixuxu pushed a commit to evinpinar/diffusers-attend-and-excite-pipeline that referenced this pull request Feb 16, 2023
…ture (huggingface#2048)

* better accelerated saving

* up

* finish

* finish

* uP

* up

* up

* fix

* Apply suggestions from code review

* correct ema

* Remove @

* up

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* Update docs/source/en/training/dreambooth.mdx

Co-authored-by: Pedro Cuenca <[email protected]>

---------

Co-authored-by: Pedro Cuenca <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…ture (huggingface#2048)

* better accelerated saving

* up

* finish

* finish

* uP

* up

* up

* fix

* Apply suggestions from code review

* correct ema

* Remove @

* up

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* Update docs/source/en/training/dreambooth.mdx

Co-authored-by: Pedro Cuenca <[email protected]>

---------

Co-authored-by: Pedro Cuenca <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…ture (huggingface#2048)

* better accelerated saving

* up

* finish

* finish

* uP

* up

* up

* fix

* Apply suggestions from code review

* correct ema

* Remove @

* up

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* Update docs/source/en/training/dreambooth.mdx

Co-authored-by: Pedro Cuenca <[email protected]>

---------

Co-authored-by: Pedro Cuenca <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants