Skip to content

Conversation

@prathikr
Copy link
Contributor

@prathikr prathikr commented Nov 2, 2022

Integrated ORTModule into examples/unconditional_image_generation/train_unconditional.py

Note: We required a change to unet_2d.py to return a raw tensor as output of the model instead of a dataclass. I verified that the changes to remove this dataclass do not affect baseline execution.

test command:

accelerate launch train_unconditional_ort.py \
  --dataset_name="huggan/flowers-102-categories" \
  --resolution=64 \
  --output_dir="ddpm-ema-flowers-64" \
  --train_batch_size=16 \
  --num_epochs=1 \
  --gradient_accumulation_steps=1 \
  --learning_rate=1e-4 \
  --lr_warmup_steps=500 \
  --mixed_precision=fp16

@prathikr
Copy link
Contributor Author

prathikr commented Nov 2, 2022

@prathikr prathikr mentioned this pull request Nov 2, 2022
Copy link
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @prathikr! Looks like you need to make changes to the UNet outputs, but that would be a breaking change and we need to maintain the pytorch API for the models. Will the integration work with return_dict=False passed to the UNet instead, to return a tuple?

@JingyaHuang
Copy link
Contributor

The ORT training part looks good to me.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 9, 2022

The documentation is not available anymore as the PR was closed or merged.

@prathikr
Copy link
Contributor Author

prathikr commented Nov 9, 2022

@anton-l yes, integration works with return_dict=False. @JingyaHuang thank you as well for your review. I've gone ahead and updated the script.

One inquiry I still have is how can I ensure train_unconditional_ort.py stays up to date with changes made to train_unconditional.py? Perhaps now that I've confirmed return_dict=False solves the issue, can we revisit the idea of adding --ort as a flag to train_unconditional.py like I've done in #1225? It shouldn't break any downstream tasks as the --ort flag wouldn't be included in your unit tests (or perhaps I could add an optional test for ort/stable-diffusion which would help with test code coverage for the return_dict=False parameter).

@prathikr
Copy link
Contributor Author

@anton-l @patrickvonplaten any updates on this?

@patrickvonplaten
Copy link
Contributor

@anton-l what do you think? The PR is ok for me, I'd maybe leave a comment in the README of examples/unconditional_image_generation though that this script is not maintained and for questions to please ping @prathikr ? I highly doubt we will have time to maintain it 😅

Would that be ok for you @prathikr ?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for merge to me if we leave an example & statement about limited maintenance in the README :-)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@prathikr
Copy link
Contributor Author

Thanks @patrickvonplaten, I've added an example/contacts to the readme.

We do not expect you to maintain this script. We plan to run this script from our daily monitoring pipelines, so we will be privy to any issues quickly and will triage on our end. Should any issues arise, the first debugging step for us will be to make sure train_unconditional.py and train_unconditional_ort.py are the same.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great @prathikr! Making a couple of updates to make the tests pass now :)

@anton-l anton-l merged commit 3346ec3 into huggingface:main Nov 17, 2022
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* integrate ort

* use return_dict=False

* revert unet return value change

* revert unet return value change

* add note to readme

* adjust readme

* add contact

* `make style`

Co-authored-by: Prathik Rao <[email protected]>
Co-authored-by: Anton Lozhkov <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants