diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index 525548e7c302..c92fdccb8333 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -45,3 +45,21 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## AutoencoderKL [[autodoc]] AutoencoderKL + +## FlaxModelMixin +[[autodoc]] FlaxModelMixin + +## FlaxUNet2DConditionOutput +[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput + +## FlaxUNet2DConditionModel +[[autodoc]] FlaxUNet2DConditionModel + +## FlaxDecoderOutput +[[autodoc]] models.vae_flax.FlaxDecoderOutput + +## FlaxAutoencoderKLOutput +[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput + +## FlaxAutoencoderKL +[[autodoc]] FlaxAutoencoderKL diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 461fb8b0ac33..1745265b91e1 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -17,6 +17,22 @@ class FlaxAttentionBlock(nn.Module): + r""" + A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 + + Parameters: + query_dim (:obj:`int`): + Input hidden states dimension + heads (:obj:`int`, *optional*, defaults to 8): + Number of heads + dim_head (:obj:`int`, *optional*, defaults to 64): + Hidden states dimension inside each head + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + + """ query_dim: int heads: int = 8 dim_head: int = 64 @@ -74,6 +90,23 @@ def __call__(self, hidden_states, context=None, deterministic=True): class FlaxBasicTransformerBlock(nn.Module): + r""" + A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: + https://arxiv.org/abs/1706.03762 + + + Parameters: + dim (:obj:`int`): + Inner hidden states dimension + n_heads (:obj:`int`): + Number of heads + d_head (:obj:`int`): + Hidden states dimension inside each head + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ dim: int n_heads: int d_head: int @@ -110,6 +143,25 @@ def __call__(self, hidden_states, context, deterministic=True): class FlaxSpatialTransformer(nn.Module): + r""" + A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: + https://arxiv.org/pdf/1506.02025.pdf + + + Parameters: + in_channels (:obj:`int`): + Input number of channels + n_heads (:obj:`int`): + Number of heads + d_head (:obj:`int`): + Hidden states dimension inside each head + depth (:obj:`int`, *optional*, defaults to 1): + Number of transformers block + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int n_heads: int d_head: int @@ -162,6 +214,18 @@ def __call__(self, hidden_states, context, deterministic=True): class FlaxGluFeedForward(nn.Module): + r""" + Flax module that encapsulates two Linear layers separated by a gated linear unit activation from: + https://arxiv.org/abs/2002.05202 + + Parameters: + dim (:obj:`int`): + Inner hidden states dimension + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ dim: int dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 @@ -179,6 +243,18 @@ def __call__(self, hidden_states, deterministic=True): class FlaxGEGLU(nn.Module): + r""" + Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from + https://arxiv.org/abs/2002.05202. + + Parameters: + dim (:obj:`int`): + Input hidden states dimension + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ dim: int dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index 50ccc2386c92..e2d607499c79 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -37,6 +37,15 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1): class FlaxTimestepEmbedding(nn.Module): + r""" + Time step Embedding Module. Learns embeddings for input time steps. + + Args: + time_embed_dim (`int`, *optional*, defaults to `32`): + Time step embedding dimension + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ time_embed_dim: int = 32 dtype: jnp.dtype = jnp.float32 @@ -49,6 +58,13 @@ def __call__(self, temb): class FlaxTimesteps(nn.Module): + r""" + Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 + + Args: + dim (`int`, *optional*, defaults to `32`): + Time step embedding dimension + """ dim: int = 32 freq_shift: float = 1 diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index cab229f21803..9e84da068757 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -39,10 +39,23 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library implements for all the models (such as downloading or saving, etc.) + Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + Parameters: - sample_size (`int`, *optional*): The size of the input sample. - in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + sample_size (`int`, *optional*): + The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): + The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): + The number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" @@ -51,10 +64,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D" block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. - dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + attention_head_dim (`int`, *optional*, defaults to 8): + The dimension of the attention heads. + cross_attention_dim (`int`, *optional*, defaults to 768): + The dimension of the cross attention features. + dropout (`float`, *optional*, defaults to 0): + Dropout probability for down, up and bottleneck blocks. """ sample_size: int = 32 diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py index fe6cc3b194e3..c23b7844c3fd 100644 --- a/src/diffusers/models/unet_blocks_flax.py +++ b/src/diffusers/models/unet_blocks_flax.py @@ -19,6 +19,26 @@ class FlaxCrossAttnDownBlock2D(nn.Module): + r""" + Cross Attention 2D Downsizing block - original architecture from Unet transformers: + https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int out_channels: int dropout: float = 0.0 @@ -73,6 +93,23 @@ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=Tru class FlaxDownBlock2D(nn.Module): + r""" + Flax 2D downsizing block + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int out_channels: int dropout: float = 0.0 @@ -113,6 +150,26 @@ def __call__(self, hidden_states, temb, deterministic=True): class FlaxCrossAttnUpBlock2D(nn.Module): + r""" + Cross Attention 2D Upsampling block - original architecture from Unet transformers: + https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + add_upsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add upsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int out_channels: int prev_output_channel: int @@ -170,6 +227,25 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_ class FlaxUpBlock2D(nn.Module): + r""" + Flax 2D upsampling block + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + prev_output_channel (:obj:`int`): + Output channels from the previous block + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int out_channels: int prev_output_channel: int @@ -214,6 +290,21 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=T class FlaxUNetMidBlock2DCrossAttn(nn.Module): + r""" + Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int dropout: float = 0.0 num_layers: int = 1 diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index 1471e715bc7c..b3261b11cf6c 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -23,6 +23,8 @@ class FlaxDecoderOutput(BaseOutput): Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): Decoded output sample of the model. Output of the last layer of the model. + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ sample: jnp.ndarray @@ -43,6 +45,16 @@ class FlaxAutoencoderKLOutput(BaseOutput): class FlaxUpsample2D(nn.Module): + """ + Flax implementation of 2D Upsample layer + + Args: + in_channels (`int`): + Input channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int dtype: jnp.dtype = jnp.float32 @@ -67,6 +79,16 @@ def __call__(self, hidden_states): class FlaxDownsample2D(nn.Module): + """ + Flax implementation of 2D Downsample layer + + Args: + in_channels (`int`): + Input channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int dtype: jnp.dtype = jnp.float32 @@ -87,6 +109,22 @@ def __call__(self, hidden_states): class FlaxResnetBlock2D(nn.Module): + """ + Flax implementation of 2D Resnet Block. + + Args: + in_channels (`int`): + Input channels + out_channels (`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): + Whether to use `nin_shortcut`. This activates a new layer inside ResNet block + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int out_channels: int = None dropout: float = 0.0 @@ -145,6 +183,18 @@ def __call__(self, hidden_states, deterministic=True): class FlaxAttentionBlock(nn.Module): + r""" + Flax Convolutional based multi-head attention block for diffusion-based VAE. + + Parameters: + channels (:obj:`int`): + Input channels + num_head_channels (:obj:`int`, *optional*, defaults to `None`): + Number of attention heads + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + + """ channels: int num_head_channels: int = None dtype: jnp.dtype = jnp.float32 @@ -202,6 +252,23 @@ def __call__(self, hidden_states): class FlaxDownEncoderBlock2D(nn.Module): + r""" + Flax Resnet blocks-based Encoder block for diffusion-based VAE. + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsample layer + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int out_channels: int dropout: float = 0.0 @@ -237,6 +304,23 @@ def __call__(self, hidden_states, deterministic=True): class FlaxUpEncoderBlock2D(nn.Module): + r""" + Flax Resnet blocks-based Encoder block for diffusion-based VAE. + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsample layer + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int out_channels: int dropout: float = 0.0 @@ -272,6 +356,21 @@ def __call__(self, hidden_states, deterministic=True): class FlaxUNetMidBlock2D(nn.Module): + r""" + Flax Unet Mid-Block module. + + Parameters: + in_channels (:obj:`int`): + Input channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`): + Number of attention heads for each attention block + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int dropout: float = 0.0 num_layers: int = 1 @@ -318,6 +417,39 @@ def __call__(self, hidden_states, deterministic=True): class FlaxEncoder(nn.Module): + r""" + Flax Implementation of VAE Encoder. + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (:obj:`int`, *optional*, defaults to 3): + Input channels + out_channels (:obj:`int`, *optional*, defaults to 3): + Output channels + down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): + DownEncoder block type + block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple containing the number of output channels for each block + layers_per_block (:obj:`int`, *optional*, defaults to `2`): + Number of Resnet layer for each block + norm_num_groups (:obj:`int`, *optional*, defaults to `2`): + norm num group + act_fn (:obj:`str`, *optional*, defaults to `silu`): + Activation function + double_z (:obj:`bool`, *optional*, defaults to `False`): + Whether to double the last output channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",) @@ -393,7 +525,39 @@ def __call__(self, sample, deterministic: bool = True): class FlaxDecoder(nn.Module): - dtype: jnp.dtype = jnp.float32 + r""" + Flax Implementation of VAE Decoder. + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (:obj:`int`, *optional*, defaults to 3): + Input channels + out_channels (:obj:`int`, *optional*, defaults to 3): + Output channels + up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): + UpDecoder block type + block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple containing the number of output channels for each block + layers_per_block (:obj:`int`, *optional*, defaults to `2`): + Number of Resnet layer for each block + norm_num_groups (:obj:`int`, *optional*, defaults to `32`): + norm num group + act_fn (:obj:`str`, *optional*, defaults to `silu`): + Activation function + double_z (:obj:`bool`, *optional*, defaults to `False`): + Whether to double the last output channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + parameters `dtype` + """ in_channels: int = 3 out_channels: int = 3 up_block_types: Tuple[str] = ("UpDecoderBlock2D",) @@ -401,6 +565,7 @@ class FlaxDecoder(nn.Module): layers_per_block: int = 2 norm_num_groups: int = 32 act_fn: str = "silu" + dtype: jnp.dtype = jnp.float32 def setup(self): block_out_channels = self.block_out_channels @@ -508,6 +673,44 @@ def mode(self): @flax_register_to_config class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + Flax Implementation of Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational + Bayes by Diederik P. Kingma and Max Welling. + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (:obj:`int`, *optional*, defaults to 3): + Input channels + out_channels (:obj:`int`, *optional*, defaults to 3): + Output channels + down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): + DownEncoder block type + up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): + UpDecoder block type + block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple containing the number of output channels for each block + layers_per_block (:obj:`int`, *optional*, defaults to `2`): + Number of Resnet layer for each block + act_fn (:obj:`str`, *optional*, defaults to `silu`): + Activation function + latent_channels (:obj:`int`, *optional*, defaults to `4`): + Latent space channels + norm_num_groups (:obj:`int`, *optional*, defaults to `32`): + Norm num group + sample_size (:obj:`int`, *optional*, defaults to `32`): + Sample input size + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + parameters `dtype` + """ in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",)