Skip to content

Commit f5ccffe

Browse files
Use accelerate save & loading hooks to have better checkpoint structure (#2048)
* better accelerated saving * up * finish * finish * uP * up * up * fix * Apply suggestions from code review * correct ema * Remove @ * up * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/training/dreambooth.mdx Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Pedro Cuenca <[email protected]>
1 parent e619db2 commit f5ccffe

File tree

5 files changed

+189
-15
lines changed

5 files changed

+189
-15
lines changed

docs/source/en/training/dreambooth.mdx

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,30 @@ This would be a good opportunity to tweak some of your hyperparameters if you wi
127127

128128
Saved checkpoints are stored in a format suitable for resuming training. They not only include the model weights, but also the state of the optimizer, data loaders and learning rate.
129129

130-
You can use a checkpoint for inference, but first you need to convert it to an inference pipeline. This is how you could do it:
130+
**Note**: If you have installed `"accelerate>=0.16.0"` you can use the following code to run
131+
inference from an intermediate checkpoint.
132+
133+
```python
134+
from diffusers import DiffusionPipeline, UNet2DConditionModel
135+
from transformers import CLIPTextModel
136+
import torch
137+
138+
# Load the pipeline with the same arguments (model, revision) that were used for training
139+
model_id = "CompVis/stable-diffusion-v1-4"
140+
141+
unet = UNet2DConditionModel.from_pretrained("/sddata/dreambooth/daruma-v2-1/checkpoint-100/unet")
142+
143+
# if you have trained with `--args.train_text_encoder` make sure to also load the text encoder
144+
text_encoder = CLIPTextModel.from_pretrained("/sddata/dreambooth/daruma-v2-1/checkpoint-100/text_encoder")
145+
146+
pipeline = DiffusionPipeline.from_pretrained(model_id, unet=unet, text_encoder=text_encoder, dtype=torch.float16)
147+
pipeline.to("cuda")
148+
149+
# Perform inference, or save, or push to the hub
150+
pipeline.save_pretrained("dreambooth-pipeline")
151+
```
152+
153+
If you have installed `"accelerate<0.16.0"` you need to first convert it to an inference pipeline. This is how you could do it:
131154

132155
```python
133156
from accelerate import Accelerator
@@ -271,6 +294,10 @@ accelerate launch train_dreambooth.py \
271294

272295
Once you have trained a model, inference can be done using the `StableDiffusionPipeline`, by simply indicating the path where the model was saved. Make sure that your prompts include the special `identifier` used during training (`sks` in the previous examples).
273296

297+
**Note**: If you have installed `"accelerate>=0.16.0"` you can use the following code to run
298+
inference from an intermediate checkpoint.
299+
300+
274301
```python
275302
from diffusers import StableDiffusionPipeline
276303
import torch
@@ -284,4 +311,4 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
284311
image.save("dog-bucket.png")
285312
```
286313

287-
You may also run inference from [any of the saved training checkpoints](#performing-inference-using-a-saved-checkpoint).
314+
You may also run inference from [any of the saved training checkpoints](#performing-inference-using-a-saved-checkpoint).

examples/dreambooth/train_dreambooth.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch.utils.checkpoint
2929
from torch.utils.data import Dataset
3030

31+
import accelerate
3132
import diffusers
3233
import transformers
3334
from accelerate import Accelerator
@@ -38,6 +39,7 @@
3839
from diffusers.utils import check_min_version
3940
from diffusers.utils.import_utils import is_xformers_available
4041
from huggingface_hub import HfFolder, Repository, create_repo, whoami
42+
from packaging import version
4143
from PIL import Image
4244
from torchvision import transforms
4345
from tqdm.auto import tqdm
@@ -606,6 +608,37 @@ def main(args):
606608
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
607609
)
608610

611+
# `accelerate` 0.16.0 will have better support for customized saving
612+
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
613+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
614+
def save_model_hook(models, weights, output_dir):
615+
for model in models:
616+
sub_dir = "unet" if type(model) == type(unet) else "text_encoder"
617+
model.save_pretrained(os.path.join(output_dir, sub_dir))
618+
619+
# make sure to pop weight so that corresponding model is not saved again
620+
weights.pop()
621+
622+
def load_model_hook(models, input_dir):
623+
while len(models) > 0:
624+
# pop models so that they are not loaded again
625+
model = models.pop()
626+
627+
if type(model) == type(text_encoder):
628+
# load transformers style into model
629+
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
630+
model.config = load_model.config
631+
else:
632+
# load diffusers style into model
633+
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
634+
model.register_to_config(**load_model.config)
635+
636+
model.load_state_dict(load_model.state_dict())
637+
del load_model
638+
639+
accelerator.register_save_state_pre_hook(save_model_hook)
640+
accelerator.register_load_state_pre_hook(load_model_hook)
641+
609642
vae.requires_grad_(False)
610643
if not args.train_text_encoder:
611644
text_encoder.requires_grad_(False)

examples/text_to_image/train_text_to_image.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch.nn.functional as F
2727
import torch.utils.checkpoint
2828

29+
import accelerate
2930
import datasets
3031
import diffusers
3132
import transformers
@@ -36,9 +37,10 @@
3637
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
3738
from diffusers.optimization import get_scheduler
3839
from diffusers.training_utils import EMAModel
39-
from diffusers.utils import check_min_version
40+
from diffusers.utils import check_min_version, deprecate
4041
from diffusers.utils.import_utils import is_xformers_available
4142
from huggingface_hub import HfFolder, Repository, create_repo, whoami
43+
from packaging import version
4244
from torchvision import transforms
4345
from tqdm.auto import tqdm
4446
from transformers import CLIPTextModel, CLIPTokenizer
@@ -319,6 +321,16 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
319321

320322
def main():
321323
args = parse_args()
324+
325+
if args.non_ema_revision is not None:
326+
deprecate(
327+
"non_ema_revision!=None",
328+
"0.15.0",
329+
message=(
330+
"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
331+
" use `--variant=non_ema` instead."
332+
),
333+
)
322334
logging_dir = os.path.join(args.output_dir, args.logging_dir)
323335

324336
accelerator = Accelerator(
@@ -396,6 +408,39 @@ def main():
396408
else:
397409
raise ValueError("xformers is not available. Make sure it is installed correctly")
398410

411+
# `accelerate` 0.16.0 will have better support for customized saving
412+
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
413+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
414+
def save_model_hook(models, weights, output_dir):
415+
if args.use_ema:
416+
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
417+
418+
for i, model in enumerate(models):
419+
model.save_pretrained(os.path.join(output_dir, "unet"))
420+
421+
# make sure to pop weight so that corresponding model is not saved again
422+
weights.pop()
423+
424+
def load_model_hook(models, input_dir):
425+
if args.use_ema:
426+
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
427+
ema_unet.load_state_dict(load_model.state_dict())
428+
del load_model
429+
430+
for i in range(len(models)):
431+
# pop models so that they are not loaded again
432+
model = models.pop()
433+
434+
# load diffusers style into model
435+
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
436+
model.register_to_config(**load_model.config)
437+
438+
model.load_state_dict(load_model.state_dict())
439+
del load_model
440+
441+
accelerator.register_save_state_pre_hook(save_model_hook)
442+
accelerator.register_load_state_pre_hook(load_model_hook)
443+
399444
if args.gradient_checkpointing:
400445
unet.enable_gradient_checkpointing()
401446

@@ -552,8 +597,9 @@ def collate_fn(examples):
552597
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
553598
unet, optimizer, train_dataloader, lr_scheduler
554599
)
600+
555601
if args.use_ema:
556-
accelerator.register_for_checkpointing(ema_unet)
602+
ema_unet.to(accelerator.device)
557603

558604
# For mixed precision training we cast the text_encoder and vae weights to half-precision
559605
# as these models are only used for inference, keeping weights in full precision is not required.
@@ -566,8 +612,6 @@ def collate_fn(examples):
566612
# Move text_encode and vae to gpu and cast to weight_dtype
567613
text_encoder.to(accelerator.device, dtype=weight_dtype)
568614
vae.to(accelerator.device, dtype=weight_dtype)
569-
if args.use_ema:
570-
ema_unet.to(accelerator.device)
571615

572616
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
573617
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torch.nn.functional as F
1111

12+
import accelerate
1213
import datasets
1314
import diffusers
1415
from accelerate import Accelerator
@@ -19,6 +20,7 @@
1920
from diffusers.training_utils import EMAModel
2021
from diffusers.utils import check_min_version
2122
from huggingface_hub import HfFolder, Repository, create_repo, whoami
23+
from packaging import version
2224
from torchvision import transforms
2325
from tqdm.auto import tqdm
2426

@@ -271,6 +273,40 @@ def main(args):
271273
logging_dir=logging_dir,
272274
)
273275

276+
# `accelerate` 0.16.0 will have better support for customized saving
277+
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
278+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
279+
def save_model_hook(models, weights, output_dir):
280+
if args.use_ema:
281+
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
282+
283+
for i, model in enumerate(models):
284+
model.save_pretrained(os.path.join(output_dir, "unet"))
285+
286+
# make sure to pop weight so that corresponding model is not saved again
287+
weights.pop()
288+
289+
def load_model_hook(models, input_dir):
290+
if args.use_ema:
291+
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel)
292+
ema_model.load_state_dict(load_model.state_dict())
293+
ema_model.to(accelerator.device)
294+
del load_model
295+
296+
for i in range(len(models)):
297+
# pop models so that they are not loaded again
298+
model = models.pop()
299+
300+
# load diffusers style into model
301+
load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")
302+
model.register_to_config(**load_model.config)
303+
304+
model.load_state_dict(load_model.state_dict())
305+
del load_model
306+
307+
accelerator.register_save_state_pre_hook(save_model_hook)
308+
accelerator.register_load_state_pre_hook(load_model_hook)
309+
274310
# Make one log on every process with the configuration for debugging.
275311
logging.basicConfig(
276312
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -336,6 +372,8 @@ def main(args):
336372
use_ema_warmup=True,
337373
inv_gamma=args.ema_inv_gamma,
338374
power=args.ema_power,
375+
model_cls=UNet2DModel,
376+
model_config=model.config,
339377
)
340378

341379
# Initialize the scheduler
@@ -411,7 +449,6 @@ def transform_images(examples):
411449
)
412450

413451
if args.use_ema:
414-
accelerator.register_for_checkpointing(ema_model)
415452
ema_model.to(accelerator.device)
416453

417454
# We need to initialize the trackers we use, and also store our configuration.

src/diffusers/training_utils.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import os
33
import random
4-
from typing import Iterable, Union
4+
from typing import Any, Dict, Iterable, Optional, Union
55

66
import numpy as np
77
import torch
@@ -57,6 +57,8 @@ def __init__(
5757
use_ema_warmup: bool = False,
5858
inv_gamma: Union[float, int] = 1.0,
5959
power: Union[float, int] = 2 / 3,
60+
model_cls: Optional[Any] = None,
61+
model_config: Dict[str, Any] = None,
6062
**kwargs,
6163
):
6264
"""
@@ -123,6 +125,35 @@ def __init__(
123125
self.power = power
124126
self.optimization_step = 0
125127

128+
self.model_cls = model_cls
129+
self.model_config = model_config
130+
131+
@classmethod
132+
def from_pretrained(cls, path, model_cls) -> "EMAModel":
133+
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
134+
model = model_cls.from_pretrained(path)
135+
136+
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
137+
138+
ema_model.load_state_dict(ema_kwargs)
139+
return ema_model
140+
141+
def save_pretrained(self, path):
142+
if self.model_cls is None:
143+
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
144+
145+
if self.model_config is None:
146+
raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
147+
148+
model = self.model_cls.from_config(self.model_config)
149+
state_dict = self.state_dict()
150+
state_dict.pop("shadow_params", None)
151+
state_dict.pop("collected_params", None)
152+
153+
model.register_to_config(**state_dict)
154+
self.copy_to(model.parameters())
155+
model.save_pretrained(path)
156+
126157
def get_decay(self, optimization_step: int) -> float:
127158
"""
128159
Compute the decay factor for the exponential moving average.
@@ -184,7 +215,7 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
184215
"""
185216
parameters = list(parameters)
186217
for s_param, param in zip(self.shadow_params, parameters):
187-
param.data.copy_(s_param.data)
218+
param.data.copy_(s_param.to(param.device).data)
188219

189220
def to(self, device=None, dtype=None) -> None:
190221
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
@@ -257,13 +288,15 @@ def load_state_dict(self, state_dict: dict) -> None:
257288
if not isinstance(self.power, (float, int)):
258289
raise ValueError("Invalid power")
259290

260-
self.shadow_params = state_dict["shadow_params"]
261-
if not isinstance(self.shadow_params, list):
262-
raise ValueError("shadow_params must be a list")
263-
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
264-
raise ValueError("shadow_params must all be Tensors")
291+
shadow_params = state_dict.get("shadow_params", None)
292+
if shadow_params is not None:
293+
self.shadow_params = shadow_params
294+
if not isinstance(self.shadow_params, list):
295+
raise ValueError("shadow_params must be a list")
296+
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
297+
raise ValueError("shadow_params must all be Tensors")
265298

266-
self.collected_params = state_dict["collected_params"]
299+
self.collected_params = state_dict.get("collected_params", None)
267300
if self.collected_params is not None:
268301
if not isinstance(self.collected_params, list):
269302
raise ValueError("collected_params must be a list")

0 commit comments

Comments
 (0)