Skip to content

[Examples] Discussions on using streaming datasets for ControlNet training in JAX #2840

@sayakpaul

Description

@sayakpaul

As we discussed how well a large dataset would play in conjunction with the ControlNet training script in JAX, I considered consolidating what I have been doing in this regard.

To mimic a mid-scale dataset for training, I created this dummy dataset: https://huggingface.co/datasets/sayakpaul/dummy-controlnet-100000-samples. It has 100000 samples.

Then I followed different training examples that use JAX to come up with a dummy script that would load the dataset in the streaming model and perform a dummy forward pass (no models needed for this):

This is the dummy script I used for testing: https://gist.github.com/sayakpaul/b7ba991bf15e302ad0c9bbfe6e1e7f83#file-train_dummy-py. Things seem to work fine unless I am missing out on something. In that case, please let me know.

Additionally, one can find this gist that has all the instructions and scripts I used to perform this testing.

Ccing @patrickvonplaten @yiyixuxu

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions