Skip to content

Adds profiling flags, computes train metrics average. #3053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
81ac95b
WIP controlnet training
andsteing Apr 6, 2023
cd7c0d2
Adds final logging statement.
andsteing Apr 7, 2023
c15af70
Sets train epochs to 11.
andsteing Apr 8, 2023
f338696
Removes --logging_dir (it's not used).
andsteing Apr 8, 2023
26ceff2
Adds --profile flags.
andsteing Apr 8, 2023
56d777c
Updates --output_dir=runs/fill-circle-{timestamp}.
andsteing Apr 8, 2023
34a63f5
Compute mean of `train_metrics`.
andsteing Apr 8, 2023
f02e482
Improves logging a bit.
andsteing Apr 8, 2023
f7a2f28
Adds --ccache (doesn't really help though).
andsteing Apr 8, 2023
95d5b18
minor fix in controlnet flax example (#2986)
yiyixuxu Apr 6, 2023
0874a9d
Bugfix --profile_steps
andsteing Apr 8, 2023
1a74f34
Sets `RACKER_PROJECT_NAME='controlnet_fill50k'`.
andsteing Apr 8, 2023
7f9a3c3
Logs fractional epoch.
andsteing Apr 8, 2023
4d20a04
Adds relative `walltime` metric.
andsteing Apr 8, 2023
d1f5994
Adds `StepTraceAnnotation` and uses `global_step` insetad of `step`.
andsteing Apr 8, 2023
4d83f6f
Applied `black`.
andsteing Apr 11, 2023
4fed793
Streamlines commands in README a bit.
andsteing Apr 11, 2023
92b7bef
Removes `--ccache`.
andsteing Apr 11, 2023
a827b34
Re-ran `black`.
andsteing Apr 11, 2023
cb410f6
Update examples/controlnet/README.md
andsteing Apr 11, 2023
00a394e
Converts spaces to tab.
andsteing Apr 11, 2023
6dc0c68
Removes repeated args.
andsteing Apr 12, 2023
f4af068
Skips first step (compilation) in profiling
andsteing Apr 12, 2023
cd084f5
Updates README with profiling instructions.
andsteing Apr 12, 2023
5ef3b5c
Unifies tabs/spaces in README.
andsteing Apr 12, 2023
94c6faf
Re-ran style & quality.
andsteing Apr 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions examples/controlnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ If you want to use Weights and Biases logging, you should also install `wandb` n
pip install wandb
```


Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress

```
Expand All @@ -343,8 +344,8 @@ Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment v

```bash
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="control_out"
export HUB_MODEL_ID="fill-circle-controlnet"
export OUTPUT_DIR="runs/fill-circle-{timestamp}"
export HUB_MODEL_ID="controlnet-fill-circle"
```

And finally start the training
Expand All @@ -363,16 +364,21 @@ python3 train_controlnet_flax.py \
--revision="non-ema" \
--from_pt \
--report_to="wandb" \
--max_train_steps=10000 \
--tracker_project_name=$HUB_MODEL_ID \
--num_train_epochs=11 \
--push_to_hub \
--hub_model_id=$HUB_MODEL_ID
```

Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).

Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command:
Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command (from [this Blog article](https://huggingface.co/blog/train-your-controlnet)):

```bash
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="runs/uncanny-faces-{timestamp}"
export HUB_MODEL_ID="controlnet-uncanny-faces"

python3 train_controlnet_flax.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
Expand All @@ -382,13 +388,12 @@ python3 train_controlnet_flax.py \
--image_column=image \
--caption_column=image_caption \
--resolution=512 \
--max_train_samples 50 \
--max_train_steps 5 \
--max_train_samples 100000 \
--learning_rate=1e-5 \
--validation_steps=2 \
--train_batch_size=1 \
--revision="flax" \
--report_to="wandb"
--report_to="wandb" \
--tracker_project_name=$HUB_MODEL_ID
```

Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:
Expand Down
98 changes: 78 additions & 20 deletions examples/controlnet/train_controlnet_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import math
import os
import random
import time
from pathlib import Path

import jax
Expand Down Expand Up @@ -225,6 +226,39 @@ def parse_args():
action="store_true",
help="Load the controlnet model from a PyTorch checkpoint.",
)
parser.add_argument(
"--profile_steps",
type=int,
default=0,
help="How many training steps to profile in the beginning.",
)
parser.add_argument(
"--profile_validation",
action="store_true",
help="Whether to profile the (last) validation.",
)
parser.add_argument(
"--profile_memory",
action="store_true",
help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.",
)
parser.add_argument(
"--ccache",
type=str,
default=None,
help="Enables compilation cache.",
)
parser.add_argument(
"--controlnet_revision",
type=str,
default=None,
help="Revision of controlnet model identifier from huggingface.co/models.",
)
parser.add_argument(
"--controlnet_from_pt",
action="store_true",
help="Load the controlnet model from a PyTorch checkpoint.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
Expand All @@ -234,8 +268,9 @@ def parse_args():
parser.add_argument(
"--output_dir",
type=str,
default="controlnet-model",
help="The output directory where the model predictions and checkpoints will be written.",
default="runs/{timestamp}",
help="The output directory where the model predictions and checkpoints will be written. "
"Can contain placeholders: {timestamp}.",
)
parser.add_argument(
"--cache_dir",
Expand Down Expand Up @@ -317,15 +352,6 @@ def parse_args():
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--logging_steps",
type=int,
Expand Down Expand Up @@ -459,6 +485,8 @@ def parse_args():
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")

args = parser.parse_args()
args.output_dir = args.output_dir.replace("{timestamp}", time.strftime("%Y%m%d_%H%M%S"))

env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
Expand Down Expand Up @@ -952,6 +980,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch")

def l2(xs):
return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))

metrics["l2_grads"] = l2(jax.tree_util.tree_leaves(grad))

return new_state, metrics, new_train_rng

# Create parallel version of the train step
Expand Down Expand Up @@ -983,32 +1016,37 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}")

if jax.process_index() == 0:
if jax.process_index() == 0 and args.report_to == "wandb":
wandb.define_metric("*", step_metric="train/step")
wandb.define_metric("train/step", step_metric="walltime")
wandb.config.update(
{
"num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
"total_train_batch_size": total_train_batch_size,
"total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
"num_devices": jax.device_count(),
"controlnet_params": sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params)),
}
)

global_step = 0
global_step = step0 = 0
epochs = tqdm(
range(args.num_train_epochs),
desc="Epoch ... ",
position=0,
disable=jax.process_index() > 0,
)
if args.profile_memory:
jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_initial.prof"))
t00 = t0 = time.monotonic()
for epoch in epochs:
# ======================== Training ================================

train_metrics = []

steps_per_epoch = (
args.max_train_samples // total_train_batch_size
if args.streaming
if args.streaming or args.max_train_samples
else len(train_dataset) // total_train_batch_size
)
train_step_progress_bar = tqdm(
Expand All @@ -1020,10 +1058,16 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
)
# train
for batch in train_dataloader:
if args.profile_steps and global_step == 0:
jax.profiler.start_trace(args.output_dir)
if args.profile_steps and args.profile_steps == global_step:
jax.profiler.stop_trace()

batch = shard(batch)
state, train_metric, train_rngs = p_train_step(
state, unet_params, text_encoder_params, vae_params, batch, train_rngs
)
with jax.profiler.StepTraceAnnotation("train", step_num=global_step):
state, train_metric, train_rngs = p_train_step(
state, unet_params, text_encoder_params, vae_params, batch, train_rngs
)
train_metrics.append(train_metric)

train_step_progress_bar.update(1)
Expand All @@ -1041,13 +1085,19 @@ def cumul_grad_step(grad_idx, loss_grad_rng):

if global_step % args.logging_steps == 0 and jax.process_index() == 0:
if args.report_to == "wandb":
train_metrics = jax_utils.unreplicate(train_metrics)
train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
wandb.log(
{
"walltime": time.monotonic() - t00,
"train/step": global_step,
"train/epoch": epoch,
"train/loss": jax_utils.unreplicate(train_metric)["loss"],
"train/epoch": global_step / dataset_length,
"train/steps_per_sec": (global_step - step0) / (time.monotonic() - t0),
**{f"train/{k}": v for k, v in train_metrics.items()},
}
)
t0, step0 = time.monotonic(), global_step
train_metrics = []
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
controlnet.save_pretrained(
f"{args.output_dir}/{global_step}",
Expand All @@ -1058,10 +1108,14 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
train_step_progress_bar.close()
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")

# Create the pipeline using using the trained modules and save it.
# Final validation & store model.
if jax.process_index() == 0:
if args.validation_prompt is not None:
if args.profile_validation:
jax.profiler.start_trace(args.output_dir)
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
if args.profile_validation:
jax.profiler.stop_trace()
else:
image_logs = None

Expand All @@ -1084,6 +1138,10 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
ignore_patterns=["step_*", "epoch_*"],
)

if args.profile_memory:
jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_final.prof"))
logger.info("Finished training.")


if __name__ == "__main__":
main()