-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Description
JAX Integration
This issue will be used as a tracker to integrate Stable Diffusion in JAX natively to diffusers. This will enable many cool use cases noteably running stable diffusion on a google colab.
General design:
We will make loosen the forced PyTorch dependency and instead force the user to either install PyTorch or JAX. Then we will mirror the following "base" classes to be JAX compatible:
ModelMixin: patil-suraj/stable-diffusion-jax#10 we should add a FlaxModelMixin class here.
FlaxDiffusionPipeline:
diffusers/src/diffusers/pipeline_utils.py
Line 76 in 25a51b6
| class DiffusionPipeline(ConfigMixin): |
FlaxDiffusionPipeline here.
Note: ModelMixin should be made state-less by default. E.g. weights will not be saved. Also contrary to transformers should we maybe only work with flax.linen.Module classes here @patil-suraj - I don't really think we need the UNetConditionModel and UNetConditionModule design here - we could just go for class UNetConditionModel(nn.Module): here and make sure everything stays stateless no?
TODO:
- 1. Make
diffusersframework independent. This will require some general changes tosetup.pyand our automation tools - 2. Add
FlaxModelMixin: ImplementFlaxModelMixin#493 Here we can take a lot from https://github.com/patil-suraj/stable-diffusion-jax/pull/10/files but I'm not sure we should follow thetransformersdesign here 1-to-1 . Will also ask some google-folks here - 3. Add all the modeling code under
unet_2d_condition_flax.py... - 4. Add automatic conversion PT <=> Flax script. #478
- 5. Add PNDM scheduler under
scheduling_pndm_flax.py - 6. Tests
- 7. Create pipeline and also
FlaxDiffusionPipeline
Happy to take over 1. and finish today and then look into 4. once 3. is done.
@mishig25 do you want to do 2.? (happy to guide you here a bit if you have questions. Also we need to discuss the design here a bit offline maybe)
- & 5. @pcuenca do you want to take this? (think 3. is more important here)
The other parts we can see tomorrow maybe :-)