Skip to content

[Examples] Add streaming support to the ControlNet training example in JAX #2859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 29, 2023

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Mar 28, 2023

Potentially closes #2840.

Example command:

python3 train_controlnet_flax.py \
	--pretrained_model_name_or_path=$MODEL_DIR \
	--output_dir=$OUTPUT_DIR \
	--dataset_name=multimodalart/facesyntheticsspigacaptioned \
	--streaming \
	--conditioning_image_column=spiga_seg \
	--image_column=image \
	--caption_column=image_caption \
	--resolution=512 \
	--max_train_samples 50 \
	--max_train_steps 5 \
	--learning_rate=1e-5 \
	--validation_steps=2 \
	--train_batch_size=1 \
	--revision="flax" \
	--report_to="wandb"

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 28, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks cool, @yiyixuxu can you take a look as well?

@sayakpaul
Copy link
Member Author

I forgot to add a note about how to deal with num_update_steps_per_epoch when streaming is enabled.

I couldn't find any proper resource that shows how to deal with this. Most of them deal with this situation by using a pre-specified number such as done in https://github.com/borisdayma/dalle-mini/blob/main/src/dalle_mini/data.py. I think this is fine, no?

@yiyixuxu
Copy link
Collaborator

@yiyixuxu
Copy link
Collaborator

@sayakpaul

I think it's en easy fix:)
I will get it working and probably going to push one commit to your PR

@yiyixuxu
Copy link
Collaborator

@sayakpaul

ok, tested here https://wandb.ai/yiyixu/train_controlnet_flax/runs/np6ktkj4/logs?workspace=user-yiyixu
Thanks a lot for the PR! super cool!

feel free to merge whenever:)

@sayakpaul
Copy link
Member Author

Thanks so much, @yiyixuxu, for taking care of the issue.

Merging.

@sayakpaul sayakpaul merged commit d82b032 into main Mar 29, 2023
@sayakpaul sayakpaul deleted the feat/streaming-controlnet branch March 29, 2023 01:12
w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
…n JAX (huggingface#2859)

* improve stable unclip doc.

* feat: add streaming support to controlnet flax training script.

* fix: CLI arg.

* fix: torch dataloader shuffle setting.

* fix: dataset length.

* fix: wandb config.

* fix: steps_per_epoch in the training loop.

* add: entry about streaming in the readme

* get column names from iterable dataset + fix final logging

---------

Co-authored-by: yiyixuxu <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…n JAX (huggingface#2859)

* improve stable unclip doc.

* feat: add streaming support to controlnet flax training script.

* fix: CLI arg.

* fix: torch dataloader shuffle setting.

* fix: dataset length.

* fix: wandb config.

* fix: steps_per_epoch in the training loop.

* add: entry about streaming in the readme

* get column names from iterable dataset + fix final logging

---------

Co-authored-by: yiyixuxu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Examples] Discussions on using streaming datasets for ControlNet training in JAX
4 participants