-
Notifications
You must be signed in to change notification settings - Fork 6.6k
integrate ort #1110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
integrate ort #1110
Conversation
anton-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @prathikr! Looks like you need to make changes to the UNet outputs, but that would be a breaking change and we need to maintain the pytorch API for the models. Will the integration work with return_dict=False passed to the UNet instead, to return a tuple?
|
The ORT training part looks good to me. |
|
The documentation is not available anymore as the PR was closed or merged. |
|
@anton-l yes, integration works with return_dict=False. @JingyaHuang thank you as well for your review. I've gone ahead and updated the script. One inquiry I still have is how can I ensure train_unconditional_ort.py stays up to date with changes made to train_unconditional.py? Perhaps now that I've confirmed return_dict=False solves the issue, can we revisit the idea of adding --ort as a flag to train_unconditional.py like I've done in #1225? It shouldn't break any downstream tasks as the --ort flag wouldn't be included in your unit tests (or perhaps I could add an optional test for ort/stable-diffusion which would help with test code coverage for the return_dict=False parameter). |
|
@anton-l @patrickvonplaten any updates on this? |
|
@anton-l what do you think? The PR is ok for me, I'd maybe leave a comment in the README of examples/unconditional_image_generation though that this script is not maintained and for questions to please ping @prathikr ? I highly doubt we will have time to maintain it 😅 Would that be ok for you @prathikr ? |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good for merge to me if we leave an example & statement about limited maintenance in the README :-)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
Thanks @patrickvonplaten, I've added an example/contacts to the readme. We do not expect you to maintain this script. We plan to run this script from our daily monitoring pipelines, so we will be privy to any issues quickly and will triage on our end. Should any issues arise, the first debugging step for us will be to make sure train_unconditional.py and train_unconditional_ort.py are the same. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
anton-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great @prathikr! Making a couple of updates to make the tests pass now :)
* integrate ort * use return_dict=False * revert unet return value change * revert unet return value change * add note to readme * adjust readme * add contact * `make style` Co-authored-by: Prathik Rao <[email protected]> Co-authored-by: Anton Lozhkov <[email protected]>
Integrated ORTModule into examples/unconditional_image_generation/train_unconditional.py
Note: We required a change to unet_2d.py to return a raw tensor as output of the model instead of a dataclass. I verified that the changes to remove this dataclass do not affect baseline execution.
test command: