Skip to content

[Examples] Add streaming support to the ControlNet training example in JAX #2859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 29, 2023
31 changes: 29 additions & 2 deletions examples/controlnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,15 @@ huggingface-cli login

Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub:

```
```bash
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="control_out"
export HUB_MODEL_ID="fill-circle-controlnet"
```

And finally start the training

```
```bash
python3 train_controlnet_flax.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
Expand All @@ -363,3 +363,30 @@ python3 train_controlnet_flax.py \
```

Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).

Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command:

```bash
python3 train_controlnet_flax.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
--streaming \
--conditioning_image_column=spiga_seg \
--image_column=image \
--caption_column=image_caption \
--resolution=512 \
--max_train_samples 50 \
--max_train_steps 5 \
--learning_rate=1e-5 \
--validation_steps=2 \
--train_batch_size=1 \
--revision="flax" \
--report_to="wandb"
```

Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:

* [Webdataset](https://webdataset.github.io/webdataset/)
* [TorchData](https://github.com/pytorch/data)
* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds)
57 changes: 44 additions & 13 deletions examples/controlnet/train_controlnet_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image
from torch.utils.data import IterableDataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
Expand Down Expand Up @@ -206,7 +207,7 @@ def parse_args():
parser.add_argument(
"--from_pt",
action="store_true",
help="Load the pretrained model from a pytorch checkpoint.",
help="Load the pretrained model from a PyTorch checkpoint.",
)
parser.add_argument(
"--tokenizer_name",
Expand Down Expand Up @@ -332,6 +333,7 @@ def parse_args():
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument("--streaming", action="store_true", help="To stream a large dataset from Hub.")
parser.add_argument(
"--dataset_config_name",
type=str,
Expand Down Expand Up @@ -369,7 +371,7 @@ def parse_args():
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
"value if set. Needed if `streaming` is set to True."
),
)
parser.add_argument(
Expand Down Expand Up @@ -453,10 +455,15 @@ def parse_args():
" or the same number of `--validation_prompt`s and `--validation_image`s"
)

# This idea comes from
# https://github.com/borisdayma/dalle-mini/blob/d2be512d4a6a9cda2d63ba04afc33038f98f705f/src/dalle_mini/data.py#L370
if args.streaming and args.max_train_samples is None:
raise ValueError("You must specify `max_train_samples` when using dataset streaming.")

return args


def make_train_dataset(args, tokenizer):
def make_train_dataset(args, tokenizer, batch_size=None):
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).

Expand All @@ -468,6 +475,7 @@ def make_train_dataset(args, tokenizer):
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
streaming=args.streaming,
)
else:
data_files = {}
Expand All @@ -483,7 +491,10 @@ def make_train_dataset(args, tokenizer):

# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names
if isinstance(dataset["train"], IterableDataset):
column_names = next(iter(dataset["train"])).keys()
else:
column_names = dataset["train"].column_names

# 6. Get the column names for input/target.
if args.image_column is None:
Expand Down Expand Up @@ -565,9 +576,20 @@ def preprocess_train(examples):

if jax.process_index() == 0:
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
if args.streaming:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).take(args.max_train_samples)
else:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
if args.streaming:
train_dataset = dataset["train"].map(
preprocess_train,
batched=True,
batch_size=batch_size,
remove_columns=list(dataset["train"].features.keys()),
)
else:
train_dataset = dataset["train"].with_transform(preprocess_train)

return train_dataset

Expand Down Expand Up @@ -661,12 +683,12 @@ def main():
raise NotImplementedError("No tokenizer specified!")

# Get the datasets: you can either provide your own training and evaluation files (see below)
train_dataset = make_train_dataset(args, tokenizer)
total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps
train_dataset = make_train_dataset(args, tokenizer, batch_size=total_train_batch_size)

train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
shuffle=not args.streaming,
collate_fn=collate_fn,
batch_size=total_train_batch_size,
num_workers=args.dataloader_num_workers,
Expand Down Expand Up @@ -897,7 +919,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
vae_params = jax_utils.replicate(vae_params)

# Train!
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.streaming:
dataset_length = args.max_train_samples
else:
dataset_length = len(train_dataloader)
num_update_steps_per_epoch = math.ceil(dataset_length / args.gradient_accumulation_steps)

# Scheduler and math around the number of training steps.
if args.max_train_steps is None:
Expand All @@ -906,7 +932,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num examples = {args.max_train_samples if args.streaming else len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
Expand All @@ -916,7 +942,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
wandb.define_metric("*", step_metric="train/step")
wandb.config.update(
{
"num_train_examples": len(train_dataset),
"num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
"total_train_batch_size": total_train_batch_size,
"total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
"num_devices": jax.device_count(),
Expand All @@ -935,7 +961,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):

train_metrics = []

steps_per_epoch = len(train_dataset) // total_train_batch_size
steps_per_epoch = (
args.max_train_samples // total_train_batch_size
if args.streaming
else len(train_dataset) // total_train_batch_size
)
train_step_progress_bar = tqdm(
total=steps_per_epoch,
desc="Training...",
Expand Down Expand Up @@ -980,7 +1010,8 @@ def cumul_grad_step(grad_idx, loss_grad_rng):

# Create the pipeline using using the trained modules and save it.
if jax.process_index() == 0:
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
if args.validation_prompt is not None:
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)

controlnet.save_pretrained(
args.output_dir,
Expand Down