Skip to content

JAX Integration #475

@patrickvonplaten

Description

@patrickvonplaten

JAX Integration

This issue will be used as a tracker to integrate Stable Diffusion in JAX natively to diffusers. This will enable many cool use cases noteably running stable diffusion on a google colab.

General design:

We will make loosen the forced PyTorch dependency and instead force the user to either install PyTorch or JAX. Then we will mirror the following "base" classes to be JAX compatible:

ModelMixin: patil-suraj/stable-diffusion-jax#10 we should add a FlaxModelMixin class here.
FlaxDiffusionPipeline:

class DiffusionPipeline(ConfigMixin):
we should add a FlaxDiffusionPipeline here.

Note: ModelMixin should be made state-less by default. E.g. weights will not be saved. Also contrary to transformers should we maybe only work with flax.linen.Module classes here @patil-suraj - I don't really think we need the UNetConditionModel and UNetConditionModule design here - we could just go for class UNetConditionModel(nn.Module): here and make sure everything stays stateless no?

TODO:

Happy to take over 1. and finish today and then look into 4. once 3. is done.

@mishig25 do you want to do 2.? (happy to guide you here a bit if you have questions. Also we need to discuss the design here a bit offline maybe)

  1. & 5. @pcuenca do you want to take this? (think 3. is more important here)

The other parts we can see tomorrow maybe :-)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions