Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b3f6343
add flax img2img pipeline
dhruvrnaik Nov 21, 2022
e637e83
update pipeline
dhruvrnaik Nov 21, 2022
30293cc
Merge branch 'main' into flax-img2img
dhruvrnaik Nov 21, 2022
20b6ce3
Merge branch 'main' into flax-img2img
dhruvrnaik Nov 22, 2022
e80efe9
black format file
dhruvrnaik Nov 22, 2022
9603d75
remove argg from get_timesteps
dhruvrnaik Nov 22, 2022
a1b27ef
update get_timesteps
dhruvrnaik Nov 22, 2022
7122386
fix bug: make use of timesteps for for_loop
dhruvrnaik Nov 25, 2022
72fdb95
Merge branch 'main' into flax-img2img
dhruvrnaik Nov 25, 2022
727fa1d
black file
dhruvrnaik Nov 25, 2022
e7d6687
black, isort, flake8
dhruvrnaik Nov 25, 2022
c8787c8
update docstring
dhruvrnaik Nov 26, 2022
94a3e93
update readme
dhruvrnaik Dec 3, 2022
7696bf7
update flax img2img readme
dhruvrnaik Dec 3, 2022
c5a4275
update sd pipeline init
dhruvrnaik Dec 3, 2022
83f2f77
Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_…
dhruvrnaik Dec 5, 2022
bc4abd6
Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_…
dhruvrnaik Dec 6, 2022
ff2be25
update inits
dhruvrnaik Dec 6, 2022
6cec0b8
revert change
dhruvrnaik Dec 6, 2022
e124ed9
update var name to image, typo
dhruvrnaik Dec 11, 2022
c1e6996
update readme
dhruvrnaik Dec 12, 2022
932a74e
return new t_start instead of modified timestep
dhruvrnaik Dec 12, 2022
6602ec3
Merge branch 'main' into flax-img2img
dhruvrnaik Dec 12, 2022
e8dff83
black format
dhruvrnaik Dec 12, 2022
d06f8c3
isort files
dhruvrnaik Dec 12, 2022
b9fa9b7
update docs
dhruvrnaik Dec 12, 2022
fb0af34
fix-copies
dhruvrnaik Dec 19, 2022
cb88c46
update prng_seed typing
dhruvrnaik Dec 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,55 @@ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
```

Diffusers also has a Image-to-Image generation pipeline with Flax/Jax
```python
import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
import requests
from io import BytesIO
from PIL import Image
from diffusers import FlaxStableDiffusionImg2ImgPipeline

def create_key(seed=0):
return jax.random.PRNGKey(seed)
rng = create_key(0)

url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
init_img = Image.open(BytesIO(response.content)).convert("RGB")
init_img = init_img.resize((768, 512))

prompts = "A fantasy landscape, trending on artstation"

pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="flax",
dtype=jnp.bfloat16,
)

num_samples = jax.device_count()
rng = jax.random.split(rng, jax.device_count())
prompt_ids, processed_image = pipeline.prepare_inputs(prompt=[prompts]*num_samples, image = [init_img]*num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
processed_image = shard(processed_image)

output = pipeline(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
strength=0.75,
num_inference_steps=50,
jit=True,
height=512,
width=768).images

output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
```

### Image-to-Image text-guided generation with Stable Diffusion

The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,11 @@
FlaxScoreSdeVeScheduler,
)


try:
if not (is_flax_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .pipelines import FlaxStableDiffusionPipeline
from .pipelines import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@
except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .stable_diffusion import FlaxStableDiffusionPipeline
from .stable_diffusion import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline
1 change: 1 addition & 0 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,5 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput):

from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
Loading