Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions src/diffusers/models/vae_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ class FlaxAutoencoderKLOutput(BaseOutput):
Output of AutoencoderKL encoding method.

Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
latent_dist (`FlaxDiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`.
`FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""

latent_dist: "DiagonalGaussianDistribution"
latent_dist: "FlaxDiagonalGaussianDistribution"


class Upsample2D(nn.Module):
class FlaxUpsample2D(nn.Module):
in_channels: int
dtype: jnp.dtype = jnp.float32

Expand All @@ -66,7 +66,7 @@ def __call__(self, hidden_states):
return hidden_states


class Downsample2D(nn.Module):
class FlaxDownsample2D(nn.Module):
in_channels: int
dtype: jnp.dtype = jnp.float32

Expand All @@ -86,7 +86,7 @@ def __call__(self, hidden_states):
return hidden_states


class ResnetBlock2D(nn.Module):
class FlaxResnetBlock2D(nn.Module):
in_channels: int
out_channels: int = None
dropout_prob: float = 0.0
Expand Down Expand Up @@ -144,7 +144,7 @@ def __call__(self, hidden_states, deterministic=True):
return hidden_states + residual


class AttentionBlock(nn.Module):
class FlaxAttentionBlock(nn.Module):
channels: int
num_head_channels: int = None
dtype: jnp.dtype = jnp.float32
Expand Down Expand Up @@ -201,7 +201,7 @@ def __call__(self, hidden_states):
return hidden_states


class DownEncoderBlock2D(nn.Module):
class FlaxDownEncoderBlock2D(nn.Module):
in_channels: int
out_channels: int
dropout: float = 0.0
Expand All @@ -214,7 +214,7 @@ def setup(self):
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels

res_block = ResnetBlock2D(
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
Expand All @@ -224,19 +224,19 @@ def setup(self):
self.resnets = resnets

if self.add_downsample:
self.downsample = Downsample2D(self.out_channels, dtype=self.dtype)
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)

def __call__(self, hidden_states, deterministic=True):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, deterministic=deterministic)

if self.add_downsample:
hidden_states = self.downsample(hidden_states)
hidden_states = self.downsamplers_0(hidden_states)

return hidden_states


class UpEncoderBlock2D(nn.Module):
class FlaxUpEncoderBlock2D(nn.Module):
in_channels: int
out_channels: int
dropout: float = 0.0
Expand All @@ -248,7 +248,7 @@ def setup(self):
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = ResnetBlock2D(
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
Expand All @@ -259,19 +259,19 @@ def setup(self):
self.resnets = resnets

if self.add_upsample:
self.upsample = Upsample2D(self.out_channels, dtype=self.dtype)
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)

def __call__(self, hidden_states, deterministic=True):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, deterministic=deterministic)

if self.add_upsample:
hidden_states = self.upsample(hidden_states)
hidden_states = self.upsamplers_0(hidden_states)

return hidden_states


class UNetMidBlock2D(nn.Module):
class FlaxUNetMidBlock2D(nn.Module):
in_channels: int
dropout: float = 0.0
num_layers: int = 1
Expand All @@ -281,7 +281,7 @@ class UNetMidBlock2D(nn.Module):
def setup(self):
# there is always at least one resnet
resnets = [
ResnetBlock2D(
FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
Expand All @@ -292,12 +292,12 @@ def setup(self):
attentions = []

for _ in range(self.num_layers):
attn_block = AttentionBlock(
attn_block = FlaxAttentionBlock(
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
)
attentions.append(attn_block)

res_block = ResnetBlock2D(
res_block = FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
Expand All @@ -317,7 +317,7 @@ def __call__(self, hidden_states, deterministic=True):
return hidden_states


class Encoder(nn.Module):
class FlaxEncoder(nn.Module):
in_channels: int = 3
out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
Expand Down Expand Up @@ -347,7 +347,7 @@ def setup(self):
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1

down_block = DownEncoderBlock2D(
down_block = FlaxDownEncoderBlock2D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=self.layers_per_block,
Expand All @@ -358,7 +358,7 @@ def setup(self):
self.down_blocks = down_blocks

# middle
self.mid_block = UNetMidBlock2D(
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
)

Expand Down Expand Up @@ -392,7 +392,7 @@ def __call__(self, sample, deterministic: bool = True):
return sample


class Decoder(nn.Module):
class FlaxDecoder(nn.Module):
dtype: jnp.dtype = jnp.float32
in_channels: int = 3
out_channels: int = 3
Expand All @@ -415,7 +415,7 @@ def setup(self):
)

# middle
self.mid_block = UNetMidBlock2D(
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
)

Expand All @@ -429,7 +429,7 @@ def setup(self):

is_final_block = i == len(block_out_channels) - 1

up_block = UpEncoderBlock2D(
up_block = FlaxUpEncoderBlock2D(
in_channels=prev_output_channel,
out_channels=output_channel,
num_layers=self.layers_per_block + 1,
Expand Down Expand Up @@ -469,7 +469,7 @@ def __call__(self, sample, deterministic: bool = True):
return sample


class DiagonalGaussianDistribution(object):
class FlaxDiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
# Last axis to account for channels-last
self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
Expand Down Expand Up @@ -521,7 +521,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype: jnp.dtype = jnp.float32

def setup(self):
self.encoder = Encoder(
self.encoder = FlaxEncoder(
in_channels=self.config.in_channels,
out_channels=self.config.latent_channels,
down_block_types=self.config.down_block_types,
Expand All @@ -532,7 +532,7 @@ def setup(self):
double_z=True,
dtype=self.dtype,
)
self.decoder = Decoder(
self.decoder = FlaxDecoder(
in_channels=self.config.latent_channels,
out_channels=self.config.out_channels,
up_block_types=self.config.up_block_types,
Expand Down Expand Up @@ -572,7 +572,7 @@ def encode(self, sample, deterministic: bool = True, return_dict: bool = True):

hidden_states = self.encoder(sample, deterministic=deterministic)
moments = self.quant_conv(hidden_states)
posterior = DiagonalGaussianDistribution(moments)
posterior = FlaxDiagonalGaussianDistribution(moments)

if not return_dict:
return (posterior,)
Expand Down