diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index 793010e87f67..eba9259b8201 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -34,15 +34,15 @@ class FlaxAutoencoderKLOutput(BaseOutput): Output of AutoencoderKL encoding method. Args: - latent_dist (`DiagonalGaussianDistribution`): - Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. - `DiagonalGaussianDistribution` allows for sampling latents from the distribution. + latent_dist (`FlaxDiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`. + `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution. """ - latent_dist: "DiagonalGaussianDistribution" + latent_dist: "FlaxDiagonalGaussianDistribution" -class Upsample2D(nn.Module): +class FlaxUpsample2D(nn.Module): in_channels: int dtype: jnp.dtype = jnp.float32 @@ -66,7 +66,7 @@ def __call__(self, hidden_states): return hidden_states -class Downsample2D(nn.Module): +class FlaxDownsample2D(nn.Module): in_channels: int dtype: jnp.dtype = jnp.float32 @@ -86,7 +86,7 @@ def __call__(self, hidden_states): return hidden_states -class ResnetBlock2D(nn.Module): +class FlaxResnetBlock2D(nn.Module): in_channels: int out_channels: int = None dropout_prob: float = 0.0 @@ -144,7 +144,7 @@ def __call__(self, hidden_states, deterministic=True): return hidden_states + residual -class AttentionBlock(nn.Module): +class FlaxAttentionBlock(nn.Module): channels: int num_head_channels: int = None dtype: jnp.dtype = jnp.float32 @@ -201,7 +201,7 @@ def __call__(self, hidden_states): return hidden_states -class DownEncoderBlock2D(nn.Module): +class FlaxDownEncoderBlock2D(nn.Module): in_channels: int out_channels: int dropout: float = 0.0 @@ -214,7 +214,7 @@ def setup(self): for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels - res_block = ResnetBlock2D( + res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout_prob=self.dropout, @@ -224,19 +224,19 @@ def setup(self): self.resnets = resnets if self.add_downsample: - self.downsample = Downsample2D(self.out_channels, dtype=self.dtype) + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): for resnet in self.resnets: hidden_states = resnet(hidden_states, deterministic=deterministic) if self.add_downsample: - hidden_states = self.downsample(hidden_states) + hidden_states = self.downsamplers_0(hidden_states) return hidden_states -class UpEncoderBlock2D(nn.Module): +class FlaxUpEncoderBlock2D(nn.Module): in_channels: int out_channels: int dropout: float = 0.0 @@ -248,7 +248,7 @@ def setup(self): resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels - res_block = ResnetBlock2D( + res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout_prob=self.dropout, @@ -259,19 +259,19 @@ def setup(self): self.resnets = resnets if self.add_upsample: - self.upsample = Upsample2D(self.out_channels, dtype=self.dtype) + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): for resnet in self.resnets: hidden_states = resnet(hidden_states, deterministic=deterministic) if self.add_upsample: - hidden_states = self.upsample(hidden_states) + hidden_states = self.upsamplers_0(hidden_states) return hidden_states -class UNetMidBlock2D(nn.Module): +class FlaxUNetMidBlock2D(nn.Module): in_channels: int dropout: float = 0.0 num_layers: int = 1 @@ -281,7 +281,7 @@ class UNetMidBlock2D(nn.Module): def setup(self): # there is always at least one resnet resnets = [ - ResnetBlock2D( + FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout_prob=self.dropout, @@ -292,12 +292,12 @@ def setup(self): attentions = [] for _ in range(self.num_layers): - attn_block = AttentionBlock( + attn_block = FlaxAttentionBlock( channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype ) attentions.append(attn_block) - res_block = ResnetBlock2D( + res_block = FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout_prob=self.dropout, @@ -317,7 +317,7 @@ def __call__(self, hidden_states, deterministic=True): return hidden_states -class Encoder(nn.Module): +class FlaxEncoder(nn.Module): in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",) @@ -347,7 +347,7 @@ def setup(self): output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 - down_block = DownEncoderBlock2D( + down_block = FlaxDownEncoderBlock2D( in_channels=input_channel, out_channels=output_channel, num_layers=self.layers_per_block, @@ -358,7 +358,7 @@ def setup(self): self.down_blocks = down_blocks # middle - self.mid_block = UNetMidBlock2D( + self.mid_block = FlaxUNetMidBlock2D( in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype ) @@ -392,7 +392,7 @@ def __call__(self, sample, deterministic: bool = True): return sample -class Decoder(nn.Module): +class FlaxDecoder(nn.Module): dtype: jnp.dtype = jnp.float32 in_channels: int = 3 out_channels: int = 3 @@ -415,7 +415,7 @@ def setup(self): ) # middle - self.mid_block = UNetMidBlock2D( + self.mid_block = FlaxUNetMidBlock2D( in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype ) @@ -429,7 +429,7 @@ def setup(self): is_final_block = i == len(block_out_channels) - 1 - up_block = UpEncoderBlock2D( + up_block = FlaxUpEncoderBlock2D( in_channels=prev_output_channel, out_channels=output_channel, num_layers=self.layers_per_block + 1, @@ -469,7 +469,7 @@ def __call__(self, sample, deterministic: bool = True): return sample -class DiagonalGaussianDistribution(object): +class FlaxDiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): # Last axis to account for channels-last self.mean, self.logvar = jnp.split(parameters, 2, axis=-1) @@ -521,7 +521,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): dtype: jnp.dtype = jnp.float32 def setup(self): - self.encoder = Encoder( + self.encoder = FlaxEncoder( in_channels=self.config.in_channels, out_channels=self.config.latent_channels, down_block_types=self.config.down_block_types, @@ -532,7 +532,7 @@ def setup(self): double_z=True, dtype=self.dtype, ) - self.decoder = Decoder( + self.decoder = FlaxDecoder( in_channels=self.config.latent_channels, out_channels=self.config.out_channels, up_block_types=self.config.up_block_types, @@ -572,7 +572,7 @@ def encode(self, sample, deterministic: bool = True, return_dict: bool = True): hidden_states = self.encoder(sample, deterministic=deterministic) moments = self.quant_conv(hidden_states) - posterior = DiagonalGaussianDistribution(moments) + posterior = FlaxDiagonalGaussianDistribution(moments) if not return_dict: return (posterior,)