35
35
from flax .training .common_utils import shard
36
36
from huggingface_hub import HfFolder , Repository , create_repo , whoami
37
37
from PIL import Image
38
+ from torch .utils .data import IterableDataset
38
39
from torchvision import transforms
39
40
from tqdm .auto import tqdm
40
41
from transformers import CLIPTokenizer , FlaxCLIPTextModel , set_seed
@@ -206,7 +207,7 @@ def parse_args():
206
207
parser .add_argument (
207
208
"--from_pt" ,
208
209
action = "store_true" ,
209
- help = "Load the pretrained model from a pytorch checkpoint." ,
210
+ help = "Load the pretrained model from a PyTorch checkpoint." ,
210
211
)
211
212
parser .add_argument (
212
213
"--tokenizer_name" ,
@@ -332,6 +333,7 @@ def parse_args():
332
333
" or to a folder containing files that 🤗 Datasets can understand."
333
334
),
334
335
)
336
+ parser .add_argument ("--streaming" , action = "store_true" , help = "To stream a large dataset from Hub." )
335
337
parser .add_argument (
336
338
"--dataset_config_name" ,
337
339
type = str ,
@@ -369,7 +371,7 @@ def parse_args():
369
371
default = None ,
370
372
help = (
371
373
"For debugging purposes or quicker training, truncate the number of training examples to this "
372
- "value if set."
374
+ "value if set. Needed if `streaming` is set to True. "
373
375
),
374
376
)
375
377
parser .add_argument (
@@ -453,10 +455,15 @@ def parse_args():
453
455
" or the same number of `--validation_prompt`s and `--validation_image`s"
454
456
)
455
457
458
+ # This idea comes from
459
+ # https://github.com/borisdayma/dalle-mini/blob/d2be512d4a6a9cda2d63ba04afc33038f98f705f/src/dalle_mini/data.py#L370
460
+ if args .streaming and args .max_train_samples is None :
461
+ raise ValueError ("You must specify `max_train_samples` when using dataset streaming." )
462
+
456
463
return args
457
464
458
465
459
- def make_train_dataset (args , tokenizer ):
466
+ def make_train_dataset (args , tokenizer , batch_size = None ):
460
467
# Get the datasets: you can either provide your own training and evaluation files (see below)
461
468
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
462
469
@@ -468,6 +475,7 @@ def make_train_dataset(args, tokenizer):
468
475
args .dataset_name ,
469
476
args .dataset_config_name ,
470
477
cache_dir = args .cache_dir ,
478
+ streaming = args .streaming ,
471
479
)
472
480
else :
473
481
data_files = {}
@@ -483,7 +491,10 @@ def make_train_dataset(args, tokenizer):
483
491
484
492
# Preprocessing the datasets.
485
493
# We need to tokenize inputs and targets.
486
- column_names = dataset ["train" ].column_names
494
+ if isinstance (dataset ["train" ], IterableDataset ):
495
+ column_names = next (iter (dataset ["train" ])).keys ()
496
+ else :
497
+ column_names = dataset ["train" ].column_names
487
498
488
499
# 6. Get the column names for input/target.
489
500
if args .image_column is None :
@@ -565,9 +576,20 @@ def preprocess_train(examples):
565
576
566
577
if jax .process_index () == 0 :
567
578
if args .max_train_samples is not None :
568
- dataset ["train" ] = dataset ["train" ].shuffle (seed = args .seed ).select (range (args .max_train_samples ))
579
+ if args .streaming :
580
+ dataset ["train" ] = dataset ["train" ].shuffle (seed = args .seed ).take (args .max_train_samples )
581
+ else :
582
+ dataset ["train" ] = dataset ["train" ].shuffle (seed = args .seed ).select (range (args .max_train_samples ))
569
583
# Set the training transforms
570
- train_dataset = dataset ["train" ].with_transform (preprocess_train )
584
+ if args .streaming :
585
+ train_dataset = dataset ["train" ].map (
586
+ preprocess_train ,
587
+ batched = True ,
588
+ batch_size = batch_size ,
589
+ remove_columns = list (dataset ["train" ].features .keys ()),
590
+ )
591
+ else :
592
+ train_dataset = dataset ["train" ].with_transform (preprocess_train )
571
593
572
594
return train_dataset
573
595
@@ -661,12 +683,12 @@ def main():
661
683
raise NotImplementedError ("No tokenizer specified!" )
662
684
663
685
# Get the datasets: you can either provide your own training and evaluation files (see below)
664
- train_dataset = make_train_dataset (args , tokenizer )
665
686
total_train_batch_size = args .train_batch_size * jax .local_device_count () * args .gradient_accumulation_steps
687
+ train_dataset = make_train_dataset (args , tokenizer , batch_size = total_train_batch_size )
666
688
667
689
train_dataloader = torch .utils .data .DataLoader (
668
690
train_dataset ,
669
- shuffle = True ,
691
+ shuffle = not args . streaming ,
670
692
collate_fn = collate_fn ,
671
693
batch_size = total_train_batch_size ,
672
694
num_workers = args .dataloader_num_workers ,
@@ -897,7 +919,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
897
919
vae_params = jax_utils .replicate (vae_params )
898
920
899
921
# Train!
900
- num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
922
+ if args .streaming :
923
+ dataset_length = args .max_train_samples
924
+ else :
925
+ dataset_length = len (train_dataloader )
926
+ num_update_steps_per_epoch = math .ceil (dataset_length / args .gradient_accumulation_steps )
901
927
902
928
# Scheduler and math around the number of training steps.
903
929
if args .max_train_steps is None :
@@ -906,7 +932,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
906
932
args .num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
907
933
908
934
logger .info ("***** Running training *****" )
909
- logger .info (f" Num examples = { len (train_dataset )} " )
935
+ logger .info (f" Num examples = { args . max_train_samples if args . streaming else len (train_dataset )} " )
910
936
logger .info (f" Num Epochs = { args .num_train_epochs } " )
911
937
logger .info (f" Instantaneous batch size per device = { args .train_batch_size } " )
912
938
logger .info (f" Total train batch size (w. parallel & distributed) = { total_train_batch_size } " )
@@ -916,7 +942,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
916
942
wandb .define_metric ("*" , step_metric = "train/step" )
917
943
wandb .config .update (
918
944
{
919
- "num_train_examples" : len (train_dataset ),
945
+ "num_train_examples" : args . max_train_samples if args . streaming else len (train_dataset ),
920
946
"total_train_batch_size" : total_train_batch_size ,
921
947
"total_optimization_step" : args .num_train_epochs * num_update_steps_per_epoch ,
922
948
"num_devices" : jax .device_count (),
@@ -935,7 +961,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
935
961
936
962
train_metrics = []
937
963
938
- steps_per_epoch = len (train_dataset ) // total_train_batch_size
964
+ steps_per_epoch = (
965
+ args .max_train_samples // total_train_batch_size
966
+ if args .streaming
967
+ else len (train_dataset ) // total_train_batch_size
968
+ )
939
969
train_step_progress_bar = tqdm (
940
970
total = steps_per_epoch ,
941
971
desc = "Training..." ,
@@ -980,7 +1010,8 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
980
1010
981
1011
# Create the pipeline using using the trained modules and save it.
982
1012
if jax .process_index () == 0 :
983
- image_logs = log_validation (controlnet , state .params , tokenizer , args , validation_rng , weight_dtype )
1013
+ if args .validation_prompt is not None :
1014
+ image_logs = log_validation (controlnet , state .params , tokenizer , args , validation_rng , weight_dtype )
984
1015
985
1016
controlnet .save_pretrained (
986
1017
args .output_dir ,
0 commit comments