-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Description
As we discussed how well a large dataset would play in conjunction with the ControlNet training script in JAX, I considered consolidating what I have been doing in this regard.
To mimic a mid-scale dataset for training, I created this dummy dataset: https://huggingface.co/datasets/sayakpaul/dummy-controlnet-100000-samples. It has 100000 samples.
Then I followed different training examples that use JAX to come up with a dummy script that would load the dataset in the streaming model and perform a dummy forward pass (no models needed for this):
- https://github.com/huggingface/transformers/blob/main/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py
- https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_flax.py
This is the dummy script I used for testing: https://gist.github.com/sayakpaul/b7ba991bf15e302ad0c9bbfe6e1e7f83#file-train_dummy-py. Things seem to work fine unless I am missing out on something. In that case, please let me know.
Additionally, one can find this gist that has all the instructions and scripts I used to perform this testing.
Ccing @patrickvonplaten @yiyixuxu