Skip to content

Adds profiling flags, computes train metrics average. #3053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
81ac95b
WIP controlnet training
andsteing Apr 6, 2023
cd7c0d2
Adds final logging statement.
andsteing Apr 7, 2023
c15af70
Sets train epochs to 11.
andsteing Apr 8, 2023
f338696
Removes --logging_dir (it's not used).
andsteing Apr 8, 2023
26ceff2
Adds --profile flags.
andsteing Apr 8, 2023
56d777c
Updates --output_dir=runs/fill-circle-{timestamp}.
andsteing Apr 8, 2023
34a63f5
Compute mean of `train_metrics`.
andsteing Apr 8, 2023
f02e482
Improves logging a bit.
andsteing Apr 8, 2023
f7a2f28
Adds --ccache (doesn't really help though).
andsteing Apr 8, 2023
95d5b18
minor fix in controlnet flax example (#2986)
yiyixuxu Apr 6, 2023
0874a9d
Bugfix --profile_steps
andsteing Apr 8, 2023
1a74f34
Sets `RACKER_PROJECT_NAME='controlnet_fill50k'`.
andsteing Apr 8, 2023
7f9a3c3
Logs fractional epoch.
andsteing Apr 8, 2023
4d20a04
Adds relative `walltime` metric.
andsteing Apr 8, 2023
d1f5994
Adds `StepTraceAnnotation` and uses `global_step` insetad of `step`.
andsteing Apr 8, 2023
4d83f6f
Applied `black`.
andsteing Apr 11, 2023
4fed793
Streamlines commands in README a bit.
andsteing Apr 11, 2023
92b7bef
Removes `--ccache`.
andsteing Apr 11, 2023
a827b34
Re-ran `black`.
andsteing Apr 11, 2023
cb410f6
Update examples/controlnet/README.md
andsteing Apr 11, 2023
00a394e
Converts spaces to tab.
andsteing Apr 11, 2023
6dc0c68
Removes repeated args.
andsteing Apr 12, 2023
f4af068
Skips first step (compilation) in profiling
andsteing Apr 12, 2023
cd084f5
Updates README with profiling instructions.
andsteing Apr 12, 2023
5ef3b5c
Unifies tabs/spaces in README.
andsteing Apr 12, 2023
94c6faf
Re-ran style & quality.
andsteing Apr 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 49 additions & 25 deletions examples/controlnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,9 @@ TPU_TYPE=v4-8
VM_NAME=hg_flax

gcloud alpha compute tpus tpu-vm create $VM_NAME \
--zone $ZONE \
--accelerator-type $TPU_TYPE \
--version tpu-vm-v4-base
--zone $ZONE \
--accelerator-type $TPU_TYPE \
--version tpu-vm-v4-base

gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
```
Expand Down Expand Up @@ -326,6 +326,7 @@ If you want to use Weights and Biases logging, you should also install `wandb` n
pip install wandb
```


Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress

```
Expand All @@ -343,8 +344,8 @@ Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment v

```bash
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="control_out"
export HUB_MODEL_ID="fill-circle-controlnet"
export OUTPUT_DIR="runs/fill-circle-{timestamp}"
export HUB_MODEL_ID="controlnet-fill-circle"
```

And finally start the training
Expand All @@ -363,32 +364,36 @@ python3 train_controlnet_flax.py \
--revision="non-ema" \
--from_pt \
--report_to="wandb" \
--max_train_steps=10000 \
--tracker_project_name=$HUB_MODEL_ID \
--num_train_epochs=11 \
--push_to_hub \
--hub_model_id=$HUB_MODEL_ID
```

Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).

Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command:
Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command (from [this blog article](https://huggingface.co/blog/train-your-controlnet)):

```bash
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="runs/uncanny-faces-{timestamp}"
export HUB_MODEL_ID="controlnet-uncanny-faces"

python3 train_controlnet_flax.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
--streaming \
--conditioning_image_column=spiga_seg \
--image_column=image \
--caption_column=image_caption \
--resolution=512 \
--max_train_samples 50 \
--max_train_steps 5 \
--learning_rate=1e-5 \
--validation_steps=2 \
--train_batch_size=1 \
--revision="flax" \
--report_to="wandb"
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
--streaming \
--conditioning_image_column=spiga_seg \
--image_column=image \
--caption_column=image_caption \
--resolution=512 \
--max_train_samples 100000 \
--learning_rate=1e-5 \
--train_batch_size=1 \
--revision="flax" \
--report_to="wandb" \
--tracker_project_name=$HUB_MODEL_ID
```

Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:
Expand All @@ -400,16 +405,35 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream
When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing:

```bash
--checkpointing_steps=500
--checkpointing_steps=500
```
This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500

You can then start your training from this saved checkpoint with

```bash
--controlnet_model_name_or_path="./control_out/500"
--controlnet_model_name_or_path="./control_out/500"
```

We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`.

We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).
We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).

You can **profile your code** with:

```bash
--profile_steps==5
```

Refer to the [JAX documentation on profiling](https://jax.readthedocs.io/en/latest/profiling.html). To inspect the profile trace, you'll have to install and start Tensorboard with the profile plugin:

```bash
pip install tensorflow tensorboard-plugin-profile
tensorboard --logdir runs/fill-circle-100steps-20230411_165612/
```

The profile can then be inspected at http://localhost:6006/#profile

Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`).

Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident).
90 changes: 70 additions & 20 deletions examples/controlnet/train_controlnet_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import math
import os
import random
import time
from pathlib import Path

import jax
Expand Down Expand Up @@ -220,6 +221,28 @@ def parse_args():
default=None,
help="Revision of controlnet model identifier from huggingface.co/models.",
)
parser.add_argument(
"--profile_steps",
type=int,
default=0,
help="How many training steps to profile in the beginning.",
)
parser.add_argument(
"--profile_validation",
action="store_true",
help="Whether to profile the (last) validation.",
)
parser.add_argument(
"--profile_memory",
action="store_true",
help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.",
)
parser.add_argument(
"--ccache",
type=str,
default=None,
help="Enables compilation cache.",
)
parser.add_argument(
"--controlnet_from_pt",
action="store_true",
Expand All @@ -234,8 +257,9 @@ def parse_args():
parser.add_argument(
"--output_dir",
type=str,
default="controlnet-model",
help="The output directory where the model predictions and checkpoints will be written.",
default="runs/{timestamp}",
help="The output directory where the model predictions and checkpoints will be written. "
"Can contain placeholders: {timestamp}.",
)
parser.add_argument(
"--cache_dir",
Expand Down Expand Up @@ -317,15 +341,6 @@ def parse_args():
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--logging_steps",
type=int,
Expand Down Expand Up @@ -459,6 +474,8 @@ def parse_args():
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")

args = parser.parse_args()
args.output_dir = args.output_dir.replace("{timestamp}", time.strftime("%Y%m%d_%H%M%S"))

env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
Expand Down Expand Up @@ -952,6 +969,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch")

def l2(xs):
return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))

metrics["l2_grads"] = l2(jax.tree_util.tree_leaves(grad))

return new_state, metrics, new_train_rng

# Create parallel version of the train step
Expand Down Expand Up @@ -983,32 +1005,38 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}")

if jax.process_index() == 0:
if jax.process_index() == 0 and args.report_to == "wandb":
wandb.define_metric("*", step_metric="train/step")
wandb.define_metric("train/step", step_metric="walltime")
wandb.config.update(
{
"num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
"total_train_batch_size": total_train_batch_size,
"total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
"num_devices": jax.device_count(),
"controlnet_params": sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params)),
}
)

global_step = 0
global_step = step0 = 0
epochs = tqdm(
range(args.num_train_epochs),
desc="Epoch ... ",
position=0,
disable=jax.process_index() > 0,
)
if args.profile_memory:
jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_initial.prof"))
t00 = t0 = time.monotonic()
for epoch in epochs:
# ======================== Training ================================

train_metrics = []
train_metric = None

steps_per_epoch = (
args.max_train_samples // total_train_batch_size
if args.streaming
if args.streaming or args.max_train_samples
else len(train_dataset) // total_train_batch_size
)
train_step_progress_bar = tqdm(
Expand All @@ -1020,10 +1048,18 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
)
# train
for batch in train_dataloader:
if args.profile_steps and global_step == 1:
train_metric["loss"].block_until_ready()
jax.profiler.start_trace(args.output_dir)
if args.profile_steps and global_step == 1 + args.profile_steps:
train_metric["loss"].block_until_ready()
jax.profiler.stop_trace()

batch = shard(batch)
state, train_metric, train_rngs = p_train_step(
state, unet_params, text_encoder_params, vae_params, batch, train_rngs
)
with jax.profiler.StepTraceAnnotation("train", step_num=global_step):
state, train_metric, train_rngs = p_train_step(
state, unet_params, text_encoder_params, vae_params, batch, train_rngs
)
train_metrics.append(train_metric)

train_step_progress_bar.update(1)
Expand All @@ -1041,13 +1077,19 @@ def cumul_grad_step(grad_idx, loss_grad_rng):

if global_step % args.logging_steps == 0 and jax.process_index() == 0:
if args.report_to == "wandb":
train_metrics = jax_utils.unreplicate(train_metrics)
train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
wandb.log(
{
"walltime": time.monotonic() - t00,
"train/step": global_step,
"train/epoch": epoch,
"train/loss": jax_utils.unreplicate(train_metric)["loss"],
"train/epoch": global_step / dataset_length,
"train/steps_per_sec": (global_step - step0) / (time.monotonic() - t0),
**{f"train/{k}": v for k, v in train_metrics.items()},
}
)
t0, step0 = time.monotonic(), global_step
train_metrics = []
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
controlnet.save_pretrained(
f"{args.output_dir}/{global_step}",
Expand All @@ -1058,10 +1100,14 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
train_step_progress_bar.close()
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")

# Create the pipeline using using the trained modules and save it.
# Final validation & store model.
if jax.process_index() == 0:
if args.validation_prompt is not None:
if args.profile_validation:
jax.profiler.start_trace(args.output_dir)
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
if args.profile_validation:
jax.profiler.stop_trace()
else:
image_logs = None

Expand All @@ -1084,6 +1130,10 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
ignore_patterns=["step_*", "epoch_*"],
)

if args.profile_memory:
jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_final.prof"))
logger.info("Finished training.")


if __name__ == "__main__":
main()