From aecfb8fe3fc36b5952ea4d9cc437c1055d3883d2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 5 Apr 2023 17:08:15 +0000 Subject: [PATCH 1/5] fix the error when push_to_hub but not log validation --- examples/controlnet/README.md | 6 ++++++ examples/controlnet/train_controlnet_flax.py | 21 +++++++++++--------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 4e6856560bde..c003cb43a430 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -320,6 +320,12 @@ Then cd in the example folder and run pip install -U -r requirements_flax.txt ``` +If you want to use Weights and Bias logging, you should also install `wandb` now + +```bash +pip install wandb +``` + Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress ``` diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 8d316fd048b9..c6df1031ebde 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -154,15 +154,16 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): img_str = "" - for i, log in enumerate(image_logs): - images = log["images"] - validation_prompt = log["validation_prompt"] - validation_image = log["validation_image"] - validation_image.save(os.path.join(repo_folder, "image_control.png")) - img_str += f"prompt: {validation_prompt}\n" - images = [validation_image] + images - image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) - img_str += f"![images_{i})](./images_{i}.png)\n" + if image_logs is not None: + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" yaml = f""" --- @@ -1021,6 +1022,8 @@ def cumul_grad_step(grad_idx, loss_grad_rng): if jax.process_index() == 0: if args.validation_prompt is not None: image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) + else: + image_logs = None controlnet.save_pretrained( args.output_dir, From c4858e029ca24b64e6155889549fe0728a0bb3ea Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 5 Apr 2023 20:55:08 +0000 Subject: [PATCH 2/5] fix typo --- examples/controlnet/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index c003cb43a430..03973f6ca7e6 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -320,7 +320,7 @@ Then cd in the example folder and run pip install -U -r requirements_flax.txt ``` -If you want to use Weights and Bias logging, you should also install `wandb` now +If you want to use Weights and Biases logging, you should also install `wandb` now ```bash pip install wandb From 3dbe56facbc269e322ec1f0746b36618b8cd36a1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 5 Apr 2023 20:55:36 +0000 Subject: [PATCH 3/5] contronet_from_pt & controlnet_revision --- examples/controlnet/train_controlnet_flax.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index c6df1031ebde..e9d767ce5e60 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -214,6 +214,17 @@ def parse_args(): action="store_true", help="Load the pretrained model from a PyTorch checkpoint.", ) + parser.add_argument( + "--controlnet_revision", + type=str, + default=None, + help="Revision of controlnet model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--controlnet_from_pt", + action="store_true", + help="Load the controlnet model from a PyTorch checkpoint.", + ) parser.add_argument( "--tokenizer_name", type=str, @@ -732,7 +743,10 @@ def main(): if args.controlnet_model_name_or_path: logger.info("Loading existing controlnet weights") controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( - args.controlnet_model_name_or_path, from_pt=True, dtype=jnp.float32 + args.controlnet_model_name_or_path, + revision=args.controlnet_revision, + from_pt=args.controlnet_from_pt, + dtype=jnp.float32, ) else: logger.info("Initializing controlnet weights from unet") From f2b417f6cb55ea56b208d4903577ed85d35d36c3 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 5 Apr 2023 20:57:56 +0000 Subject: [PATCH 4/5] make style --- examples/controlnet/train_controlnet_flax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index e9d767ce5e60..292b665a8a42 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -743,9 +743,9 @@ def main(): if args.controlnet_model_name_or_path: logger.info("Loading existing controlnet weights") controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( - args.controlnet_model_name_or_path, + args.controlnet_model_name_or_path, revision=args.controlnet_revision, - from_pt=args.controlnet_from_pt, + from_pt=args.controlnet_from_pt, dtype=jnp.float32, ) else: From a46ae132021161e754317ba8a916ad102fd4b1ef Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 5 Apr 2023 21:14:10 +0000 Subject: [PATCH 5/5] add intermediate checkpointing to the guide --- examples/controlnet/README.md | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 03973f6ca7e6..f3621ac61309 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -395,4 +395,17 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream * [Webdataset](https://webdataset.github.io/webdataset/) * [TorchData](https://github.com/pytorch/data) -* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds) \ No newline at end of file +* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds) + +When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing: + +```bash + --checkpointing_steps=500 +``` +This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500 + +You can then start your training from this saved checkpoint with + +```bash + --controlnet_model_name_or_path="./control_out/500" +``` \ No newline at end of file