diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 4b388d92a195..387755624729 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -284,9 +284,9 @@ TPU_TYPE=v4-8 VM_NAME=hg_flax gcloud alpha compute tpus tpu-vm create $VM_NAME \ - --zone $ZONE \ - --accelerator-type $TPU_TYPE \ - --version tpu-vm-v4-base + --zone $ZONE \ + --accelerator-type $TPU_TYPE \ + --version tpu-vm-v4-base gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \ ``` @@ -326,6 +326,7 @@ If you want to use Weights and Biases logging, you should also install `wandb` n pip install wandb ``` + Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress ``` @@ -343,8 +344,8 @@ Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment v ```bash export MODEL_DIR="runwayml/stable-diffusion-v1-5" -export OUTPUT_DIR="control_out" -export HUB_MODEL_ID="fill-circle-controlnet" +export OUTPUT_DIR="runs/fill-circle-{timestamp}" +export HUB_MODEL_ID="controlnet-fill-circle" ``` And finally start the training @@ -363,32 +364,36 @@ python3 train_controlnet_flax.py \ --revision="non-ema" \ --from_pt \ --report_to="wandb" \ - --max_train_steps=10000 \ + --tracker_project_name=$HUB_MODEL_ID \ + --num_train_epochs=11 \ --push_to_hub \ --hub_model_id=$HUB_MODEL_ID ``` Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet). -Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command: +Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command (from [this blog article](https://huggingface.co/blog/train-your-controlnet)): ```bash +export MODEL_DIR="runwayml/stable-diffusion-v1-5" +export OUTPUT_DIR="runs/uncanny-faces-{timestamp}" +export HUB_MODEL_ID="controlnet-uncanny-faces" + python3 train_controlnet_flax.py \ - --pretrained_model_name_or_path=$MODEL_DIR \ - --output_dir=$OUTPUT_DIR \ - --dataset_name=multimodalart/facesyntheticsspigacaptioned \ - --streaming \ - --conditioning_image_column=spiga_seg \ - --image_column=image \ - --caption_column=image_caption \ - --resolution=512 \ - --max_train_samples 50 \ - --max_train_steps 5 \ - --learning_rate=1e-5 \ - --validation_steps=2 \ - --train_batch_size=1 \ - --revision="flax" \ - --report_to="wandb" + --pretrained_model_name_or_path=$MODEL_DIR \ + --output_dir=$OUTPUT_DIR \ + --dataset_name=multimodalart/facesyntheticsspigacaptioned \ + --streaming \ + --conditioning_image_column=spiga_seg \ + --image_column=image \ + --caption_column=image_caption \ + --resolution=512 \ + --max_train_samples 100000 \ + --learning_rate=1e-5 \ + --train_batch_size=1 \ + --revision="flax" \ + --report_to="wandb" \ + --tracker_project_name=$HUB_MODEL_ID ``` Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options: @@ -400,16 +405,35 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing: ```bash - --checkpointing_steps=500 + --checkpointing_steps=500 ``` This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500 You can then start your training from this saved checkpoint with ```bash - --controlnet_model_name_or_path="./control_out/500" + --controlnet_model_name_or_path="./control_out/500" ``` We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`. -We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation). \ No newline at end of file +We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation). + +You can **profile your code** with: + +```bash + --profile_steps==5 +``` + +Refer to the [JAX documentation on profiling](https://jax.readthedocs.io/en/latest/profiling.html). To inspect the profile trace, you'll have to install and start Tensorboard with the profile plugin: + +```bash +pip install tensorflow tensorboard-plugin-profile +tensorboard --logdir runs/fill-circle-100steps-20230411_165612/ +``` + +The profile can then be inspected at http://localhost:6006/#profile + +Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`). + +Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident). diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 224a50bb7fbe..bef3e49ed007 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -18,6 +18,7 @@ import math import os import random +import time from pathlib import Path import jax @@ -220,6 +221,28 @@ def parse_args(): default=None, help="Revision of controlnet model identifier from huggingface.co/models.", ) + parser.add_argument( + "--profile_steps", + type=int, + default=0, + help="How many training steps to profile in the beginning.", + ) + parser.add_argument( + "--profile_validation", + action="store_true", + help="Whether to profile the (last) validation.", + ) + parser.add_argument( + "--profile_memory", + action="store_true", + help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.", + ) + parser.add_argument( + "--ccache", + type=str, + default=None, + help="Enables compilation cache.", + ) parser.add_argument( "--controlnet_from_pt", action="store_true", @@ -234,8 +257,9 @@ def parse_args(): parser.add_argument( "--output_dir", type=str, - default="controlnet-model", - help="The output directory where the model predictions and checkpoints will be written.", + default="runs/{timestamp}", + help="The output directory where the model predictions and checkpoints will be written. " + "Can contain placeholders: {timestamp}.", ) parser.add_argument( "--cache_dir", @@ -317,15 +341,6 @@ def parse_args(): default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), - ) parser.add_argument( "--logging_steps", type=int, @@ -459,6 +474,8 @@ def parse_args(): parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") args = parser.parse_args() + args.output_dir = args.output_dir.replace("{timestamp}", time.strftime("%Y%m%d_%H%M%S")) + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -952,6 +969,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng): metrics = {"loss": loss} metrics = jax.lax.pmean(metrics, axis_name="batch") + def l2(xs): + return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)])) + + metrics["l2_grads"] = l2(jax.tree_util.tree_leaves(grad)) + return new_state, metrics, new_train_rng # Create parallel version of the train step @@ -983,32 +1005,38 @@ def cumul_grad_step(grad_idx, loss_grad_rng): logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}") - if jax.process_index() == 0: + if jax.process_index() == 0 and args.report_to == "wandb": wandb.define_metric("*", step_metric="train/step") + wandb.define_metric("train/step", step_metric="walltime") wandb.config.update( { "num_train_examples": args.max_train_samples if args.streaming else len(train_dataset), "total_train_batch_size": total_train_batch_size, "total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch, "num_devices": jax.device_count(), + "controlnet_params": sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params)), } ) - global_step = 0 + global_step = step0 = 0 epochs = tqdm( range(args.num_train_epochs), desc="Epoch ... ", position=0, disable=jax.process_index() > 0, ) + if args.profile_memory: + jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_initial.prof")) + t00 = t0 = time.monotonic() for epoch in epochs: # ======================== Training ================================ train_metrics = [] + train_metric = None steps_per_epoch = ( args.max_train_samples // total_train_batch_size - if args.streaming + if args.streaming or args.max_train_samples else len(train_dataset) // total_train_batch_size ) train_step_progress_bar = tqdm( @@ -1020,10 +1048,18 @@ def cumul_grad_step(grad_idx, loss_grad_rng): ) # train for batch in train_dataloader: + if args.profile_steps and global_step == 1: + train_metric["loss"].block_until_ready() + jax.profiler.start_trace(args.output_dir) + if args.profile_steps and global_step == 1 + args.profile_steps: + train_metric["loss"].block_until_ready() + jax.profiler.stop_trace() + batch = shard(batch) - state, train_metric, train_rngs = p_train_step( - state, unet_params, text_encoder_params, vae_params, batch, train_rngs - ) + with jax.profiler.StepTraceAnnotation("train", step_num=global_step): + state, train_metric, train_rngs = p_train_step( + state, unet_params, text_encoder_params, vae_params, batch, train_rngs + ) train_metrics.append(train_metric) train_step_progress_bar.update(1) @@ -1041,13 +1077,19 @@ def cumul_grad_step(grad_idx, loss_grad_rng): if global_step % args.logging_steps == 0 and jax.process_index() == 0: if args.report_to == "wandb": + train_metrics = jax_utils.unreplicate(train_metrics) + train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics) wandb.log( { + "walltime": time.monotonic() - t00, "train/step": global_step, - "train/epoch": epoch, - "train/loss": jax_utils.unreplicate(train_metric)["loss"], + "train/epoch": global_step / dataset_length, + "train/steps_per_sec": (global_step - step0) / (time.monotonic() - t0), + **{f"train/{k}": v for k, v in train_metrics.items()}, } ) + t0, step0 = time.monotonic(), global_step + train_metrics = [] if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0: controlnet.save_pretrained( f"{args.output_dir}/{global_step}", @@ -1058,10 +1100,14 @@ def cumul_grad_step(grad_idx, loss_grad_rng): train_step_progress_bar.close() epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") - # Create the pipeline using using the trained modules and save it. + # Final validation & store model. if jax.process_index() == 0: if args.validation_prompt is not None: + if args.profile_validation: + jax.profiler.start_trace(args.output_dir) image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) + if args.profile_validation: + jax.profiler.stop_trace() else: image_logs = None @@ -1084,6 +1130,10 @@ def cumul_grad_step(grad_idx, loss_grad_rng): ignore_patterns=["step_*", "epoch_*"], ) + if args.profile_memory: + jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_final.prof")) + logger.info("Finished training.") + if __name__ == "__main__": main()