Skip to content

Commit d82b032

Browse files
sayakpaulyiyixuxu
andauthored
[Examples] Add streaming support to the ControlNet training example in JAX (#2859)
* improve stable unclip doc. * feat: add streaming support to controlnet flax training script. * fix: CLI arg. * fix: torch dataloader shuffle setting. * fix: dataset length. * fix: wandb config. * fix: steps_per_epoch in the training loop. * add: entry about streaming in the readme * get column names from iterable dataset + fix final logging --------- Co-authored-by: yiyixuxu <[email protected]>
1 parent 40a7b86 commit d82b032

File tree

2 files changed

+73
-15
lines changed

2 files changed

+73
-15
lines changed

examples/controlnet/README.md

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,15 +335,15 @@ huggingface-cli login
335335
336336
Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub:
337337
338-
```
338+
```bash
339339
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
340340
export OUTPUT_DIR="control_out"
341341
export HUB_MODEL_ID="fill-circle-controlnet"
342342
```
343343

344344
And finally start the training
345345

346-
```
346+
```bash
347347
python3 train_controlnet_flax.py \
348348
--pretrained_model_name_or_path=$MODEL_DIR \
349349
--output_dir=$OUTPUT_DIR \
@@ -363,3 +363,30 @@ python3 train_controlnet_flax.py \
363363
```
364364

365365
Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).
366+
367+
Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command:
368+
369+
```bash
370+
python3 train_controlnet_flax.py \
371+
--pretrained_model_name_or_path=$MODEL_DIR \
372+
--output_dir=$OUTPUT_DIR \
373+
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
374+
--streaming \
375+
--conditioning_image_column=spiga_seg \
376+
--image_column=image \
377+
--caption_column=image_caption \
378+
--resolution=512 \
379+
--max_train_samples 50 \
380+
--max_train_steps 5 \
381+
--learning_rate=1e-5 \
382+
--validation_steps=2 \
383+
--train_batch_size=1 \
384+
--revision="flax" \
385+
--report_to="wandb"
386+
```
387+
388+
Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:
389+
390+
* [Webdataset](https://webdataset.github.io/webdataset/)
391+
* [TorchData](https://github.com/pytorch/data)
392+
* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds)

examples/controlnet/train_controlnet_flax.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from flax.training.common_utils import shard
3636
from huggingface_hub import HfFolder, Repository, create_repo, whoami
3737
from PIL import Image
38+
from torch.utils.data import IterableDataset
3839
from torchvision import transforms
3940
from tqdm.auto import tqdm
4041
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
@@ -206,7 +207,7 @@ def parse_args():
206207
parser.add_argument(
207208
"--from_pt",
208209
action="store_true",
209-
help="Load the pretrained model from a pytorch checkpoint.",
210+
help="Load the pretrained model from a PyTorch checkpoint.",
210211
)
211212
parser.add_argument(
212213
"--tokenizer_name",
@@ -332,6 +333,7 @@ def parse_args():
332333
" or to a folder containing files that 🤗 Datasets can understand."
333334
),
334335
)
336+
parser.add_argument("--streaming", action="store_true", help="To stream a large dataset from Hub.")
335337
parser.add_argument(
336338
"--dataset_config_name",
337339
type=str,
@@ -369,7 +371,7 @@ def parse_args():
369371
default=None,
370372
help=(
371373
"For debugging purposes or quicker training, truncate the number of training examples to this "
372-
"value if set."
374+
"value if set. Needed if `streaming` is set to True."
373375
),
374376
)
375377
parser.add_argument(
@@ -453,10 +455,15 @@ def parse_args():
453455
" or the same number of `--validation_prompt`s and `--validation_image`s"
454456
)
455457

458+
# This idea comes from
459+
# https://github.com/borisdayma/dalle-mini/blob/d2be512d4a6a9cda2d63ba04afc33038f98f705f/src/dalle_mini/data.py#L370
460+
if args.streaming and args.max_train_samples is None:
461+
raise ValueError("You must specify `max_train_samples` when using dataset streaming.")
462+
456463
return args
457464

458465

459-
def make_train_dataset(args, tokenizer):
466+
def make_train_dataset(args, tokenizer, batch_size=None):
460467
# Get the datasets: you can either provide your own training and evaluation files (see below)
461468
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
462469

@@ -468,6 +475,7 @@ def make_train_dataset(args, tokenizer):
468475
args.dataset_name,
469476
args.dataset_config_name,
470477
cache_dir=args.cache_dir,
478+
streaming=args.streaming,
471479
)
472480
else:
473481
data_files = {}
@@ -483,7 +491,10 @@ def make_train_dataset(args, tokenizer):
483491

484492
# Preprocessing the datasets.
485493
# We need to tokenize inputs and targets.
486-
column_names = dataset["train"].column_names
494+
if isinstance(dataset["train"], IterableDataset):
495+
column_names = next(iter(dataset["train"])).keys()
496+
else:
497+
column_names = dataset["train"].column_names
487498

488499
# 6. Get the column names for input/target.
489500
if args.image_column is None:
@@ -565,9 +576,20 @@ def preprocess_train(examples):
565576

566577
if jax.process_index() == 0:
567578
if args.max_train_samples is not None:
568-
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
579+
if args.streaming:
580+
dataset["train"] = dataset["train"].shuffle(seed=args.seed).take(args.max_train_samples)
581+
else:
582+
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
569583
# Set the training transforms
570-
train_dataset = dataset["train"].with_transform(preprocess_train)
584+
if args.streaming:
585+
train_dataset = dataset["train"].map(
586+
preprocess_train,
587+
batched=True,
588+
batch_size=batch_size,
589+
remove_columns=list(dataset["train"].features.keys()),
590+
)
591+
else:
592+
train_dataset = dataset["train"].with_transform(preprocess_train)
571593

572594
return train_dataset
573595

@@ -661,12 +683,12 @@ def main():
661683
raise NotImplementedError("No tokenizer specified!")
662684

663685
# Get the datasets: you can either provide your own training and evaluation files (see below)
664-
train_dataset = make_train_dataset(args, tokenizer)
665686
total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps
687+
train_dataset = make_train_dataset(args, tokenizer, batch_size=total_train_batch_size)
666688

667689
train_dataloader = torch.utils.data.DataLoader(
668690
train_dataset,
669-
shuffle=True,
691+
shuffle=not args.streaming,
670692
collate_fn=collate_fn,
671693
batch_size=total_train_batch_size,
672694
num_workers=args.dataloader_num_workers,
@@ -897,7 +919,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
897919
vae_params = jax_utils.replicate(vae_params)
898920

899921
# Train!
900-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
922+
if args.streaming:
923+
dataset_length = args.max_train_samples
924+
else:
925+
dataset_length = len(train_dataloader)
926+
num_update_steps_per_epoch = math.ceil(dataset_length / args.gradient_accumulation_steps)
901927

902928
# Scheduler and math around the number of training steps.
903929
if args.max_train_steps is None:
@@ -906,7 +932,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
906932
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
907933

908934
logger.info("***** Running training *****")
909-
logger.info(f" Num examples = {len(train_dataset)}")
935+
logger.info(f" Num examples = {args.max_train_samples if args.streaming else len(train_dataset)}")
910936
logger.info(f" Num Epochs = {args.num_train_epochs}")
911937
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
912938
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
@@ -916,7 +942,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
916942
wandb.define_metric("*", step_metric="train/step")
917943
wandb.config.update(
918944
{
919-
"num_train_examples": len(train_dataset),
945+
"num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
920946
"total_train_batch_size": total_train_batch_size,
921947
"total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
922948
"num_devices": jax.device_count(),
@@ -935,7 +961,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
935961

936962
train_metrics = []
937963

938-
steps_per_epoch = len(train_dataset) // total_train_batch_size
964+
steps_per_epoch = (
965+
args.max_train_samples // total_train_batch_size
966+
if args.streaming
967+
else len(train_dataset) // total_train_batch_size
968+
)
939969
train_step_progress_bar = tqdm(
940970
total=steps_per_epoch,
941971
desc="Training...",
@@ -980,7 +1010,8 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
9801010

9811011
# Create the pipeline using using the trained modules and save it.
9821012
if jax.process_index() == 0:
983-
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
1013+
if args.validation_prompt is not None:
1014+
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
9841015

9851016
controlnet.save_pretrained(
9861017
args.output_dir,

0 commit comments

Comments
 (0)