-
Notifications
You must be signed in to change notification settings - Fork 6.6k
ort integration #916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ort integration #916
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
Leaving it up to @anton-l , maybe you could take a look whenever you find some time :-) |
Signed-off-by: Ryan Russell <[email protected]>
* documenting `attention_flax.py` file * documenting `embeddings_flax.py` * documenting `unet_blocks_flax.py` * Add new objs to doc page * document `vae_flax.py` * Apply suggestions from code review * modify `unet_2d_condition_flax.py` * make style * Apply suggestions from code review * make style * Apply suggestions from code review * fix indent * fix typo * fix indent unet * Update src/diffusers/models/vae_flax.py * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
the result of running the pipeline is stored in StableDiffusionPipelineOutput.images
* refactor: pipelines readability improvements Signed-off-by: Ryan Russell <[email protected]> * docs: remove todo comment from flax pipeline Signed-off-by: Ryan Russell <[email protected]> Signed-off-by: Ryan Russell <[email protected]>
Fix "ort is not defined" issue.
* docs: `src/diffusers` readability improvements Signed-off-by: Ryan Russell <[email protected]> * docs: `make style` lint Signed-off-by: Ryan Russell <[email protected]> Signed-off-by: Ryan Russell <[email protected]>
…ce#627) fix formula for noise levels in karras scheduler and tests
…ce#447) (huggingface#472) * Return encoded texts by DiffusionPipelines * Updated README to show hot to use enoded_text_input * Reverted examples in README.md * Reverted all * Warning for long prompts * Fix bugs * Formatted
the link points to an old location of the train_unconditional.py file
* Remove deprecated `torch_device` kwarg. * Remove unused imports.
Signed-off-by: Ryan Russell <[email protected]> Signed-off-by: Ryan Russell <[email protected]>
* WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline * todo comment * Fix imports * Fix imports * add dummies * Fix empty init * make pipeline work * up * Allow dtype to be overridden on model load. This may be a temporary solution until huggingface#567 is addressed. * Convert params to bfloat16 or fp16 after loading. This deals with the weights, not the model. * Use Flax schedulers (typing, docstring) * PNDM: replace control flow with jax functions. Otherwise jitting/parallelization don't work properly as they don't know how to deal with traced objects. I temporarily removed `step_prk`. * Pass latents shape to scheduler set_timesteps() PNDMScheduler uses it to reserve space, other schedulers will just ignore it. * Wrap model imports inside availability checks. * Optionally return state in from_config. Useful for Flax schedulers. * Do not convert model weights to dtype. * Re-enable PRK steps with functional implementation. Values returned still not verified for correctness. * Remove left over has_state var. * make style * Apply suggestion list -> tuple Co-authored-by: Suraj Patil <[email protected]> * Apply suggestion list -> tuple Co-authored-by: Suraj Patil <[email protected]> * Remove unused comments. * Use zeros instead of empty. Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
* fix accelerate for testing * fix copies * uP
* [CI] Add Apple M1 tests * setup-python * python build * conda install * remove branch * only 3.8 is built for osx-arm * try fetching prebuilt tokenizers * use user cache * update shells * Reports and cleanup * -> MPS * Disable parallel tests * Better naming * investigate worker crash * return xdist * restart * num_workers=2 * still crashing? * faulthandler for segfaults * faulthandler for segfaults * remove restarts, stop on segfault * torch version * change installation order * Use pre-RC version of PyTorch. To be updated when it is released. * Skip crashing test on MPS, add new one that works. * Skip cuda tests in mps device. * Actually use generator in test. I think this was a typo. * make style Co-authored-by: Pedro Cuenca <[email protected]>
Fix autoencoder test.
* fix accelerate for testing * fix copies * uP
* [CI] Add Apple M1 tests * setup-python * python build * conda install * remove branch * only 3.8 is built for osx-arm * try fetching prebuilt tokenizers * use user cache * update shells * Reports and cleanup * -> MPS * Disable parallel tests * Better naming * investigate worker crash * return xdist * restart * num_workers=2 * still crashing? * faulthandler for segfaults * faulthandler for segfaults * remove restarts, stop on segfault * torch version * change installation order * Use pre-RC version of PyTorch. To be updated when it is released. * Skip crashing test on MPS, add new one that works. * Skip cuda tests in mps device. * Actually use generator in test. I think this was a typo. * make style Co-authored-by: Pedro Cuenca <[email protected]>
Fix autoencoder test.
…ikr/diffusers into prathikrao/ort-integration
|
Hi @anton-l, I've tried rebasing off main a couple times to get the CI tests running but for some reason it says I need to resolve conflicts in tests/test_models_unet.py even though I believe it is merge-able. Could you please let me know what I am doing wrong? Thank you. |
|
Hey @prathikr, Thanks a lot for the PR. We are trying to keep our examples as easy to understand as possible and I think we don't want to mix ORT and PyTorch in the same training example to keep the PyTorch simple and independent from ORT. Could we maybe instead add a ORT compatible as its own standalone script? Cc @anton-l |
|
Hi @patrickvonplaten , Absolutely, I just made the changes but the CI still seems to have issues with tests/test_models_unet.py. I'd like to highlight my changes to the unet_2d.py class. Is this safe to do? I've done my best to verify that removing the dataclass for the output of forward() doesn't affect downstream code if .sample is removed, but it might be best to have someone more familiar with the codebase confirm. Thanks again, |
|
@JingyaHuang could you please take a look at this PR? |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me - @anton-l what do you think?
|
@prathikr however it seems like the commit history of this PR is a bit messed up - could you maybe open a new PR with the intended changes 😅 - thanks! |
|
I'll need a second opinion from @JingyaHuang (we've discussed ORT training recently), don't have enough exp with training myself yet :) |
|
@patrickvonplaten I've created a new PR and tagged you three #1110 |
Integrated ORTModule into examples/unconditional_image_generation/train_unconditional.py
Note: We required a change to unet_2d.py to return a raw tensor as output of the model instead of a dataclass. I verified that the changes to remove this dataclass do not affect baseline execution.
test command: