Skip to content

Conversation

@prathikr
Copy link
Contributor

@prathikr prathikr commented Oct 20, 2022

Integrated ORTModule into examples/unconditional_image_generation/train_unconditional.py

Note: We required a change to unet_2d.py to return a raw tensor as output of the model instead of a dataclass. I verified that the changes to remove this dataclass do not affect baseline execution.

test command:

accelerate launch train_unconditional.py \
  --dataset_name="huggan/flowers-102-categories" \
  --resolution=64 \
  --output_dir="ddpm-ema-flowers-64" \
  --train_batch_size=16 \
  --num_epochs=1 \
  --gradient_accumulation_steps=1 \
  --learning_rate=1e-4 \
  --lr_warmup_steps=500 \
  --mixed_precision=no \
  --ort

@prathikr prathikr marked this pull request as ready for review October 20, 2022 22:02
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@patrickvonplaten
Copy link
Contributor

Leaving it up to @anton-l , maybe you could take a look whenever you find some time :-)

ryanrussell and others added 15 commits October 26, 2022 15:09
* documenting `attention_flax.py` file

* documenting `embeddings_flax.py`

* documenting `unet_blocks_flax.py`

* Add new objs to doc page

* document `vae_flax.py`

* Apply suggestions from code review

* modify `unet_2d_condition_flax.py`

* make style

* Apply suggestions from code review

* make style

* Apply suggestions from code review

* fix indent

* fix typo

* fix indent unet

* Update src/diffusers/models/vae_flax.py

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

Co-authored-by: Mishig Davaadorj <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
the result of running the pipeline is stored in StableDiffusionPipelineOutput.images
* refactor: pipelines readability improvements

Signed-off-by: Ryan Russell <[email protected]>

* docs: remove todo comment from flax pipeline

Signed-off-by: Ryan Russell <[email protected]>

Signed-off-by: Ryan Russell <[email protected]>
* docs: `src/diffusers` readability improvements

Signed-off-by: Ryan Russell <[email protected]>

* docs: `make style` lint

Signed-off-by: Ryan Russell <[email protected]>

Signed-off-by: Ryan Russell <[email protected]>
…ce#627)

fix formula for noise levels in karras scheduler and tests
…ce#447) (huggingface#472)

* Return encoded texts by DiffusionPipelines

* Updated README to show hot to use enoded_text_input

* Reverted examples in README.md

* Reverted all

* Warning for long prompts

* Fix bugs

* Formatted
the link points to an old location of the train_unconditional.py file
* Remove deprecated `torch_device` kwarg.

* Remove unused imports.
* WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline

* todo comment

* Fix imports

* Fix imports

* add dummies

* Fix empty init

* make pipeline work

* up

* Allow dtype to be overridden on model load.

This may be a temporary solution until huggingface#567 is addressed.

* Convert params to bfloat16 or fp16 after loading.

This deals with the weights, not the model.

* Use Flax schedulers (typing, docstring)

* PNDM: replace control flow with jax functions.

Otherwise jitting/parallelization don't work properly as they don't know
how to deal with traced objects.

I temporarily removed `step_prk`.

* Pass latents shape to scheduler set_timesteps()

PNDMScheduler uses it to reserve space, other schedulers will just
ignore it.

* Wrap model imports inside availability checks.

* Optionally return state in from_config.

Useful for Flax schedulers.

* Do not convert model weights to dtype.

* Re-enable PRK steps with functional implementation.

Values returned still not verified for correctness.

* Remove left over has_state var.

* make style

* Apply suggestion list -> tuple

Co-authored-by: Suraj Patil <[email protected]>

* Apply suggestion list -> tuple

Co-authored-by: Suraj Patil <[email protected]>

* Remove unused comments.

* Use zeros instead of empty.

Co-authored-by: Mishig Davaadorj <[email protected]>
Co-authored-by: Mishig Davaadorj <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Suraj Patil <[email protected]>
@prathikr prathikr marked this pull request as ready for review October 26, 2022 23:22
patrickvonplaten and others added 16 commits October 26, 2022 16:23
* fix accelerate for testing

* fix copies

* uP
* [CI] Add Apple M1 tests

* setup-python

* python build

* conda install

* remove branch

* only 3.8 is built for osx-arm

* try fetching prebuilt tokenizers

* use user cache

* update shells

* Reports and cleanup

* -> MPS

* Disable parallel tests

* Better naming

* investigate worker crash

* return xdist

* restart

* num_workers=2

* still crashing?

* faulthandler for segfaults

* faulthandler for segfaults

* remove restarts, stop on segfault

* torch version

* change installation order

* Use pre-RC version of PyTorch.

To be updated when it is released.

* Skip crashing test on MPS, add new one that works.

* Skip cuda tests in mps device.

* Actually use generator in test.

I think this was a typo.

* make style

Co-authored-by: Pedro Cuenca <[email protected]>
* fix accelerate for testing

* fix copies

* uP
* [CI] Add Apple M1 tests

* setup-python

* python build

* conda install

* remove branch

* only 3.8 is built for osx-arm

* try fetching prebuilt tokenizers

* use user cache

* update shells

* Reports and cleanup

* -> MPS

* Disable parallel tests

* Better naming

* investigate worker crash

* return xdist

* restart

* num_workers=2

* still crashing?

* faulthandler for segfaults

* faulthandler for segfaults

* remove restarts, stop on segfault

* torch version

* change installation order

* Use pre-RC version of PyTorch.

To be updated when it is released.

* Skip crashing test on MPS, add new one that works.

* Skip cuda tests in mps device.

* Actually use generator in test.

I think this was a typo.

* make style

Co-authored-by: Pedro Cuenca <[email protected]>
@prathikr
Copy link
Contributor Author

Hi @anton-l, I've tried rebasing off main a couple times to get the CI tests running but for some reason it says I need to resolve conflicts in tests/test_models_unet.py even though I believe it is merge-able. Could you please let me know what I am doing wrong? Thank you.

@patrickvonplaten
Copy link
Contributor

Hey @prathikr,

Thanks a lot for the PR. We are trying to keep our examples as easy to understand as possible and I think we don't want to mix ORT and PyTorch in the same training example to keep the PyTorch simple and independent from ORT.

Could we maybe instead add a ORT compatible as its own standalone script?

Cc @anton-l

@prathikr
Copy link
Contributor Author

prathikr commented Oct 31, 2022

Hi @patrickvonplaten ,

Absolutely, I just made the changes but the CI still seems to have issues with tests/test_models_unet.py. I'd like to highlight my changes to the unet_2d.py class. Is this safe to do? I've done my best to verify that removing the dataclass for the output of forward() doesn't affect downstream code if .sample is removed, but it might be best to have someone more familiar with the codebase confirm.

Thanks again,
Prathik

@prathikr
Copy link
Contributor Author

prathikr commented Nov 1, 2022

@JingyaHuang could you please take a look at this PR?

@anton-l anton-l requested a review from JingyaHuang November 1, 2022 01:11
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me - @anton-l what do you think?

@patrickvonplaten
Copy link
Contributor

@prathikr however it seems like the commit history of this PR is a bit messed up - could you maybe open a new PR with the intended changes 😅 - thanks!

@anton-l
Copy link
Member

anton-l commented Nov 2, 2022

I'll need a second opinion from @JingyaHuang (we've discussed ORT training recently), don't have enough exp with training myself yet :)

@prathikr
Copy link
Contributor Author

prathikr commented Nov 2, 2022

@patrickvonplaten I've created a new PR and tagged you three #1110

@prathikr prathikr closed this Nov 2, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.