Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
```

**Note**: If you don't want to use the token, you can also simply download the model weights
Expand All @@ -101,7 +101,7 @@ pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
```

If you are limited by GPU memory, you might want to consider using the model in `fp16`.
Expand All @@ -117,7 +117,7 @@ pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
```

Finally, if you wish to use a different scheduler, you can simply instantiate
Expand All @@ -143,7 +143,7 @@ pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]

image.save("astronaut_rides_horse.png")
```
Expand Down Expand Up @@ -184,7 +184,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"

with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"]
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images

images[0].save("fantasy_landscape.png")
```
Expand Down Expand Up @@ -228,7 +228,7 @@ pipe = pipe.to(device)

prompt = "a cat sitting on a bench"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"]
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images

images[0].save("cat_on_bench.png")
```
Expand Down Expand Up @@ -260,7 +260,7 @@ ldm = DiffusionPipeline.from_pretrained(model_id)

# run pipeline in inference (sample random noise and denoise)
prompt = "A painting of a squirrel eating a burger"
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)["sample"]
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images

# save images
for idx, image in enumerate(images):
Expand All @@ -277,7 +277,7 @@ model_id = "google/ddpm-celebahq-256"
ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference

# run pipeline in inference (sample random noise and denoise)
image = ddpm()["sample"]
image = ddpm().images

# save image
image[0].save("ddpm_generated_image.png")
Expand Down
2 changes: 1 addition & 1 deletion examples/textual_inversion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ pipe = pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch
prompt = "A <cat-toy> backpack"

with autocast("cuda"):
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5)["sample"][0]
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]

image.save("cat-backpack.png")
```
4 changes: 2 additions & 2 deletions examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def main():
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"]).sample().detach()
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
latents = latents * 0.18215

# Sample noise that we'll add to the latents
Expand All @@ -515,7 +515,7 @@ def main():
encoder_hidden_states = text_encoder(batch["input_ids"])[0]

# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)["sample"]
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def transforms(examples):

with accelerator.accumulate(model):
# Predict the noise residual
noise_pred = model(noisy_images, timesteps)["sample"]
noise_pred = model(noisy_images, timesteps).sample
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)

Expand Down Expand Up @@ -174,7 +174,7 @@ def transforms(examples):

generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images

# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
Expand Down
2 changes: 1 addition & 1 deletion scripts/generate_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
logits = model(noise, time_step)["sample"]
logits = model(noise, time_step).sample

assert torch.allclose(
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from pathlib import Path
from typing import Optional

from diffusers import DiffusionPipeline
from huggingface_hub import HfFolder, Repository, whoami

from .pipeline_utils import DiffusionPipeline
from .utils import is_modelcards_available, logging


Expand Down
27 changes: 22 additions & 5 deletions src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block


@dataclass
class UNet2DOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Hidden states output. Output of last layer of model.
"""

sample: torch.FloatTensor


class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
Expand Down Expand Up @@ -118,8 +131,11 @@ def __init__(
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)

def forward(
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
) -> Dict[str, torch.FloatTensor]:
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
Expand Down Expand Up @@ -181,6 +197,7 @@ def forward(
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps

output = {"sample": sample}
if not return_dict:
return (sample,)

return output
return UNet2DOutput(sample=sample)
23 changes: 19 additions & 4 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block


@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""

sample: torch.FloatTensor


class UNet2DConditionModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
Expand Down Expand Up @@ -125,7 +138,8 @@ def forward(
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
) -> Dict[str, torch.FloatTensor]:
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
Expand Down Expand Up @@ -183,6 +197,7 @@ def forward(
sample = self.conv_act(sample)
sample = self.conv_out(sample)

output = {"sample": sample}
if not return_dict:
return (sample,)

return output
return UNet2DConditionOutput(sample=sample)
104 changes: 87 additions & 17 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,56 @@
from typing import Optional, Tuple
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block


@dataclass
class DecoderOutput(BaseOutput):
"""
Output of decoding method.

Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Decoded output sample of the model. Output of the last layer of the model.
"""

sample: torch.FloatTensor


@dataclass
class VQEncoderOutput(BaseOutput):
"""
Output of VQModel encoding method.

Args:
latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Encoded output sample of the model. Output of the last layer of the model.
"""

latents: torch.FloatTensor


@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.

Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""

latent_dist: "DiagonalGaussianDistribution"


class Encoder(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -369,26 +411,40 @@ def __init__(
act_fn=act_fn,
)

def encode(self, x):
def encode(self, x, return_dict: bool = True):
h = self.encoder(x)
h = self.quant_conv(h)
return h

def decode(self, h, force_not_quantize=False):
if not return_dict:
return (h,)

return VQEncoderOutput(latents=h)

def decode(
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec

def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)

def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
x = sample
h = self.encode(x)
dec = self.decode(h)
return dec
h = self.encode(x).latents
dec = self.decode(h).sample

if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)


class AutoencoderKL(ModelMixin, ConfigMixin):
Expand Down Expand Up @@ -431,23 +487,37 @@ def __init__(
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)

def encode(self, x):
def encode(self, x, return_dict: bool = True):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior

def decode(self, z):
if not return_dict:
return (posterior,)

return AutoencoderKLOutput(latent_dist=posterior)

def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec

def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor:
if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)

def forward(
self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
x = sample
posterior = self.encode(x)
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec
dec = self.decode(z).sample

if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)
Loading