Skip to content

Commit e6fa703

Browse files
yiyixuxuandsteing
authored andcommitted
minor fix in controlnet flax example (huggingface#2986)
* fix the error when push_to_hub but not log validation * contronet_from_pt & controlnet_revision * add intermediate checkpointing to the guide
1 parent cdb3cc6 commit e6fa703

File tree

2 files changed

+47
-11
lines changed

2 files changed

+47
-11
lines changed

examples/controlnet/README.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,12 @@ Then cd in the example folder and run
320320
pip install -U -r requirements_flax.txt
321321
```
322322

323+
If you want to use Weights and Biases logging, you should also install `wandb` now
324+
325+
```bash
326+
pip install wandb
327+
```
328+
323329
Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress
324330

325331
```
@@ -389,4 +395,17 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream
389395

390396
* [Webdataset](https://webdataset.github.io/webdataset/)
391397
* [TorchData](https://github.com/pytorch/data)
392-
* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds)
398+
* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds)
399+
400+
When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing:
401+
402+
```bash
403+
--checkpointing_steps=500
404+
```
405+
This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500
406+
407+
You can then start your training from this saved checkpoint with
408+
409+
```bash
410+
--controlnet_model_name_or_path="./control_out/500"
411+
```

examples/controlnet/train_controlnet_flax.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,16 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d
156156

157157
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
158158
img_str = ""
159-
for i, log in enumerate(image_logs):
160-
images = log["images"]
161-
validation_prompt = log["validation_prompt"]
162-
validation_image = log["validation_image"]
163-
validation_image.save(os.path.join(repo_folder, "image_control.png"))
164-
img_str += f"prompt: {validation_prompt}\n"
165-
images = [validation_image] + images
166-
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
167-
img_str += f"![images_{i})](./images_{i}.png)\n"
159+
if image_logs is not None:
160+
for i, log in enumerate(image_logs):
161+
images = log["images"]
162+
validation_prompt = log["validation_prompt"]
163+
validation_image = log["validation_image"]
164+
validation_image.save(os.path.join(repo_folder, "image_control.png"))
165+
img_str += f"prompt: {validation_prompt}\n"
166+
images = [validation_image] + images
167+
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
168+
img_str += f"![images_{i})](./images_{i}.png)\n"
168169

169170
yaml = f"""
170171
---
@@ -215,6 +216,17 @@ def parse_args():
215216
action="store_true",
216217
help="Load the pretrained model from a PyTorch checkpoint.",
217218
)
219+
parser.add_argument(
220+
"--controlnet_revision",
221+
type=str,
222+
default=None,
223+
help="Revision of controlnet model identifier from huggingface.co/models.",
224+
)
225+
parser.add_argument(
226+
"--controlnet_from_pt",
227+
action="store_true",
228+
help="Load the controlnet model from a PyTorch checkpoint.",
229+
)
218230
parser.add_argument(
219231
"--profile_steps",
220232
type=int,
@@ -751,7 +763,10 @@ def main():
751763
if args.controlnet_model_name_or_path:
752764
logger.info("Loading existing controlnet weights")
753765
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
754-
args.controlnet_model_name_or_path, from_pt=True, dtype=jnp.float32
766+
args.controlnet_model_name_or_path,
767+
revision=args.controlnet_revision,
768+
from_pt=args.controlnet_from_pt,
769+
dtype=jnp.float32,
755770
)
756771
else:
757772
logger.info("Initializing controlnet weights from unet")
@@ -1066,6 +1081,8 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
10661081
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
10671082
if args.profile_validation:
10681083
jax.profiler.stop_trace()
1084+
else:
1085+
image_logs = None
10691086

10701087
controlnet.save_pretrained(
10711088
args.output_dir,

0 commit comments

Comments
 (0)