-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Use accelerate
save & loading hooks to have better checkpoint structure
#2048
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
accelerate
save & loading hooks to have better checkpoint structure
accelerate
save & loading hooks to have better checkpoint structureaccelerate
save & loading hooks to have better checkpoint structure
The documentation is not available anymore as the PR was closed or merged. |
@pcuenca @patil-suraj could you take a look here and give some feedback? Doing this saving & loading would help quite a bit to make this: from diffusers import DiffusionPipeline, Unet2DConditionModel
from transformers import CLIPTextEncoder
unet = UNet2DConditionModel.from_pretrained(...)
text_encoder = CLIPTextEncoder.from_pretrained(...)
pipe = DiffusionPipeline.from_pretrained(..., unet=unet, text_encoder=text_encoder)
... |
Looks cool, will try it out ! |
Can now be used with |
Will open a PR tomorrow to finish it. |
Just tried it out, works like a charm! |
accelerate
save & loading hooks to have better checkpoint structureaccelerate
save & loading hooks to have better checkpoint structure
accelerate
save & loading hooks to have better checkpoint structureaccelerate
save & loading hooks to have better checkpoint structure
@patil-suraj @pcuenca this is ready for a review :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is very cool, I tried it out last week, and it works well. I just left a comment about EMA weight checkpoints.
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | ||
def save_model_hook(models, weights, output_dir): | ||
for i, model in enumerate(models): | ||
model.save_pretrained(os.path.join(output_dir, "unet")) | ||
|
||
# make sure to pop weight so that corresponding model is not saved again | ||
weights.pop() | ||
|
||
def load_model_hook(models, input_dir): | ||
for i in range(len(models)): | ||
# pop models so that they are not loaded again | ||
model = models.pop() | ||
|
||
# load diffusers style into model | ||
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") | ||
model.register_to_config(**load_model.config) | ||
|
||
model.load_state_dict(load_model.state_dict()) | ||
del load_model | ||
|
||
accelerator.register_save_state_pre_hook(save_model_hook) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is it also possible to extend this for ema_unet
? It's registered for checkpointing using register_for_checkpointing
and will be saved as custom_checkpoints0.bin
. For text_to_image, we need to load the ema weights for inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a very good point. I tested both train_text_to_image
and train_unconditional
and I saw that:
- The ema model is not sent to the
save_model_hook
function despite being registered for checkpointing. I suppose the hook only deals with weights being trained. - In
train_unconditional
nocustom_checkpoints0.bin
is saved. I don't know the reason why.
If we need to save ema weights then we might need to save the model ourselves.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think It's not passed to save_model_hook
because it's not registered as a nn.Moudle
in with accelerate.prepare
. We could also, instead make the EMAModel
an instance of nn.Module
and pass to prepare
so it'll work with save_model_hook
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, you are right! Do you think it's better to make it a nn.Module
or handle it as a special case in the training scripts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch @patil-suraj !
I would be in favor of making the EMAModel
a torch.nn.Module
so that it's recognized automatically by accelerate
here: https://github.com/huggingface/accelerate/blob/d0df263b0917773935110411a347cc3628213342/src/accelerate/checkpointing.py#L138
The EMA model class defines a very nice save & load state dict functions here:
diffusers/src/diffusers/training_utils.py
Line 201 in 87cf88e
def state_dict(self) -> dict: |
Then I'd slightly need to adapt the hooks to not pop the EMAModel and let accelerate handle the saving and loading.
Think that's a nice solution because:
- EMA model has to config
- IMO it's not strictly necessary to have a nice folder structure for the EMA model during training as the "non-ema" UNet should be checked for training progress instead. I'd argue, it's easier to directly see with the non-EMA model whether training is overfitting or not?
Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree, no harm in making EMAModel
an instance of nn.Module
.
I'd argue, it's easier to directly see with the non-EMA model whether training is overfitting or not?
That's true, but if we are doing generation during training, ideally we would want to use the ema model, as that's the final model we use for actual inference.
if version.parse(accelerate.__version__) >= version.parse("0.15.0.dev0"): | ||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | ||
def save_model_hook(models, weights, output_dir): | ||
for i, model in enumerate(models): | ||
model.save_pretrained(os.path.join(output_dir, "unet")) | ||
|
||
# make sure to pop weight so that corresponding model is not saved again | ||
weights.pop() | ||
|
||
def load_model_hook(models, input_dir): | ||
for i in range(len(models)): | ||
# pop models so that they are not loaded again | ||
model = models.pop() | ||
|
||
# load diffusers style into model | ||
load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet") | ||
model.register_to_config(**load_model.config) | ||
|
||
model.load_state_dict(load_model.state_dict()) | ||
del load_model | ||
|
||
accelerator.register_save_state_pre_hook(save_model_hook) | ||
accelerator.register_load_state_pre_hook(load_model_hook) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment about the ema model as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, I'll test it out for real.
@@ -609,6 +611,37 @@ def main(args): | |||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision | |||
) | |||
|
|||
# `accelerate` 0.15.0 will have better support for customized saving | |||
if version.parse(accelerate.__version__) >= version.parse("0.15.0.dev0"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if version.parse(accelerate.__version__) >= version.parse("0.15.0.dev0"): | |
if version.parse(accelerate.__version__) >= version.parse("0.15.0"): |
nit: I think it's enough to refer to the official version (this check will succeed and it's maybe less intimidating :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, this doesn't work properly with either 0.15.0
or 0.15.0.dev0
.
The current pip version is 0.15.0
, which is >= "0.15.0.dev0"
but doesn't have the hook registration functions, so the script will raise an exception.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another comment: this method will be incompatible with state restoration from previous training runs. Maybe we should create a nice error message explaining it?
@@ -609,6 +611,37 @@ def main(args): | |||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision | |||
) | |||
|
|||
# `accelerate` 0.15.0 will have better support for customized saving | |||
if version.parse(accelerate.__version__) >= version.parse("0.15.0.dev0"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, this doesn't work properly with either 0.15.0
or 0.15.0.dev0
.
The current pip version is 0.15.0
, which is >= "0.15.0.dev0"
but doesn't have the hook registration functions, so the script will raise an exception.
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | ||
def save_model_hook(models, weights, output_dir): | ||
for i, model in enumerate(models): | ||
model.save_pretrained(os.path.join(output_dir, "unet")) | ||
|
||
# make sure to pop weight so that corresponding model is not saved again | ||
weights.pop() | ||
|
||
def load_model_hook(models, input_dir): | ||
for i in range(len(models)): | ||
# pop models so that they are not loaded again | ||
model = models.pop() | ||
|
||
# load diffusers style into model | ||
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") | ||
model.register_to_config(**load_model.config) | ||
|
||
model.load_state_dict(load_model.state_dict()) | ||
del load_model | ||
|
||
accelerator.register_save_state_pre_hook(save_model_hook) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a very good point. I tested both train_text_to_image
and train_unconditional
and I saw that:
- The ema model is not sent to the
save_model_hook
function despite being registered for checkpointing. I suppose the hook only deals with weights being trained. - In
train_unconditional
nocustom_checkpoints0.bin
is saved. I don't know the reason why.
If we need to save ema weights then we might need to save the model ourselves.
@@ -566,8 +612,6 @@ def collate_fn(examples): | |||
# Move text_encode and vae to gpu and cast to weight_dtype | |||
text_encoder.to(accelerator.device, dtype=weight_dtype) | |||
vae.to(accelerator.device, dtype=weight_dtype) | |||
if args.use_ema: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done above
Co-authored-by: Pedro Cuenca <[email protected]>
Update: As discussed with @pcuenca and @patil-suraj this PR adds a new file structure for EMA training which now looks as follows for a checkpoint directory:
The EMAModel config params are saved inside the UNet's config.json which allows to both correctly reload the EMAModel as well as loading the diffusion model via
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, looks great! I'll try to test again later.
Co-authored-by: Pedro Cuenca <[email protected]>
…ture (huggingface#2048) * better accelerated saving * up * finish * finish * uP * up * up * fix * Apply suggestions from code review * correct ema * Remove @ * up * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/training/dreambooth.mdx Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Pedro Cuenca <[email protected]>
…ture (huggingface#2048) * better accelerated saving * up * finish * finish * uP * up * up * fix * Apply suggestions from code review * correct ema * Remove @ * up * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/training/dreambooth.mdx Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Pedro Cuenca <[email protected]>
…ture (huggingface#2048) * better accelerated saving * up * finish * finish * uP * up * up * fix * Apply suggestions from code review * correct ema * Remove @ * up * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/training/dreambooth.mdx Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Pedro Cuenca <[email protected]>
This PR is a showcase of how accelerate saving & loading hooks (see PR here) could be used to make sure intermediate checkpointing in
diffusers
Can be adapted and merged once 991 in accelerate gets merged.
This is ready for review.