From d672b841e03f80ebb1e3b0f6dd87eca646dc44c1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Apr 2025 03:20:46 +0000 Subject: [PATCH 01/39] add jit-friendly dropout w rate in call --- algoperf/jax_utils.py | 121 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 algoperf/jax_utils.py diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py new file mode 100644 index 000000000..ddafd77c6 --- /dev/null +++ b/algoperf/jax_utils.py @@ -0,0 +1,121 @@ +from collections.abc import Sequence + +import jax +import jax.numpy as jnp +from jax import lax, random + +import flax.linen as nn +from flax.linen.module import Module, compact, merge_param +from flax.typing import PRNGKey + + +# Custom Layers +class Dropout(Module): + """Create a dropout layer. + Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. + The reference dropout implementation is modified support changes to dropout rate during training by: + 1) adding rate argument to the __call__ method + 2) removing the if-else condition to check for edge cases, which will trigger a recompile for jitted code + + .. note:: + When using :meth:`Module.apply() `, make sure + to include an RNG seed named ``'dropout'``. Dropout isn't necessary for + variable initialization. + + Example usage:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> class MLP(nn.Module): + ... @nn.compact + ... def __call__(self, x, train): + ... x = nn.Dense(4)(x) + ... x = nn.Dropout(0.5, deterministic=not train)(x) + ... return x + + >>> model = MLP() + >>> x = jnp.ones((1, 3)) + >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout + >>> model.apply(variables, x, train=False) # don't use dropout + Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32) + >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout + Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32) + + Attributes: + rate: the dropout probability. (_not_ the keep rate!) + broadcast_dims: dimensions that will share the same dropout mask + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and + masked, whereas if true, no mask is applied and the inputs are returned as + is. + rng_collection: the rng collection name to use when requesting an rng key. + """ + + rate: float | None = None + broadcast_dims: Sequence[int] = () + deterministic: bool | None = None + rng_collection: str = "dropout" + legacy: bool = True + + @compact + def __call__( + self, + inputs, + deterministic: bool | None = None, + rate: float | None = None, + rng: PRNGKey | None = None, + ): + """Applies a random dropout mask to the input. + + Args: + inputs: the inputs that should be randomly masked. + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and + masked, whereas if true, no mask is applied and the inputs are returned + as is. + rate: the dropout probability. (_not_ the keep rate!) + rng: an optional PRNGKey used as the random key, if not specified, one + will be generated using ``make_rng`` with the ``rng_collection`` name. + + Returns: + The masked inputs reweighted to preserve mean. + """ + deterministic = merge_param("deterministic", self.deterministic, deterministic) + + # Override self.rate if rate is passed to __call__ + if not (self.rate is not None and rate is not None): + rate = merge_param("rate", self.rate, rate) + + if self.legacy: + if rate == 0.0: + return inputs + + # Prevent gradient NaNs in 1.0 edge-case. + if rate == 1.0: + return jnp.zeros_like(inputs) + + if deterministic: + return inputs + + keep_prob = 1.0 - rate + if rng is None: + rng = self.make_rng(self.rng_collection) + broadcast_shape = list(inputs.shape) + for dim in self.broadcast_dims: + broadcast_shape[dim] = 1 + mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) + mask = jnp.broadcast_to(mask, inputs.shape) + + return lax.select( + mask, jnp.nan_to_num(inputs / keep_prob), jnp.zeros_like(inputs) + ) + + +# Utilities for debugging +def print_jax_model_summary(model, fake_inputs): + """Prints a summary of the jax module.""" + tabulate_fn = nn.tabulate( + model, + jax.random.PRNGKey(0), + console_kwargs={"force_terminal": False, "force_jupyter": False, "width": 240}, + ) + print(tabulate_fn(fake_inputs, train=False)) From aa25e208ef7f283d2ee3d4ebc304e45d6c24fa8f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 5 May 2025 08:28:01 +0000 Subject: [PATCH 02/39] remove nan_to_num convertion --- algoperf/jax_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py index ddafd77c6..3ca3f1bfc 100644 --- a/algoperf/jax_utils.py +++ b/algoperf/jax_utils.py @@ -104,10 +104,7 @@ def __call__( broadcast_shape[dim] = 1 mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) mask = jnp.broadcast_to(mask, inputs.shape) - - return lax.select( - mask, jnp.nan_to_num(inputs / keep_prob), jnp.zeros_like(inputs) - ) + return lax.select(mask, inputs, jnp.zeros_like(inputs)) # Utilities for debugging From 85a35785a523ff44a1fca93e41d3cb082d72d4c2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 5 May 2025 09:35:39 +0000 Subject: [PATCH 03/39] update models with custom dropout layer --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 6 +++--- algoperf/workloads/fastmri/fastmri_jax/models.py | 5 +++-- .../imagenet_vit/imagenet_jax/models.py | 13 +++++++------ .../librispeech_jax/models.py | 11 ++++++----- .../librispeech_jax/models.py | 5 +++-- algoperf/workloads/ogbg/ogbg_jax/models.py | 4 +++- algoperf/workloads/wmt/wmt_jax/models.py | 16 +++++++++------- 7 files changed, 34 insertions(+), 26 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index 6d9a489ff..e89db0c86 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -1,11 +1,11 @@ """A JAX implementation of DLRM-Small.""" - from typing import Sequence import flax.linen as nn from jax import nn as jnn import jax.numpy as jnp +from algoperf.jax_utils import Dropout class DLRMResNet(nn.Module): """Define a DLRMResNet model. @@ -89,7 +89,7 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input) x = nn.relu(x) if self.dropout_rate and layer_idx == num_layers_top - 2: - x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) + x = Dropout(rate=self.dropout_rate, deterministic=not train)(x) top_mlp_input += x # In the DLRM model the last layer width is always 1. We can hardcode that # below. @@ -212,7 +212,7 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input = nn.LayerNorm()(top_mlp_input) if (self.dropout_rate is not None and self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2): - top_mlp_input = nn.Dropout( + top_mlp_input = Dropout( rate=self.dropout_rate, deterministic=not train)( top_mlp_input) logits = top_mlp_input diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index 44bff0e21..f29e0be22 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp +from algoperf.jax_utils import Dropout def _instance_norm2d(x, axes, epsilon=1e-5): # promote x to at least float32, this avoids half precision computation @@ -172,7 +173,7 @@ def __call__(self, x, train=True): x = activation_fn(x) # Ref code uses dropout2d which applies the same mask for the entire channel # Replicated by using broadcast dims to have the same filter on HW - x = nn.Dropout( + x = Dropout( self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( x) x = nn.Conv( @@ -186,7 +187,7 @@ def __call__(self, x, train=True): else: x = _instance_norm2d(x, (1, 2)) x = activation_fn(x) - x = nn.Dropout( + x = Dropout( self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( x) return x diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 7ce3a0395..902658fbe 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -11,6 +11,7 @@ import jax.numpy as jnp from algoperf import spec +from algoperf.jax_utils import Dropout def posemb_sincos_2d(h: int, @@ -53,7 +54,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: y = nn.Dense(self.mlp_dim, **inits)(x) x = x * y - x = nn.Dropout(rate=self.dropout_rate)(x, train) + x = Dropout(rate=self.dropout_rate)(x, train) x = nn.Dense(d, **inits)(x) return x @@ -76,7 +77,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + y = Dropout(rate=self.dropout_rate)(y, train) x = x + y y = nn.LayerNorm(name='LayerNorm_2')(x) @@ -85,7 +86,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: use_glu=self.use_glu, dropout_rate=self.dropout_rate, name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + y = Dropout(rate=self.dropout_rate)(y, train) x = x + y else: y = x @@ -95,7 +96,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + y = Dropout(rate=self.dropout_rate)(y, train) x = x + y x = nn.LayerNorm(name='LayerNorm_0')(x) @@ -105,7 +106,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: use_glu=self.use_glu, dropout_rate=self.dropout_rate, name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + y = Dropout(rate=self.dropout_rate)(y, train) x = x + y x = nn.LayerNorm(name='LayerNorm_2')(x) @@ -205,7 +206,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: dropout_rate = self.dropout_rate if dropout_rate is None: dropout_rate = 0.0 - x = nn.Dropout(rate=dropout_rate)(x, not train) + x = Dropout(rate=dropout_rate)(x, not train) x = Encoder( depth=self.depth, diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 593d463c3..51e93acc9 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -26,6 +26,7 @@ librispeech_preprocessor as preprocessor from algoperf.workloads.librispeech_conformer.librispeech_jax import \ spectrum_augmenter +from algoperf.jax_utils import Dropout @struct.dataclass @@ -129,7 +130,7 @@ def __call__(self, inputs, input_paddings, train): outputs = outputs + AddPositionalEmbedding(embedding_dim=self.encoder_dim)( seq_length=outputs.shape[1]) - outputs = nn.Dropout( + outputs = Dropout( rate=self.input_dropout_rate, deterministic=not train)( outputs) @@ -217,7 +218,7 @@ def __call__(self, inputs, padding_mask=None, train=False): 'config.activation_function_name values, recieved ' f'{config.activation_function_name}') inputs = activation_fn(inputs) - inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)( + inputs = Dropout(rate=config.feed_forward_dropout_rate)( inputs, deterministic=not train) inputs = inputs * padding_mask @@ -234,7 +235,7 @@ def __call__(self, inputs, padding_mask=None, train=False): else: feed_forward_residual_dropout_rate = ( config.feed_forward_residual_dropout_rate) - inputs = nn.Dropout(rate=feed_forward_residual_dropout_rate)( + inputs = Dropout(rate=feed_forward_residual_dropout_rate)( inputs, deterministic=not train) return inputs @@ -416,7 +417,7 @@ def __call__(self, inputs, paddings, train): attention_residual_dropout_rate = 0.1 else: attention_residual_dropout_rate = config.attention_residual_dropout_rate - result = nn.Dropout( + result = Dropout( rate=attention_residual_dropout_rate, deterministic=not train)( result) @@ -578,7 +579,7 @@ def __call__(self, conv_residual_dropout_rate = 0.0 else: conv_residual_dropout_rate = config.conv_residual_dropout_rate - inputs = nn.Dropout( + inputs = Dropout( rate=conv_residual_dropout_rate, deterministic=not train)( inputs) return inputs diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index b116f44cd..f937a1692 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -20,6 +20,7 @@ librispeech_preprocessor as preprocessor from algoperf.workloads.librispeech_conformer.librispeech_jax import \ spectrum_augmenter +from algoperf.jax_utils import Dropout Array = jnp.ndarray StateType = Union[Array, Tuple[Array, ...]] @@ -110,7 +111,7 @@ def __call__(self, inputs, output_paddings, train): input_dropout_rate = 0.1 else: input_dropout_rate = config.input_dropout_rate - outputs = nn.Dropout( + outputs = Dropout( rate=input_dropout_rate, deterministic=not train)( outputs) @@ -216,7 +217,7 @@ def __call__(self, inputs, input_paddings=None, train=False): feed_forward_dropout_rate = 0.1 else: feed_forward_dropout_rate = config.feed_forward_dropout_rate - inputs = nn.Dropout(rate=feed_forward_dropout_rate)( + inputs = Dropout(rate=feed_forward_dropout_rate)( inputs, deterministic=not train) return inputs diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 0e66d2ab8..f5710a3ab 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -6,6 +6,8 @@ import jax.numpy as jnp import jraph +from algoperf.jax_utils import Dropout + def _make_embed(latent_dim, name): @@ -50,7 +52,7 @@ def __call__(self, graph, train): dropout_rate = 0.1 else: dropout_rate = self.dropout_rate - dropout = nn.Dropout(rate=dropout_rate, deterministic=not train) + dropout = Dropout(rate=dropout_rate, deterministic=not train) graph = graph._replace( globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index 97fee032f..04f46e8ac 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -11,6 +11,8 @@ import jax.numpy as jnp import numpy as np +from algoperf.jax_utils import Dropout + @struct.dataclass class TransformerConfig: @@ -172,14 +174,14 @@ def __call__(self, inputs): dropout_rate = 0.1 else: dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) output = nn.Dense( actual_out_dim, dtype=cfg.dtype, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init)( x) - output = nn.Dropout(rate=dropout_rate)( + output = Dropout(rate=dropout_rate)( output, deterministic=cfg.deterministic) return output @@ -229,7 +231,7 @@ def __call__(self, inputs, encoder_mask=None): dropout_rate = 0.1 else: dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) x = x + inputs if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -293,7 +295,7 @@ def __call__(self, dropout_rate = 0.1 else: dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) x = x + targets if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -312,7 +314,7 @@ def __call__(self, deterministic=cfg.deterministic)( cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) - y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) + y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) y = y + x if not pre_ln: y = nn.LayerNorm(dtype=cfg.dtype)(y) @@ -366,7 +368,7 @@ def __call__(self, inputs, inputs_positions=None, encoder_mask=None): dropout_rate = 0.1 else: dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) x = x.astype(cfg.dtype) @@ -436,7 +438,7 @@ def __call__(self, dropout_rate = 0.1 else: dropout_rate = cfg.dropout_rate - y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) + y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) y = y.astype(cfg.dtype) From 9354079c1a97b51fe722069b29608953f58aa107 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 5 May 2025 10:36:48 +0000 Subject: [PATCH 04/39] add functional dropout for criteo, fastmri, and vit --- .../criteo1tb/criteo1tb_jax/models.py | 21 +++++---- .../workloads/fastmri/fastmri_jax/models.py | 20 ++++----- .../imagenet_vit/imagenet_jax/models.py | 45 ++++++++++--------- 3 files changed, 48 insertions(+), 38 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index e89db0c86..c56748bb1 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -28,7 +28,10 @@ class DLRMResNet(nn.Module): embedding_init_multiplier: float = None # Unused @nn.compact - def __call__(self, x, train): + def __call__(self, x, train, dropout_rate=None): + if not dropout_rate: + dropout_rate=self.dropout_rate + bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -88,8 +91,8 @@ def scaled_init(key, shape, dtype=jnp.float_): stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))( top_mlp_input) x = nn.relu(x) - if self.dropout_rate and layer_idx == num_layers_top - 2: - x = Dropout(rate=self.dropout_rate, deterministic=not train)(x) + if dropout_rate and layer_idx == num_layers_top - 2: + x = Dropout(deterministic=not train)(x, rate=dropout_rate) top_mlp_input += x # In the DLRM model the last layer width is always 1. We can hardcode that # below. @@ -151,7 +154,10 @@ class DlrmSmall(nn.Module): embedding_init_multiplier: float = None @nn.compact - def __call__(self, x, train): + def __call__(self, x, train, dropout_rate=None): + if not dropout_rate: + dropout_rate = self.dropout_rate + bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -210,10 +216,9 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input = nn.relu(top_mlp_input) if self.use_layer_norm: top_mlp_input = nn.LayerNorm()(top_mlp_input) - if (self.dropout_rate is not None and self.dropout_rate > 0.0 and + if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): - top_mlp_input = Dropout( - rate=self.dropout_rate, deterministic=not train)( - top_mlp_input) + top_mlp_input = Dropout(deterministic=not train)( + top_mlp_input, rate=dropout_rate) logits = top_mlp_input return logits diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index f29e0be22..3d5460c18 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -62,10 +62,9 @@ class UNet(nn.Module): use_layer_norm: bool = False @nn.compact - def __call__(self, x, train=True): - dropout_rate = self.dropout_rate - if dropout_rate is None: - dropout_rate = 0.0 + def __call__(self, x, train=True, dropout_rate=None): + if not dropout_rate: + dropout_rate = self.dropout_rate # pylint: disable=invalid-name _ConvBlock = functools.partial( @@ -144,7 +143,7 @@ class ConvBlock(nn.Module): use_layer_norm: bool @nn.compact - def __call__(self, x, train=True): + def __call__(self, x, train=True, dropout_rate=None): """Forward function. Note: Pytorch is NCHW and jax/flax is NHWC. Args: @@ -153,6 +152,8 @@ def __call__(self, x, train=True): Returns: jnp.array: Output tensor of shape `(N, H, W, out_channels)`. """ + if not dropout_rate: + dropout_rate=self.dropout_rate x = nn.Conv( features=self.out_channels, kernel_size=(3, 3), @@ -173,9 +174,8 @@ def __call__(self, x, train=True): x = activation_fn(x) # Ref code uses dropout2d which applies the same mask for the entire channel # Replicated by using broadcast dims to have the same filter on HW - x = Dropout( - self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( - x) + x = Dropout(broadcast_dims=(1, 2), deterministic=not train)( + x, rate=dropout_rate ) x = nn.Conv( features=self.out_channels, kernel_size=(3, 3), @@ -188,8 +188,8 @@ def __call__(self, x, train=True): x = _instance_norm2d(x, (1, 2)) x = activation_fn(x) x = Dropout( - self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( - x) + broadcast_dims=(1, 2), deterministic=not train)( + x, rate=dropout_rate) return x diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 902658fbe..10a90f37d 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -39,8 +39,11 @@ class MlpBlock(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spec.Tensor: """Applies Transformer MlpBlock module.""" + if not dropout_rate: + dropout_rate = self.dropout_rate + inits = { 'kernel_init': nn.initializers.xavier_uniform(), 'bias_init': nn.initializers.normal(stddev=1e-6), @@ -54,7 +57,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: y = nn.Dense(self.mlp_dim, **inits)(x) x = x * y - x = Dropout(rate=self.dropout_rate)(x, train) + x = Dropout()(x, train, rate=dropout_rate) x = nn.Dense(d, **inits)(x) return x @@ -68,7 +71,10 @@ class Encoder1DBlock(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate) -> spec.Tensor: + if not dropout_rate: + dropout_rate=self.dropout_rate + if not self.use_post_layer_norm: y = nn.LayerNorm(name='LayerNorm_0')(x) y = nn.MultiHeadDotProductAttention( @@ -77,16 +83,15 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = Dropout(rate=self.dropout_rate)(y, train) + y = Dropout()(y, train, dropout_rate=dropout_rate) x = x + y y = nn.LayerNorm(name='LayerNorm_2')(x) y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, - dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = Dropout(rate=self.dropout_rate)(y, train) + name='MlpBlock_3')(y, train, dropout_rate=dropout_rate) + y = Dropout()(y, train, rate=dropout_rate) x = x + y else: y = x @@ -96,7 +101,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = Dropout(rate=self.dropout_rate)(y, train) + y = Dropout()(y, train, rate=dropout_rate) x = x + y x = nn.LayerNorm(name='LayerNorm_0')(x) @@ -104,9 +109,8 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, - dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = Dropout(rate=self.dropout_rate)(y, train) + name='MlpBlock_3')(y, train, dropout_rate=dropout_rate) + y = Dropout()(y, train)(rate=dropout_rate) x = x + y x = nn.LayerNorm(name='LayerNorm_2')(x) @@ -123,7 +127,10 @@ class Encoder(nn.Module): use_post_layer_norm: bool = False @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spec.Tensor: + if not dropout_rate: + dropout_rate=self.dropout_rate + # Input Encoder for lyr in range(self.depth): block = Encoder1DBlock( @@ -132,7 +139,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=self.dropout_rate) + )(dropout_rate=dropout_rate) x = block(x, train) if not self.use_post_layer_norm: return nn.LayerNorm(name='encoder_layernorm')(x) @@ -187,7 +194,9 @@ def get_posemb(self, return posemb_sincos_2d(*seqshape, width, dtype=dtype) @nn.compact - def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: + def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) -> spec.Tensor: + if not dropout_rate: + dropout_rate = self.dropout_rate # Patch extraction x = nn.Conv( self.width, @@ -203,10 +212,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: # Add posemb before adding extra token. x = x + self.get_posemb((h, w), c, x.dtype) - dropout_rate = self.dropout_rate - if dropout_rate is None: - dropout_rate = 0.0 - x = Dropout(rate=dropout_rate)(x, not train) + x = Dropout()(x, not train, rate=dropout_rate) x = Encoder( depth=self.depth, @@ -214,9 +220,8 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=dropout_rate, name='Transformer')( - x, train=not train) + x, train=not train, dropout_rate=dropout_rate) if self.use_map: x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) From feb9cc5cfd57575e11c91039c0292d26014e99fe Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 5 May 2025 10:48:34 +0000 Subject: [PATCH 05/39] add functional dropout for ogbg --- algoperf/workloads/ogbg/ogbg_jax/models.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index f5710a3ab..6ced9bef5 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -47,12 +47,10 @@ class GNN(nn.Module): activation_fn_name: str = 'relu' @nn.compact - def __call__(self, graph, train): - if self.dropout_rate is None: - dropout_rate = 0.1 - else: + def __call__(self, graph, train, dropout_rate=None): + if not dropout_rate: dropout_rate = self.dropout_rate - dropout = Dropout(rate=dropout_rate, deterministic=not train) + dropout = Dropout(deterministic=not train, rate=dropout_rate) graph = graph._replace( globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) From 9bba078b5e17a7881ffaa294a551129d4acf5c65 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 15 May 2025 21:43:18 +0000 Subject: [PATCH 06/39] modify wmt model for dropout passing --- algoperf/workloads/wmt/wmt_jax/models.py | 599 +++++++++++---------- algoperf/workloads/wmt/wmt_jax/workload.py | 4 +- 2 files changed, 310 insertions(+), 293 deletions(-) diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index 04f46e8ac..a5b484320 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -140,325 +140,342 @@ def __call__(self, inputs, inputs_positions=None): class MlpBlock(nn.Module): - """Transformer MLP / feed-forward block. + """Transformer MLP / feed-forward block. - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - out_dim: optionally specify out dimension. - """ - config: TransformerConfig - out_dim: Optional[int] = None + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + out_dim: optionally specify out dimension. + """ - @nn.compact - def __call__(self, inputs): - """Applies Transformer MlpBlock module.""" - cfg = self.config - actual_out_dim = ( - inputs.shape[-1] if self.out_dim is None else self.out_dim) - x = nn.Dense( - cfg.mlp_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - inputs) - x = cfg.activation(x) - if cfg.glu: - y = nn.Dense( - cfg.mlp_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - inputs) - x = x * y - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) - output = nn.Dense( - actual_out_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - x) - output = Dropout(rate=dropout_rate)( - output, deterministic=cfg.deterministic) - return output + config: TransformerConfig + out_dim: Optional[int] = None + + @nn.compact + def __call__(self, inputs, dropout_rate=None): + """Applies Transformer MlpBlock module.""" + cfg = self.config + actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim + x = nn.Dense( + cfg.mlp_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )(inputs) + x = cfg.activation(x) + if cfg.glu: + y = nn.Dense( + cfg.mlp_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )(inputs) + x = x * y + if dropout_rate is None: + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = Dropout()(x, rate=dropout_rate, deterministic=cfg.deterministic) + output = nn.Dense( + actual_out_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )(x) + output = Dropout()(output, rate=dropout_rate, deterministic=cfg.deterministic) + return output class Encoder1DBlock(nn.Module): - """Transformer encoder layer. - - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - """ - config: TransformerConfig - - @nn.compact - def __call__(self, inputs, encoder_mask=None): - """Applies Encoder1DBlock module. + """Transformer encoder layer. - Args: - inputs: input data. - encoder_mask: encoder self-attention mask. - - Returns: - output after transformer encoder block. + Attributes: + config: TransformerConfig dataclass containing hyperparameters. """ - cfg = self.config - pre_ln = cfg.pre_ln - # Attention block. - assert inputs.ndim == 3 - x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs - if cfg.attention_dropout_rate is None: - attention_dropout_rate = 0.1 - else: - attention_dropout_rate = cfg.attention_dropout_rate - x = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)( - cfg.attention_temp * x, x, mask=encoder_mask) - - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) - x = x + inputs - if not pre_ln: - x = nn.LayerNorm(dtype=cfg.dtype)(x) - - # MLP block. - y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x - y = MlpBlock(config=cfg)(y) - - return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y) + config: TransformerConfig + + @nn.compact + def __call__(self, inputs, encoder_mask=None, dropout_rate=None): + """Applies Encoder1DBlock module. + + Args: + inputs: input data. + encoder_mask: encoder self-attention mask. + + Returns: + output after transformer encoder block. + """ + cfg = self.config + pre_ln = cfg.pre_ln + + # Attention block. + assert inputs.ndim == 3 + x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs + if dropout_rate is None: + if cfg.attention_dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + )(cfg.attention_temp * x, x, mask=encoder_mask) + + x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = x + inputs + if not pre_ln: + x = nn.LayerNorm(dtype=cfg.dtype)(x) + + # MLP block. + y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x + y = MlpBlock(config=cfg)(y) + + return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y) class EncoderDecoder1DBlock(nn.Module): - """Transformer encoder-decoder layer. - - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - """ - config: TransformerConfig - - @nn.compact - def __call__(self, - targets, - encoded, - decoder_mask=None, - encoder_decoder_mask=None): - """Applies EncoderDecoder1DBlock module. + """Transformer encoder-decoder layer. - Args: - targets: input data for decoder - encoded: input data from encoder - decoder_mask: decoder self-attention mask. - encoder_decoder_mask: encoder-decoder attention mask. - - Returns: - output after transformer encoder-decoder block. + Attributes: + config: TransformerConfig dataclass containing hyperparameters. """ - cfg = self.config - pre_ln = cfg.pre_ln - # Decoder block. - assert targets.ndim == 3 - x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets + config: TransformerConfig - if cfg.attention_dropout_rate is None: - attention_dropout_rate = 0.1 - else: - attention_dropout_rate = cfg.attention_dropout_rate - x = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic, - decode=cfg.decode)( - cfg.attention_temp * x, x, mask=decoder_mask) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) - x = x + targets - if not pre_ln: - x = nn.LayerNorm(dtype=cfg.dtype)(x) - - # Encoder-Decoder block. - y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x - y = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)( - cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) - - y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) - y = y + x - if not pre_ln: - y = nn.LayerNorm(dtype=cfg.dtype)(y) - - # MLP block. - z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y - z = MlpBlock(config=cfg)(z) - - return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z) + @nn.compact + def __call__( + self, + targets, + encoded, + decoder_mask=None, + encoder_decoder_mask=None, + dropout_rate=None, + ): + """Applies EncoderDecoder1DBlock module. + + Args: + targets: input data for decoder + encoded: input data from encoder + decoder_mask: decoder self-attention mask. + encoder_decoder_mask: encoder-decoder attention mask. + + Returns: + output after transformer encoder-decoder block. + """ + cfg = self.config + pre_ln = cfg.pre_ln + + # Decoder block. + assert targets.ndim == 3 + x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets + + if dropout_rate is None: + if cfg.attention_dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + decode=cfg.decode, + )(cfg.attention_temp * x, x, mask=decoder_mask) + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = x + targets + if not pre_ln: + x = nn.LayerNorm(dtype=cfg.dtype)(x) + + # Encoder-Decoder block. + y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x + y = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) + + y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) + y = y + x + if not pre_ln: + y = nn.LayerNorm(dtype=cfg.dtype)(y) + + # MLP block. + z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y + z = MlpBlock(config=cfg)(z) + + return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z) class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation. - - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - shared_embedding: a shared embedding layer to use. - """ - config: TransformerConfig - shared_embedding: Any = None - - @nn.compact - def __call__(self, inputs, inputs_positions=None, encoder_mask=None): - """Applies Transformer model on the inputs. + """Transformer Model Encoder for sequence to sequence translation. - Args: - inputs: input data - inputs_positions: input subsequence positions for packed examples. - encoder_mask: decoder self-attention mask. - - Returns: - output of a transformer encoder. + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + shared_embedding: a shared embedding layer to use. """ - cfg = self.config - assert inputs.ndim == 2 # (batch, len) - # Input Embedding - if self.shared_embedding is None: - input_embed = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) - else: - input_embed = self.shared_embedding - x = inputs.astype('int32') - x = input_embed(x) - x = AddPositionEmbs( - config=cfg, decode=False, name='posembed_input')( - x, inputs_positions=inputs_positions) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) - - x = x.astype(cfg.dtype) - - # Input Encoder - for lyr in range(cfg.num_layers): - x = Encoder1DBlock( - config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask) - - encoded = ( - nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x) - if cfg.pre_ln else x) - - return encoded + config: TransformerConfig + shared_embedding: Any = None + + @nn.compact + def __call__( + self, inputs, inputs_positions=None, encoder_mask=None, dropout_rate=None + ): + """Applies Transformer model on the inputs. + + Args: + inputs: input data + inputs_positions: input subsequence positions for packed examples. + encoder_mask: decoder self-attention mask. + + Returns: + output of a transformer encoder. + """ + cfg = self.config + assert inputs.ndim == 2 # (batch, len) + + # Input Embedding + if self.shared_embedding is None: + input_embed = nn.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0), + ) + else: + input_embed = self.shared_embedding + x = inputs.astype("int32") + x = input_embed(x) + x = AddPositionEmbs(config=cfg, decode=False, name="posembed_input")( + x, inputs_positions=inputs_positions + ) + if dropout_rate is None: + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + + x = x.astype(cfg.dtype) + + # Input Encoder + for lyr in range(cfg.num_layers): + x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask) + + encoded = ( + nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x) + if cfg.pre_ln + else x + ) + + return encoded class Decoder(nn.Module): - """Transformer Model Decoder for sequence to sequence translation. + """Transformer Model Decoder for sequence to sequence translation. - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - shared_embedding: a shared embedding layer to use. - """ - config: TransformerConfig - shared_embedding: Any = None - - @nn.compact - def __call__(self, - encoded, - targets, - targets_positions=None, - decoder_mask=None, - encoder_decoder_mask=None): - """Applies Transformer model on the inputs. - - Args: - encoded: encoded input data from encoder. - targets: target inputs. - targets_positions: input subsequence positions for packed examples. - decoder_mask: decoder self-attention mask. - encoder_decoder_mask: encoder-decoder attention mask. - - Returns: - output of a transformer decoder. + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + shared_embedding: a shared embedding layer to use. """ - cfg = self.config - assert encoded.ndim == 3 # (batch, len, depth) - assert targets.ndim == 2 # (batch, len) + config: TransformerConfig + shared_embedding: Any = None - # Target Embedding - if self.shared_embedding is None: - output_embed = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) - else: - output_embed = self.shared_embedding - - y = targets.astype('int32') - if not cfg.decode: - y = shift_right(y) - y = output_embed(y) - y = AddPositionEmbs( - config=cfg, decode=cfg.decode, name='posembed_output')( - y, inputs_positions=targets_positions) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) - - y = y.astype(cfg.dtype) - - # Target-Input Decoder - for lyr in range(cfg.num_layers): - y = EncoderDecoder1DBlock( - config=cfg, name=f'encoderdecoderblock_{lyr}')( - y, - encoded, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask) - y = ( - nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y) - if cfg.pre_ln else y) - - # Use the transpose of embedding matrix for logit transform. - logits = output_embed.attend(y.astype(jnp.float32)) - # Correctly normalize pre-softmax logits for this shared case. - logits = logits / jnp.sqrt(y.shape[-1]) - return logits + @nn.compact + def __call__( + self, + encoded, + targets, + targets_positions=None, + decoder_mask=None, + encoder_decoder_mask=None, + dropout_rate=None, + ): + """Applies Transformer model on the inputs. + + Args: + encoded: encoded input data from encoder. + targets: target inputs. + targets_positions: input subsequence positions for packed examples. + decoder_mask: decoder self-attention mask. + encoder_decoder_mask: encoder-decoder attention mask. + + Returns: + output of a transformer decoder. + """ + cfg = self.config + + assert encoded.ndim == 3 # (batch, len, depth) + assert targets.ndim == 2 # (batch, len) + + # Target Embedding + if self.shared_embedding is None: + output_embed = nn.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0), + ) + else: + output_embed = self.shared_embedding + + y = targets.astype("int32") + if not cfg.decode: + y = shift_right(y) + y = output_embed(y) + y = AddPositionEmbs(config=cfg, decode=cfg.decode, name="posembed_output")( + y, inputs_positions=targets_positions + ) + if dropout_rate is None: + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) + + y = y.astype(cfg.dtype) + + # Target-Input Decoder + for lyr in range(cfg.num_layers): + y = EncoderDecoder1DBlock(config=cfg, name=f"encoderdecoderblock_{lyr}")( + y, + encoded, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + ) + y = ( + nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y) + if cfg.pre_ln + else y + ) + + # Use the transpose of embedding matrix for logit transform. + logits = output_embed.attend(y.astype(jnp.float32)) + # Correctly normalize pre-softmax logits for this shared case. + logits = logits / jnp.sqrt(y.shape[-1]) + return logits class Transformer(nn.Module): diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index cdfcb91df..240ad2c11 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -209,8 +209,8 @@ def translate_and_calculate_bleu(self, def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + dropout_rate: Optional[float] = 0.0, + aux_dropout_rate: Optional[float] = 0.0) -> spec.ModelInitState: """aux_dropout_rate is used as attention_dropout_rate.""" init_fake_batch_size = 2 From 31f601977335e170c54e50d37ece29ecaea9a314 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 15 May 2025 22:19:20 +0000 Subject: [PATCH 07/39] modify wmt model for dropout passing --- .../librispeech_jax/models.py | 2 +- algoperf/workloads/wmt/wmt_jax/models.py | 20 +++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index f937a1692..003bf4ea7 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -73,7 +73,7 @@ class Subsample(nn.Module): config: DeepspeechConfig @nn.compact - def __call__(self, inputs, output_paddings, train): + def __call__(self, inputs, output_paddings, train, dropout_rate=None): config = self.config outputs = jnp.expand_dims(inputs, axis=-1) diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index a5b484320..b84fb6d96 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -236,7 +236,7 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None): # MLP block. y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x - y = MlpBlock(config=cfg)(y) + y = MlpBlock(config=cfg)(y, dropout_rate=dropout_rate) return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y) @@ -324,7 +324,7 @@ def __call__( # MLP block. z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y - z = MlpBlock(config=cfg)(z) + z = MlpBlock(config=cfg)(z, dropout_rate=dropout_rate) return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z) @@ -382,7 +382,7 @@ def __call__( # Input Encoder for lyr in range(cfg.num_layers): - x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask) + x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask, dropout_rate) encoded = ( nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x) @@ -464,6 +464,7 @@ def __call__( encoded, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, + dropout_rate=dropout_rate, ) y = ( nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y) @@ -503,7 +504,7 @@ def setup(self): self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) - def encode(self, inputs, inputs_positions=None, inputs_segmentation=None): + def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, dropout_rate=None): """Applies Transformer encoder-branch on the inputs. Args: @@ -528,7 +529,7 @@ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None): jnp.equal, dtype=cfg.dtype)) return self.encoder( - inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask) + inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask, dropout_rate=dropout_rate) def decode( self, @@ -595,7 +596,8 @@ def __call__(self, inputs_positions=None, targets_positions=None, inputs_segmentation=None, - targets_segmentation=None): + targets_segmentation=None, + dropout_rate=None): """Applies Transformer model on the inputs. Args: @@ -612,7 +614,8 @@ def __call__(self, encoded = self.encode( inputs, inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation) + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate) return self.decode( encoded, @@ -620,4 +623,5 @@ def __call__(self, targets, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, - targets_segmentation=targets_segmentation) + targets_segmentation=targets_segmentation, + dropout_rate=dropout_rate) From e36d29432a960283f9a44baf745644c5c3ddbdaa Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 May 2025 23:32:41 +0000 Subject: [PATCH 08/39] reformatting and dropout fixes to fastmri and vit --- .../criteo1tb/criteo1tb_jax/models.py | 7 +- .../workloads/fastmri/fastmri_jax/models.py | 18 +- .../workloads/fastmri/fastmri_jax/workload.py | 21 +- .../imagenet_vit/imagenet_jax/models.py | 75 ++- .../imagenet_vit/imagenet_jax/workload.py | 22 +- .../librispeech_jax/models.py | 4 +- algoperf/workloads/wmt/wmt_jax/models.py | 525 +++++++++--------- 7 files changed, 362 insertions(+), 310 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index c56748bb1..0b2126915 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -7,6 +7,7 @@ from algoperf.jax_utils import Dropout + class DLRMResNet(nn.Module): """Define a DLRMResNet model. @@ -30,7 +31,7 @@ class DLRMResNet(nn.Module): @nn.compact def __call__(self, x, train, dropout_rate=None): if not dropout_rate: - dropout_rate=self.dropout_rate + dropout_rate = self.dropout_rate bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -157,7 +158,7 @@ class DlrmSmall(nn.Module): def __call__(self, x, train, dropout_rate=None): if not dropout_rate: dropout_rate = self.dropout_rate - + bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -219,6 +220,6 @@ def scaled_init(key, shape, dtype=jnp.float_): if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): top_mlp_input = Dropout(deterministic=not train)( - top_mlp_input, rate=dropout_rate) + top_mlp_input, rate=dropout_rate) logits = top_mlp_input return logits diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index 3d5460c18..7ecca2add 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -21,6 +21,7 @@ from algoperf.jax_utils import Dropout + def _instance_norm2d(x, axes, epsilon=1e-5): # promote x to at least float32, this avoids half precision computation # but preserves double or complex floating points @@ -57,13 +58,13 @@ class UNet(nn.Module): num_channels: int = 32 num_pool_layers: int = 4 out_channels = 1 - dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. + dropout_rate: float = 0.0 use_tanh: bool = False use_layer_norm: bool = False @nn.compact def __call__(self, x, train=True, dropout_rate=None): - if not dropout_rate: + if dropout_rate is None: dropout_rate = self.dropout_rate # pylint: disable=invalid-name @@ -138,7 +139,7 @@ class ConvBlock(nn.Module): dropout_rate: Dropout probability. """ out_channels: int - dropout_rate: float + dropout_rate: float = 0.0 use_tanh: bool use_layer_norm: bool @@ -152,8 +153,8 @@ def __call__(self, x, train=True, dropout_rate=None): Returns: jnp.array: Output tensor of shape `(N, H, W, out_channels)`. """ - if not dropout_rate: - dropout_rate=self.dropout_rate + if dropout_rate is None: + dropout_rate = self.dropout_rate x = nn.Conv( features=self.out_channels, kernel_size=(3, 3), @@ -174,8 +175,9 @@ def __call__(self, x, train=True, dropout_rate=None): x = activation_fn(x) # Ref code uses dropout2d which applies the same mask for the entire channel # Replicated by using broadcast dims to have the same filter on HW - x = Dropout(broadcast_dims=(1, 2), deterministic=not train)( - x, rate=dropout_rate ) + x = Dropout( + dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( + x, rate=dropout_rate) x = nn.Conv( features=self.out_channels, kernel_size=(3, 3), @@ -188,7 +190,7 @@ def __call__(self, x, train=True, dropout_rate=None): x = _instance_norm2d(x, (1, 2)) x = activation_fn(x) x = Dropout( - broadcast_dims=(1, 2), deterministic=not train)( + dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( x, rate=dropout_rate) return x diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 1156cf30a..17ce6b442 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -26,12 +26,21 @@ def init_model_fn( """aux_dropout_rate is unused.""" del aux_dropout_rate fake_batch = jnp.zeros((13, 320, 320)) - self._model = UNet( - num_pool_layers=self.num_pool_layers, - num_channels=self.num_channels, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm, - dropout_rate=dropout_rate) + if dropout_rate is None: + self._model = UNet( + num_pool_layers=self.num_pool_layers, + num_channels=self.num_channels, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm, + ) + else: + self._model = UNet( + num_pool_layers=self.num_pool_layers, + num_channels=self.num_channels, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm, + dropout_rate=dropout_rate) + params_rng, dropout_rng = jax.random.split(rng) variables = jax.jit( self._model.init)({'params': params_rng, 'dropout': dropout_rng}, diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 10a90f37d..227f7c297 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -39,9 +39,12 @@ class MlpBlock(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spec.Tensor: + def __call__(self, + x: spec.Tensor, + train: bool = True, + dropout_rate=None) -> spec.Tensor: """Applies Transformer MlpBlock module.""" - if not dropout_rate: + if dropout_rate is None: dropout_rate = self.dropout_rate inits = { @@ -57,7 +60,7 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spe y = nn.Dense(self.mlp_dim, **inits)(x) x = x * y - x = Dropout()(x, train, rate=dropout_rate) + x = Dropout(dropout_rate)(x, train, rate=dropout_rate) x = nn.Dense(d, **inits)(x) return x @@ -71,9 +74,12 @@ class Encoder1DBlock(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate) -> spec.Tensor: - if not dropout_rate: - dropout_rate=self.dropout_rate + def __call__(self, + x: spec.Tensor, + train: bool = True, + dropout_rate=dropout_rate) -> spec.Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate if not self.use_post_layer_norm: y = nn.LayerNorm(name='LayerNorm_0')(x) @@ -83,15 +89,14 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = Dropout()(y, train, dropout_rate=dropout_rate) + y = Dropout(dropout_rate)(y, train, dropout_rate=dropout_rate) x = x + y y = nn.LayerNorm(name='LayerNorm_2')(x) y = MlpBlock( - mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - name='MlpBlock_3')(y, train, dropout_rate=dropout_rate) - y = Dropout()(y, train, rate=dropout_rate) + mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3')( + y, train, dropout_rate=dropout_rate) + y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y else: y = x @@ -101,7 +106,7 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = Dropout()(y, train, rate=dropout_rate) + y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y x = nn.LayerNorm(name='LayerNorm_0')(x) @@ -109,8 +114,10 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, - name='MlpBlock_3')(y, train, dropout_rate=dropout_rate) - y = Dropout()(y, train)(rate=dropout_rate) + name='MlpBlock_3', + dropout_rate=dropout_rate)( + y, train, dropout_rate=dropout_rate) + y = Dropout(dropout_rate)(y, train)(rate=dropout_rate) x = x + y x = nn.LayerNorm(name='LayerNorm_2')(x) @@ -127,9 +134,12 @@ class Encoder(nn.Module): use_post_layer_norm: bool = False @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spec.Tensor: - if not dropout_rate: - dropout_rate=self.dropout_rate + def __call__(self, + x: spec.Tensor, + train: bool = True, + dropout_rate=None) -> spec.Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate # Input Encoder for lyr in range(self.depth): @@ -139,7 +149,8 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spe num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - )(dropout_rate=dropout_rate) + dropout_rate=dropout_rate)( + dropout_rate=dropout_rate) x = block(x, train) if not self.use_post_layer_norm: return nn.LayerNorm(name='encoder_layernorm')(x) @@ -151,9 +162,12 @@ class MAPHead(nn.Module): """Multihead Attention Pooling.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 + dropout_rate: 0.0 @nn.compact - def __call__(self, x): + def __call__(self, x, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.dropout_rate n, _, d = x.shape probe = self.param('probe', nn.initializers.xavier_uniform(), (1, 1, d), @@ -166,7 +180,7 @@ def __call__(self, x): kernel_init=nn.initializers.xavier_uniform())(probe, x) y = nn.LayerNorm()(x) - x = x + MlpBlock(mlp_dim=self.mlp_dim)(y) + x = x + MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=dropout_rate)(y) return x[:, 0] @@ -180,7 +194,7 @@ class ViT(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 rep_size: Union[int, bool] = True - dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. + dropout_rate: Optional[float] = 0.0 reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True use_glu: bool = False @@ -194,8 +208,12 @@ def get_posemb(self, return posemb_sincos_2d(*seqshape, width, dtype=dtype) @nn.compact - def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) -> spec.Tensor: - if not dropout_rate: + def __call__(self, + x: spec.Tensor, + *, + train: bool = False, + dropout_rate=None) -> spec.Tensor: + if dropout_rate is None: dropout_rate = self.dropout_rate # Patch extraction x = nn.Conv( @@ -212,7 +230,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) -> # Add posemb before adding extra token. x = x + self.get_posemb((h, w), c, x.dtype) - x = Dropout()(x, not train, rate=dropout_rate) + x = Dropout(dropout_rate)(x, not train, rate=dropout_rate) x = Encoder( depth=self.depth, @@ -220,11 +238,16 @@ def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) -> num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - name='Transformer')( + name='Transformer', + dropout_rate=dropout_rate)( x, train=not train, dropout_rate=dropout_rate) if self.use_map: - x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) + x = MAPHead( + num_heads=self.num_heads, + mlp_dim=self.mlp_dim, + dropout_rate=dropout_rate)( + x, dropout_rate=dropout_rate) else: x = jnp.mean(x, axis=1) diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 35a6c46be..5d07b5ff8 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -36,13 +36,21 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: del aux_dropout_rate - self._model = models.ViT( - dropout_rate=dropout_rate, - num_classes=self._num_classes, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - use_map=self.use_map, - **decode_variant('S/16')) + if dropout_rate is None: + self._model = models.ViT( + num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + **decode_variant('S/16')) + else: + self._model = models.ViT( + dropout_rate=dropout_rate, + num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + **decode_variant('S/16')) params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 003bf4ea7..4cdb02ee1 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -111,9 +111,7 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=None): input_dropout_rate = 0.1 else: input_dropout_rate = config.input_dropout_rate - outputs = Dropout( - rate=input_dropout_rate, deterministic=not train)( - outputs) + outputs = Dropout(rate=input_dropout_rate, deterministic=not train)(outputs) return outputs, output_paddings diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index b84fb6d96..54a917a09 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -140,64 +140,68 @@ def __call__(self, inputs, inputs_positions=None): class MlpBlock(nn.Module): - """Transformer MLP / feed-forward block. + """Transformer MLP / feed-forward block. Attributes: config: TransformerConfig dataclass containing hyperparameters. out_dim: optionally specify out dimension. """ - config: TransformerConfig - out_dim: Optional[int] = None - - @nn.compact - def __call__(self, inputs, dropout_rate=None): - """Applies Transformer MlpBlock module.""" - cfg = self.config - actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim - x = nn.Dense( - cfg.mlp_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - )(inputs) - x = cfg.activation(x) - if cfg.glu: - y = nn.Dense( - cfg.mlp_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - )(inputs) - x = x * y - if dropout_rate is None: - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout()(x, rate=dropout_rate, deterministic=cfg.deterministic) - output = nn.Dense( - actual_out_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - )(x) - output = Dropout()(output, rate=dropout_rate, deterministic=cfg.deterministic) - return output + config: TransformerConfig + out_dim: Optional[int] = None + + @nn.compact + def __call__(self, inputs, dropout_rate=None): + """Applies Transformer MlpBlock module.""" + cfg = self.config + actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim + x = nn.Dense( + cfg.mlp_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )( + inputs) + x = cfg.activation(x) + if cfg.glu: + y = nn.Dense( + cfg.mlp_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )( + inputs) + x = x * y + if dropout_rate is None: + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = Dropout()(x, rate=dropout_rate, deterministic=cfg.deterministic) + output = nn.Dense( + actual_out_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )( + x) + output = Dropout()( + output, rate=dropout_rate, deterministic=cfg.deterministic) + return output class Encoder1DBlock(nn.Module): - """Transformer encoder layer. + """Transformer encoder layer. Attributes: config: TransformerConfig dataclass containing hyperparameters. """ - config: TransformerConfig + config: TransformerConfig - @nn.compact - def __call__(self, inputs, encoder_mask=None, dropout_rate=None): - """Applies Encoder1DBlock module. + @nn.compact + def __call__(self, inputs, encoder_mask=None, dropout_rate=None): + """Applies Encoder1DBlock module. Args: inputs: input data. @@ -206,60 +210,60 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None): Returns: output after transformer encoder block. """ - cfg = self.config - pre_ln = cfg.pre_ln - - # Attention block. - assert inputs.ndim == 3 - x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs - if dropout_rate is None: - if cfg.attention_dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=dropout_rate, - deterministic=cfg.deterministic, - )(cfg.attention_temp * x, x, mask=encoder_mask) - - x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) - x = x + inputs - if not pre_ln: - x = nn.LayerNorm(dtype=cfg.dtype)(x) - - # MLP block. - y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x - y = MlpBlock(config=cfg)(y, dropout_rate=dropout_rate) - - return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y) + cfg = self.config + pre_ln = cfg.pre_ln + + # Attention block. + assert inputs.ndim == 3 + x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs + if dropout_rate is None: + if cfg.attention_dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + )(cfg.attention_temp * x, x, mask=encoder_mask) + + x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = x + inputs + if not pre_ln: + x = nn.LayerNorm(dtype=cfg.dtype)(x) + + # MLP block. + y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x + y = MlpBlock(config=cfg)(y, dropout_rate=dropout_rate) + + return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y) class EncoderDecoder1DBlock(nn.Module): - """Transformer encoder-decoder layer. + """Transformer encoder-decoder layer. Attributes: config: TransformerConfig dataclass containing hyperparameters. """ - config: TransformerConfig + config: TransformerConfig - @nn.compact - def __call__( - self, - targets, - encoded, - decoder_mask=None, - encoder_decoder_mask=None, - dropout_rate=None, - ): - """Applies EncoderDecoder1DBlock module. + @nn.compact + def __call__( + self, + targets, + encoded, + decoder_mask=None, + encoder_decoder_mask=None, + dropout_rate=None, + ): + """Applies EncoderDecoder1DBlock module. Args: targets: input data for decoder @@ -270,81 +274,83 @@ def __call__( Returns: output after transformer encoder-decoder block. """ - cfg = self.config - pre_ln = cfg.pre_ln - - # Decoder block. - assert targets.ndim == 3 - x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets - - if dropout_rate is None: - if cfg.attention_dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=dropout_rate, - deterministic=cfg.deterministic, - decode=cfg.decode, - )(cfg.attention_temp * x, x, mask=decoder_mask) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) - x = x + targets - if not pre_ln: - x = nn.LayerNorm(dtype=cfg.dtype)(x) - - # Encoder-Decoder block. - y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x - y = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=dropout_rate, - deterministic=cfg.deterministic, - )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) - - y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) - y = y + x - if not pre_ln: - y = nn.LayerNorm(dtype=cfg.dtype)(y) - - # MLP block. - z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y - z = MlpBlock(config=cfg)(z, dropout_rate=dropout_rate) - - return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z) + cfg = self.config + pre_ln = cfg.pre_ln + + # Decoder block. + assert targets.ndim == 3 + x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets + + if dropout_rate is None: + if cfg.attention_dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + decode=cfg.decode, + )(cfg.attention_temp * x, x, mask=decoder_mask) + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = x + targets + if not pre_ln: + x = nn.LayerNorm(dtype=cfg.dtype)(x) + + # Encoder-Decoder block. + y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x + y = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) + + y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) + y = y + x + if not pre_ln: + y = nn.LayerNorm(dtype=cfg.dtype)(y) + + # MLP block. + z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y + z = MlpBlock(config=cfg)(z, dropout_rate=dropout_rate) + + return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z) class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation. + """Transformer Model Encoder for sequence to sequence translation. Attributes: config: TransformerConfig dataclass containing hyperparameters. shared_embedding: a shared embedding layer to use. """ - config: TransformerConfig - shared_embedding: Any = None + config: TransformerConfig + shared_embedding: Any = None - @nn.compact - def __call__( - self, inputs, inputs_positions=None, encoder_mask=None, dropout_rate=None - ): - """Applies Transformer model on the inputs. + @nn.compact + def __call__(self, + inputs, + inputs_positions=None, + encoder_mask=None, + dropout_rate=None): + """Applies Transformer model on the inputs. Args: inputs: input data @@ -354,67 +360,66 @@ def __call__( Returns: output of a transformer encoder. """ - cfg = self.config - assert inputs.ndim == 2 # (batch, len) - - # Input Embedding - if self.shared_embedding is None: - input_embed = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0), - ) - else: - input_embed = self.shared_embedding - x = inputs.astype("int32") - x = input_embed(x) - x = AddPositionEmbs(config=cfg, decode=False, name="posembed_input")( - x, inputs_positions=inputs_positions - ) - if dropout_rate is None: - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) - - x = x.astype(cfg.dtype) - - # Input Encoder - for lyr in range(cfg.num_layers): - x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask, dropout_rate) - - encoded = ( - nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x) - if cfg.pre_ln - else x - ) - - return encoded + cfg = self.config + assert inputs.ndim == 2 # (batch, len) + + # Input Embedding + if self.shared_embedding is None: + input_embed = nn.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0), + ) + else: + input_embed = self.shared_embedding + x = inputs.astype("int32") + x = input_embed(x) + x = AddPositionEmbs( + config=cfg, decode=False, name="posembed_input")( + x, inputs_positions=inputs_positions) + if dropout_rate is None: + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + + x = x.astype(cfg.dtype) + + # Input Encoder + for lyr in range(cfg.num_layers): + x = Encoder1DBlock( + config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask, dropout_rate) + + encoded = ( + nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x) + if cfg.pre_ln else x) + + return encoded class Decoder(nn.Module): - """Transformer Model Decoder for sequence to sequence translation. + """Transformer Model Decoder for sequence to sequence translation. Attributes: config: TransformerConfig dataclass containing hyperparameters. shared_embedding: a shared embedding layer to use. """ - config: TransformerConfig - shared_embedding: Any = None + config: TransformerConfig + shared_embedding: Any = None - @nn.compact - def __call__( - self, - encoded, - targets, - targets_positions=None, - decoder_mask=None, - encoder_decoder_mask=None, - dropout_rate=None, - ): - """Applies Transformer model on the inputs. + @nn.compact + def __call__( + self, + encoded, + targets, + targets_positions=None, + decoder_mask=None, + encoder_decoder_mask=None, + dropout_rate=None, + ): + """Applies Transformer model on the inputs. Args: encoded: encoded input data from encoder. @@ -426,57 +431,56 @@ def __call__( Returns: output of a transformer decoder. """ - cfg = self.config - - assert encoded.ndim == 3 # (batch, len, depth) - assert targets.ndim == 2 # (batch, len) - - # Target Embedding - if self.shared_embedding is None: - output_embed = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0), - ) - else: - output_embed = self.shared_embedding - - y = targets.astype("int32") - if not cfg.decode: - y = shift_right(y) - y = output_embed(y) - y = AddPositionEmbs(config=cfg, decode=cfg.decode, name="posembed_output")( - y, inputs_positions=targets_positions - ) - if dropout_rate is None: - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) - - y = y.astype(cfg.dtype) - - # Target-Input Decoder - for lyr in range(cfg.num_layers): - y = EncoderDecoder1DBlock(config=cfg, name=f"encoderdecoderblock_{lyr}")( - y, - encoded, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask, - dropout_rate=dropout_rate, - ) - y = ( - nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y) - if cfg.pre_ln - else y - ) - - # Use the transpose of embedding matrix for logit transform. - logits = output_embed.attend(y.astype(jnp.float32)) - # Correctly normalize pre-softmax logits for this shared case. - logits = logits / jnp.sqrt(y.shape[-1]) - return logits + cfg = self.config + + assert encoded.ndim == 3 # (batch, len, depth) + assert targets.ndim == 2 # (batch, len) + + # Target Embedding + if self.shared_embedding is None: + output_embed = nn.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0), + ) + else: + output_embed = self.shared_embedding + + y = targets.astype("int32") + if not cfg.decode: + y = shift_right(y) + y = output_embed(y) + y = AddPositionEmbs( + config=cfg, decode=cfg.decode, name="posembed_output")( + y, inputs_positions=targets_positions) + if dropout_rate is None: + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) + + y = y.astype(cfg.dtype) + + # Target-Input Decoder + for lyr in range(cfg.num_layers): + y = EncoderDecoder1DBlock( + config=cfg, name=f"encoderdecoderblock_{lyr}")( + y, + encoded, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + dropout_rate=dropout_rate, + ) + y = ( + nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y) + if cfg.pre_ln else y) + + # Use the transpose of embedding matrix for logit transform. + logits = output_embed.attend(y.astype(jnp.float32)) + # Correctly normalize pre-softmax logits for this shared case. + logits = logits / jnp.sqrt(y.shape[-1]) + return logits class Transformer(nn.Module): @@ -504,7 +508,11 @@ def setup(self): self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) - def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, dropout_rate=None): + def encode(self, + inputs, + inputs_positions=None, + inputs_segmentation=None, + dropout_rate=None): """Applies Transformer encoder-branch on the inputs. Args: @@ -529,7 +537,10 @@ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, dropou jnp.equal, dtype=cfg.dtype)) return self.encoder( - inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask, dropout_rate=dropout_rate) + inputs, + inputs_positions=inputs_positions, + encoder_mask=encoder_mask, + dropout_rate=dropout_rate) def decode( self, From 363da8ac032c82b2ed8ac1c7ea64a25be243cbc6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 May 2025 23:39:39 +0000 Subject: [PATCH 09/39] dropout fix for criteo1tb jax --- algoperf/workloads/criteo1tb/criteo1tb_jax/models.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index 0b2126915..b7af15208 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -30,7 +30,7 @@ class DLRMResNet(nn.Module): @nn.compact def __call__(self, x, train, dropout_rate=None): - if not dropout_rate: + if dropout_rate is None: dropout_rate = self.dropout_rate bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) @@ -93,7 +93,7 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input) x = nn.relu(x) if dropout_rate and layer_idx == num_layers_top - 2: - x = Dropout(deterministic=not train)(x, rate=dropout_rate) + x = Dropout(dropout_rate, deterministic=not train)(x, rate=dropout_rate) top_mlp_input += x # In the DLRM model the last layer width is always 1. We can hardcode that # below. @@ -156,7 +156,7 @@ class DlrmSmall(nn.Module): @nn.compact def __call__(self, x, train, dropout_rate=None): - if not dropout_rate: + if dropout_rate is None: dropout_rate = self.dropout_rate bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) @@ -219,7 +219,8 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input = nn.LayerNorm()(top_mlp_input) if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): - top_mlp_input = Dropout(deterministic=not train)( - top_mlp_input, rate=dropout_rate) + top_mlp_input = Dropout( + dropout_rate, deterministic=not train)( + top_mlp_input, rate=dropout_rate) logits = top_mlp_input return logits From 341bf8996de4e4897eb9559124b1123c115dd62a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 May 2025 23:42:00 +0000 Subject: [PATCH 10/39] dropout fix for criteo1tb jax --- .../criteo1tb/criteo1tb_jax/workload.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 91761e458..bad2f4390 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -82,15 +82,26 @@ def init_model_fn( model_class = models.DLRMResNet else: model_class = models.DlrmSmall - self._model = model_class( - vocab_size=self.vocab_size, - num_dense_features=self.num_dense_features, - mlp_bottom_dims=self.mlp_bottom_dims, - mlp_top_dims=self.mlp_top_dims, - embed_dim=self.embed_dim, - dropout_rate=dropout_rate, - use_layer_norm=self.use_layer_norm, - embedding_init_multiplier=self.embedding_init_multiplier) + + if dropout_rate is None: + self._model = model_class( + vocab_size=self.vocab_size, + num_dense_features=self.num_dense_features, + mlp_bottom_dims=self.mlp_bottom_dims, + mlp_top_dims=self.mlp_top_dims, + embed_dim=self.embed_dim, + use_layer_norm=self.use_layer_norm, + embedding_init_multiplier=self.embedding_init_multiplier) + else: + self._model = model_class( + vocab_size=self.vocab_size, + num_dense_features=self.num_dense_features, + mlp_bottom_dims=self.mlp_bottom_dims, + mlp_top_dims=self.mlp_top_dims, + embed_dim=self.embed_dim, + dropout_rate=dropout_rate, + use_layer_norm=self.use_layer_norm, + embedding_init_multiplier=self.embedding_init_multiplier) params_rng, dropout_rng = jax.random.split(rng) init_fake_batch_size = 2 From f0c385bcec139050fe11bee01f0bc0ba0b9194d9 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 May 2025 23:49:42 +0000 Subject: [PATCH 11/39] remove aux dropout option from conformer and from init_model_fn signature for fastmri, vit and criteo --- .../criteo1tb/criteo1tb_jax/workload.py | 2 -- .../workloads/fastmri/fastmri_jax/workload.py | 3 +- .../imagenet_vit/imagenet_jax/workload.py | 4 +-- .../librispeech_jax/workload.py | 28 ++++++++++++------- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index bad2f4390..e3864643b 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -73,11 +73,9 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None, tabulate: Optional[bool] = False, ) -> spec.ModelInitState: """Only dropout is used.""" - del aux_dropout_rate if self.use_resnet: model_class = models.DLRMResNet else: diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 17ce6b442..13ab5c1b8 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -22,9 +22,8 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + ) -> spec.ModelInitState: """aux_dropout_rate is unused.""" - del aux_dropout_rate fake_batch = jnp.zeros((13, 320, 320)) if dropout_rate is None: self._model = UNet( diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 5d07b5ff8..5107ed993 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -33,9 +33,7 @@ def initialized(self, key: spec.RandomState, def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - del aux_dropout_rate + dropout_rate: Optional[float] = None) -> spec.ModelInitState: if dropout_rate is None: self._model = models.ViT( num_classes=self._num_classes, diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 39012a20d..4da70fc61 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -61,24 +61,32 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + ) -> spec.ModelInitState: """Conformer model init function. - Here we use dropout_rate as *_residual_dropout_rate, and aux_dropout_rate as + Here we use dropout_rate as *_residual_dropout_rate, and for input_dropout_rate. """ if self.use_gelu: activation_function_name = 'gelu' else: activation_function_name = 'swish' - model_config = models.ConformerConfig( - attention_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - input_dropout_rate=aux_dropout_rate, - use_specaug=self.use_specaug, - attention_temperature=self.attention_temperature, - use_post_layer_norm=self.use_post_layer_norm, - activation_function_name=activation_function_name) + if dropout_rate is None: + model_config = models.ConformerConfig( + attention_residual_dropout_rate=dropout_rate, + feed_forward_residual_dropout_rate=dropout_rate, + input_dropout_rate=dropout_rate, + use_specaug=self.use_specaug, + attention_temperature=self.attention_temperature, + use_post_layer_norm=self.use_post_layer_norm, + activation_function_name=activation_function_name) + else: + model_config = models.ConformerConfig( + use_specaug=self.use_specaug, + attention_temperature=self.attention_temperature, + use_post_layer_norm=self.use_post_layer_norm, + activation_function_name=activation_function_name) + self._model = models.Conformer(model_config) input_shape = [(320000,), (320000,)] fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] From 7af5c941d81a7e66e3128afaf5b49c6f2730c302 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 31 May 2025 00:26:32 +0000 Subject: [PATCH 12/39] add dropout piping for conformer and deepspeech --- .../workloads/fastmri/fastmri_jax/workload.py | 2 +- .../librispeech_jax/models.py | 7 +-- .../librispeech_jax/workload.py | 2 +- .../librispeech_jax/models.py | 34 +++++++++------ .../librispeech_jax/workload.py | 43 +++++++++++-------- algoperf/workloads/ogbg/ogbg_jax/models.py | 4 +- algoperf/workloads/ogbg/ogbg_jax/workload.py | 26 ++++++----- algoperf/workloads/wmt/wmt_jax/workload.py | 24 ++++++----- 8 files changed, 83 insertions(+), 59 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 13ab5c1b8..439b8d055 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -22,7 +22,7 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - ) -> spec.ModelInitState: + ) -> spec.ModelInitState: """aux_dropout_rate is unused.""" fake_batch = jnp.zeros((13, 320, 320)) if dropout_rate is None: diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 51e93acc9..29c349e11 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -38,13 +38,10 @@ class ConformerConfig: num_attention_heads: int = 8 num_encoder_layers: int = 4 attention_dropout_rate: float = 0.0 - # If None, defaults to 0.1. - attention_residual_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.0. + attention_residual_dropout_rate: Optional[float] = 0.0 conv_residual_dropout_rate: Optional[float] = 0.0 feed_forward_dropout_rate: float = 0.0 - # If None, defaults to 0.1. - feed_forward_residual_dropout_rate: Optional[float] = 0.1 + feed_forward_residual_dropout_rate: Optional[float] = 0.0 convolution_kernel_size: int = 5 feed_forward_expansion_factor: int = 4 freq_mask_count: int = 2 diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 4da70fc61..a54f52c04 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -61,7 +61,7 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - ) -> spec.ModelInitState: + ) -> spec.ModelInitState: """Conformer model init function. Here we use dropout_rate as *_residual_dropout_rate, and for diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 4cdb02ee1..3ad31b532 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -75,6 +75,9 @@ class Subsample(nn.Module): @nn.compact def __call__(self, inputs, output_paddings, train, dropout_rate=None): config = self.config + if dropout_rate is None: + dropout_rate = config.dropout_rate + outputs = jnp.expand_dims(inputs, axis=-1) outputs, output_paddings = Conv2dSubsampling( @@ -111,7 +114,9 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=None): input_dropout_rate = 0.1 else: input_dropout_rate = config.input_dropout_rate - outputs = Dropout(rate=input_dropout_rate, deterministic=not train)(outputs) + outputs = Dropout( + rate=input_dropout_rate, deterministic=not train, rate=dropout_rate)( + outputs, rate=dropout_rate) return outputs, output_paddings @@ -187,7 +192,13 @@ class FeedForwardModule(nn.Module): config: DeepspeechConfig @nn.compact - def __call__(self, inputs, input_paddings=None, train=False): + def __call__(self, + inputs, + input_paddings=None, + train=False, + dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.config.feed_forward_dropout_rate padding_mask = jnp.expand_dims(1 - input_paddings, -1) config = self.config @@ -211,12 +222,8 @@ def __call__(self, inputs, input_paddings=None, train=False): inputs = nn.relu(inputs) inputs *= padding_mask - if config.feed_forward_dropout_rate is None: - feed_forward_dropout_rate = 0.1 - else: - feed_forward_dropout_rate = config.feed_forward_dropout_rate - inputs = Dropout(rate=feed_forward_dropout_rate)( - inputs, deterministic=not train) + inputs = Dropout(rate=dropout_rate)( + inputs, deterministic=not train, rate=dropout_rate) return inputs @@ -472,8 +479,10 @@ def setup(self): ) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, dropout_rate=None): config = self.config + if dropout_rate is None: + dropout_rate = config.dropout_rate outputs = inputs output_paddings = input_paddings @@ -493,7 +502,7 @@ def __call__(self, inputs, input_paddings, train): # Subsample input by a factor of 4 by performing strided convolutions. outputs, output_paddings = Subsample( - config=config)(outputs, output_paddings, train) + config=config)(outputs, output_paddings, train, dropout_rate=dropout_rate) # Run the lstm layers. for _ in range(config.num_lstm_layers): @@ -507,9 +516,8 @@ def __call__(self, inputs, input_paddings, train): outputs = outputs + FeedForwardModule(config=self.config)( outputs, output_paddings, train) else: - outputs = FeedForwardModule(config=self.config)(outputs, - output_paddings, - train) + outputs = FeedForwardModule(config=self.config)( + outputs, output_paddings, train, dropout_rate=dropout_rate) # Run the decoder which in this case is a trivial projection layer. if config.enable_decoder_layer_norm: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index d3b616f43..3c9a96f99 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -18,24 +18,31 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + dropout_rate: Optional[float] = None) -> spec.ModelInitState: """Deepspeech model init function. - - Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate - as input_dropout_rate. """ - model_config = models.DeepspeechConfig( - feed_forward_dropout_rate=dropout_rate, - use_specaug=self.use_specaug, - input_dropout_rate=aux_dropout_rate, - use_tanh=self.use_tanh, - enable_residual_connections=self.enable_residual_connections, - enable_decoder_layer_norm=self.enable_decoder_layer_norm, - layernorm_everywhere=self.layernorm_everywhere, - freq_mask_count=self.freq_mask_count, - time_mask_count=self.time_mask_count, - ) + if dropout_rate is None: + model_config = models.DeepspeechConfig( + use_specaug=self.use_specaug, + use_tanh=self.use_tanh, + enable_residual_connections=self.enable_residual_connections, + enable_decoder_layer_norm=self.enable_decoder_layer_norm, + layernorm_everywhere=self.layernorm_everywhere, + freq_mask_count=self.freq_mask_count, + time_mask_count=self.time_mask_count, + ) + else: + model_config = models.DeepspeechConfig( + feed_forward_dropout_rate=dropout_rate, + use_specaug=self.use_specaug, + input_dropout_rate=dropout_rate, + use_tanh=self.use_tanh, + enable_residual_connections=self.enable_residual_connections, + enable_decoder_layer_norm=self.enable_decoder_layer_norm, + layernorm_everywhere=self.layernorm_everywhere, + freq_mask_count=self.freq_mask_count, + time_mask_count=self.time_mask_count, + ) self._model = models.Deepspeech(model_config) input_shape = [(320000,), (320000,)] fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] @@ -64,6 +71,7 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool, use_running_average_bn: Optional[bool] = None + dropout_rate: Optional[bool] = None ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] @@ -75,7 +83,8 @@ def model_fn( input_paddings, train=True, rngs={'dropout' : rng}, - mutable=['batch_stats']) + mutable=['batch_stats'], + dropout_rate=dropout_rate) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 6ced9bef5..f6cb1c490 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -48,9 +48,9 @@ class GNN(nn.Module): @nn.compact def __call__(self, graph, train, dropout_rate=None): - if not dropout_rate: + if dropout_rate is not None: dropout_rate = self.dropout_rate - dropout = Dropout(deterministic=not train, rate=dropout_rate) + dropout = Dropout(dropout_rate, deterministic=not train)(dropout_rate) graph = graph._replace( globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index e895d15a7..f7de3f982 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -20,18 +20,24 @@ class OgbgWorkload(BaseOgbgWorkload): def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + dropout_rate: Optional[float] = None) -> spec.ModelInitState: """aux_dropout_rate is unused.""" - del aux_dropout_rate rng, params_rng, dropout_rng = jax.random.split(rng, 3) - self._model = models.GNN( - self._num_outputs, - dropout_rate=dropout_rate, - activation_fn_name=self.activation_fn_name, - hidden_dims=self.hidden_dims, - latent_dim=self.latent_dim, - num_message_passing_steps=self.num_message_passing_steps) + if dropout_rate is None: + self._model = models.GNN( + self._num_outputs, + activation_fn_name=self.activation_fn_name, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps) + else: + self._model = models.GNN( + self._num_outputs, + dropout_rate=dropout_rate, + activation_fn_name=self.activation_fn_name, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps) init_fn = jax.jit(functools.partial(self._model.init, train=False)) fake_batch = jraph.GraphsTuple( n_node=jnp.asarray([1]), diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 240ad2c11..b1f1e78a8 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -209,10 +209,7 @@ def translate_and_calculate_bleu(self, def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = 0.0, - aux_dropout_rate: Optional[float] = 0.0) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" - + dropout_rate: Optional[float] = 0.0) -> spec.ModelInitState: init_fake_batch_size = 2 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) @@ -224,13 +221,20 @@ def init_model_fn( else: raise ValueError(f'Unknown activation function {self.activation}.') + if dropout_rate is None: + model_config = models.TransformerConfig( + pre_ln=self.pre_ln, + attention_temp=self.attention_temp, + activation=activation, + glu=self.glu) + else: model_config = models.TransformerConfig( - dropout_rate=dropout_rate, - attention_dropout_rate=aux_dropout_rate, - pre_ln=self.pre_ln, - attention_temp=self.attention_temp, - activation=activation, - glu=self.glu) + dropout_rate=dropout_rate, + attention_dropout_rate=dropout_rate, + pre_ln=self.pre_ln, + attention_temp=self.attention_temp, + activation=activation, + glu=self.glu) self._train_model = models.Transformer(model_config) eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) From cbd065b490e1e57212d6b0112715ee73fcdf1a67 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 31 May 2025 01:30:22 +0000 Subject: [PATCH 13/39] pipe dropout through model_fn --- .../criteo1tb/criteo1tb_jax/workload.py | 4 +- .../workloads/fastmri/fastmri_jax/workload.py | 14 ++-- .../imagenet_vit/imagenet_jax/workload.py | 6 +- .../librispeech_jax/models.py | 72 +++++++++---------- .../librispeech_jax/workload.py | 6 +- .../librispeech_jax/workload.py | 2 +- algoperf/workloads/ogbg/ogbg_jax/workload.py | 6 +- 7 files changed, 60 insertions(+), 50 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index e3864643b..101e02c15 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -126,7 +126,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm inputs = augmented_and_preprocessed_input_batch['inputs'] @@ -134,6 +135,7 @@ def model_fn( apply_kwargs = {'train': train} if train: apply_kwargs['rngs'] = {'dropout': rng} + apply_kwargs['dropout_rate'] = dropout_rate logits_batch = self._model.apply({'params': params}, inputs, **apply_kwargs) return logits_batch, None diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 439b8d055..3d891cf8f 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -60,14 +60,18 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN - logits = self._model.apply({'params': params}, - augmented_and_preprocessed_input_batch['inputs'], - rngs={'dropout': rng}, - train=train) + + if train: + logits = self._model.apply({'params': params}, + augmented_and_preprocessed_input_batch['inputs'], + rngs={'dropout': rng}, + train=train, + dropout_rate=dropout_rate) return logits, None # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 5107ed993..89355ac6e 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -66,14 +66,16 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN logits = self._model.apply({'params': params}, augmented_and_preprocessed_input_batch['inputs'], rngs={'dropout': rng}, - train=train) + train=train, + dropout_rate=dropout_rate) return logits, None def _eval_model_on_split(self, diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 29c349e11..2ca0fffdc 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -37,7 +37,7 @@ class ConformerConfig: encoder_dim: int = 512 num_attention_heads: int = 8 num_encoder_layers: int = 4 - attention_dropout_rate: float = 0.0 + dropout_rate: float = 0.1 attention_residual_dropout_rate: Optional[float] = 0.0 conv_residual_dropout_rate: Optional[float] = 0.0 feed_forward_dropout_rate: float = 0.0 @@ -51,8 +51,6 @@ class ConformerConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 batch_norm_momentum: float = 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True @@ -98,10 +96,12 @@ class Subsample(nn.Module): input_dropout_rate: dropout rate for inputs. """ encoder_dim: int = 0 - input_dropout_rate: float = 0.0 + dropout_rate: float = 0.0 @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.dropout_rate output_paddings = input_paddings outputs = jnp.expand_dims(inputs, axis=-1) @@ -128,8 +128,8 @@ def __call__(self, inputs, input_paddings, train): seq_length=outputs.shape[1]) outputs = Dropout( - rate=self.input_dropout_rate, deterministic=not train)( - outputs) + rate=dropout_rate, deterministic=not train)( + outputs, rate=dropout_rate) return outputs, output_paddings @@ -196,9 +196,10 @@ class FeedForwardModule(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, padding_mask=None, train=False): + def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=dropout_rate): config = self.config - + if dropout_rate is None: + dropout_rate = config.dropout_rate inputs = LayerNorm(dim=config.encoder_dim)(inputs) inputs = nn.Dense( @@ -387,8 +388,11 @@ class MultiHeadedSelfAttention(nn.Module): config: ConformerConfig = None @nn.compact - def __call__(self, inputs, paddings, train): + def __call__(self, inputs, paddings, train, dropout_rate=None): config = self.config + if dropout_rate is None: + dropout_rate = config.dropout_rate + mask_paddings = 1 - paddings attention_mask = nn.make_attention_mask( mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32) @@ -410,13 +414,9 @@ def __call__(self, inputs, paddings, train): deterministic=not train)( inputs_q=inputs, mask=attention_mask) - if config.attention_residual_dropout_rate is None: - attention_residual_dropout_rate = 0.1 - else: - attention_residual_dropout_rate = config.attention_residual_dropout_rate result = Dropout( - rate=attention_residual_dropout_rate, deterministic=not train)( - result) + rate=dropout_rate, deterministic=not train)( + result, rate=dropout_rate) return result @@ -526,8 +526,11 @@ def __call__(self, input_paddings, train, update_batch_norm, - use_running_average_bn): + use_running_average_bn, + dropout_rate=None): config = self.config + if dropout_rate is None: + dropout_rate = config.dropout_rate inputs = LayerNorm(dim=config.encoder_dim)(inputs) input_gated1 = nn.Dense( @@ -572,13 +575,9 @@ def __call__(self, config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())( inputs) - if config.conv_residual_dropout_rate is None: - conv_residual_dropout_rate = 0.0 - else: - conv_residual_dropout_rate = config.conv_residual_dropout_rate inputs = Dropout( - rate=conv_residual_dropout_rate, deterministic=not train)( - inputs) + rate=dropout_rate, deterministic=not train)( + inputs, rate=dropout_rate) return inputs @@ -603,26 +602,28 @@ def __call__(self, input_paddings, train, update_batch_norm, - use_running_average): + use_running_average, + dropout_rate=None): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( - inputs, padding_mask, train) + inputs, padding_mask, train, dropout_rate) inputs = inputs + MultiHeadedSelfAttention(config=self.config)( - inputs, input_paddings, train) + inputs, input_paddings, train, dropout_rate=dropout_rate) inputs = inputs + \ ConvolutionBlock(config)(inputs, input_paddings, train, update_batch_norm, - use_running_average + use_running_average, + dropout_rate ) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( - inputs, padding_mask, train) + inputs, padding_mask, train, dropout_rate) if config.use_post_layer_norm: inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -656,7 +657,8 @@ def __call__(self, input_paddings, train, update_batch_norm: Optional[bool] = None, - use_running_average_bn: Optional[bool] = None): + use_running_average_bn: Optional[bool] = None, + dropout_rate: Optional[float] = None): config = self.config outputs = inputs @@ -681,15 +683,10 @@ def __call__(self, if train and config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - # Subsample input by a factor of 4 by performing strided convolutions. - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate outputs, output_paddings = Subsample( encoder_dim=config.encoder_dim, - input_dropout_rate=input_dropout_rate)( - outputs, output_paddings, train) + dropout_rate=dropout_rate)( + outputs, output_paddings, train, dropout_rate=dropout_rate) # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): @@ -697,7 +694,8 @@ def __call__(self, output_paddings, train, update_batch_norm, - use_running_average_bn) + use_running_average_bn, + dropout_rate) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index a54f52c04..2e082cf07 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -116,7 +116,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None + use_running_average_bn: Optional[bool] = None, + dropout_rate: Optional[float] = None, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] @@ -129,7 +130,8 @@ def model_fn( train=True, rngs={'dropout' : rng}, mutable=['batch_stats'], - use_running_average_bn=use_running_average_bn) + use_running_average_bn=use_running_average_bn, + dropout_rate=dropout_rate) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 3c9a96f99..825b470db 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -70,7 +70,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None + use_running_average_bn: Optional[bool] = None, dropout_rate: Optional[bool] = None ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index f7de3f982..3becd5599 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -63,7 +63,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: Optional[float]) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del update_batch_norm # No BN in the GNN model. if model_state is not None: @@ -74,7 +75,8 @@ def model_fn( logits = self._model.apply({'params': params}, augmented_and_preprocessed_input_batch['inputs'], rngs={'dropout': rng}, - train=train) + train=train, + dropout_rate=dropout_rate) return logits, None def _binary_cross_entropy_with_mask( From 31babfd9da6e2ad55156f55aeeb1c9cf10d88edc Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 4 Jun 2025 19:05:10 +0000 Subject: [PATCH 14/39] fix syntax --- algoperf/workloads/wmt/wmt_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index b1f1e78a8..367c062cb 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -228,7 +228,7 @@ def init_model_fn( activation=activation, glu=self.glu) else: - model_config = models.TransformerConfig( + model_config = models.TransformerConfig( dropout_rate=dropout_rate, attention_dropout_rate=dropout_rate, pre_ln=self.pre_ln, From 95d67db14e4c0d68e4868b55240c2211b8b039af Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 4 Jun 2025 20:45:08 +0000 Subject: [PATCH 15/39] dropout changes wmt jax --- algoperf/workloads/wmt/wmt_jax/models.py | 67 +++++++++------------- algoperf/workloads/wmt/wmt_jax/workload.py | 6 +- 2 files changed, 30 insertions(+), 43 deletions(-) diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index 54a917a09..3947a1b81 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -28,10 +28,7 @@ class TransformerConfig: max_len: int = 256 activation: Callable = nn.relu glu: bool = False - #If None, defaults to 0.1. dropout_rate: Optional[float] = 0.1 - #If None, defaults to 0.1. - attention_dropout_rate: Optional[float] = 0.1 attention_temp: float = 1.0 deterministic: bool = False decode: bool = False @@ -154,6 +151,9 @@ class MlpBlock(nn.Module): def __call__(self, inputs, dropout_rate=None): """Applies Transformer MlpBlock module.""" cfg = self.config + if dropout_rate is None: + dropout_rate = cfg.dropout_rate + actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( cfg.mlp_dim, @@ -172,12 +172,7 @@ def __call__(self, inputs, dropout_rate=None): )( inputs) x = x * y - if dropout_rate is None: - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout()(x, rate=dropout_rate, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)(x, rate=dropout_rate, deterministic=cfg.deterministic) output = nn.Dense( actual_out_dim, dtype=cfg.dtype, @@ -185,7 +180,7 @@ def __call__(self, inputs, dropout_rate=None): bias_init=cfg.bias_init, )( x) - output = Dropout()( + output = Dropout(rate=dropout_rate)( output, rate=dropout_rate, deterministic=cfg.deterministic) return output @@ -211,16 +206,14 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None): output after transformer encoder block. """ cfg = self.config + if dropout_rate is None: + dropout_rate = cfg.dropout_rate + pre_ln = cfg.pre_ln # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs - if dropout_rate is None: - if cfg.attention_dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate x = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, @@ -233,7 +226,7 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None): deterministic=cfg.deterministic, )(cfg.attention_temp * x, x, mask=encoder_mask) - x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate) x = x + inputs if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -275,17 +268,15 @@ def __call__( output after transformer encoder-decoder block. """ cfg = self.config + if dropout_rate is None: + dropout_rate = cfg.dropout_rate + pre_ln = cfg.pre_ln # Decoder block. assert targets.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets - if dropout_rate is None: - if cfg.attention_dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate x = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, @@ -298,11 +289,8 @@ def __call__( deterministic=cfg.deterministic, decode=cfg.decode, )(cfg.attention_temp * x, x, mask=decoder_mask) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + + x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate) x = x + targets if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -321,7 +309,7 @@ def __call__( deterministic=cfg.deterministic, )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) - y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) + y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic, rate=dropout_rate) y = y + x if not pre_ln: y = nn.LayerNorm(dtype=cfg.dtype)(y) @@ -361,6 +349,9 @@ def __call__(self, output of a transformer encoder. """ cfg = self.config + if dropout_rate is None: + dropout_rate = cfg.dropout_rate + assert inputs.ndim == 2 # (batch, len) # Input Embedding @@ -377,12 +368,7 @@ def __call__(self, x = AddPositionEmbs( config=cfg, decode=False, name="posembed_input")( x, inputs_positions=inputs_positions) - if dropout_rate is None: - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate) x = x.astype(cfg.dtype) @@ -432,6 +418,8 @@ def __call__( output of a transformer decoder. """ cfg = self.config + if dropout_rate is None: + dropout_rate = cfg.dropout_rate assert encoded.ndim == 3 # (batch, len, depth) assert targets.ndim == 2 # (batch, len) @@ -453,12 +441,7 @@ def __call__( y = AddPositionEmbs( config=cfg, decode=cfg.decode, name="posembed_output")( y, inputs_positions=targets_positions) - if dropout_rate is None: - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) + y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic, rate=dropout_rate) y = y.astype(cfg.dtype) @@ -549,7 +532,8 @@ def decode( targets, targets_positions=None, inputs_segmentation=None, - targets_segmentation=None): + targets_segmentation=None, + dropout_rate=None): """Applies Transformer decoder-branch on encoded-input and target. Args: @@ -598,7 +582,8 @@ def decode( targets, targets_positions=targets_positions, decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask) + encoder_decoder_mask=encoder_decoder_mask, + dropout_rate=droput_rate) return logits.astype(self.config.dtype) def __call__(self, diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 367c062cb..193732640 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -259,7 +259,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: Optional[float] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm @@ -286,7 +287,8 @@ def model_fn( targets_positions=targets_positions, inputs_segmentation=inputs_segmentations, targets_segmentation=targets_segmentations, - rngs={'dropout': rng}) + rngs={'dropout': rng}, + dropout_rate=None) return logits_batch, None def _normalize_eval_metrics( From 2c96b884eba3cd8500c8d3fd1de6feb28194fbe3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 4 Jun 2025 20:49:15 +0000 Subject: [PATCH 16/39] modify dockerfile --- docker/Dockerfile | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 76bc5cfe0..72e3a810f 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,7 +5,7 @@ # docker build -t --build-arg framework=pytorch # To build Docker image -FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 +FROM nvidia/cuda:12.9.0-cudnn-devel-ubuntu20.04 # Installing machine packages RUN echo "Setting up machine" @@ -23,8 +23,8 @@ RUN apt-get update && apt-get install -y \ libreadline-dev \ libffi-dev \ curl \ - libbz2-dev \ liblzma-dev \ + libbz2-dev \ vim # Download and install Python 3.11 @@ -56,8 +56,6 @@ RUN echo "Setting up directories for data and experiment_runs" RUN mkdir -p data/ RUN mkdir -p experiment_runs/ -RUN pip install --upgrade pip - # Install Algorithmic efficiency repo RUN pip install --upgrade pip @@ -71,18 +69,18 @@ RUN cd /algorithmic-efficiency && git checkout $branch RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ + && pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \ + && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_cpu]' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ + && pip install -e '.[pytorch_gpu]' \ + && pip install -e '.[jax_cpu]'; \ elif [ "$framework" = "both" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ + && pip install -e '.[pytorch_gpu]' \ + && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \ From 54786a6594bf70388e8791aab2b35c78b7cbf028 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 4 Jun 2025 20:49:52 +0000 Subject: [PATCH 17/39] modify docker build script --- docker/build_docker_images.sh | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 645b81955..6b5e67ceb 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -1,27 +1,40 @@ #!/bin/bash # Bash script to build and push dev docker images to artifact repo # Usage: -# bash build_docker_images.sh -b +# bash build_docker_images.sh -b -f # Make program exit with non-zero exit code if any command fails. set -e -while getopts b: flag +while getopts "b:p:f:" flag; do case "${flag}" in b) GIT_BRANCH=${OPTARG};; + p) PROJECT=${OPTARG};; + f) FRAMEWORK=${OPTARG};; esac done # Artifact repostiory -ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" +if [ "$PROJECT" = "mlcommons-algoperf" ]; then + ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" +else + ARTIFACT_REPO="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo" +fi -if [[ -z ${GIT_BRANCH+x} ]] +if [[ -z ${GIT_BRANCH+x} ]]; then GIT_BRANCH='main' # Set default argument fi -for FRAMEWORK in "jax" "pytorch" "both" +FRAMEWORKS=( "jax" "pythorch" "both" ) + +if [[ -n "$FRAMEWORK" ]]; +then + FRAMEWORKS=("$FRAMEWORK") +fi + +for FRAMEWORK in "${FRAMEWORKS[@]}"; do IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" DOCKER_BUILD_COMMAND="docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH" From 246d68ee96d05d9b4b463dd08ae36e1715f6b3bb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 5 Jun 2025 23:38:41 +0000 Subject: [PATCH 18/39] fsmall fixes --- algoperf/workloads/fastmri/fastmri_jax/models.py | 2 +- algoperf/workloads/imagenet_vit/imagenet_jax/models.py | 2 +- .../workloads/librispeech_conformer/librispeech_jax/models.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index 7ecca2add..b04510297 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -139,9 +139,9 @@ class ConvBlock(nn.Module): dropout_rate: Dropout probability. """ out_channels: int - dropout_rate: float = 0.0 use_tanh: bool use_layer_norm: bool + dropout_rate: float = 0.0 @nn.compact def __call__(self, x, train=True, dropout_rate=None): diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 227f7c297..8ffc0b610 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -162,7 +162,7 @@ class MAPHead(nn.Module): """Multihead Attention Pooling.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 - dropout_rate: 0.0 + dropout_rate: float = 0.0 @nn.compact def __call__(self, x, dropout_rate=None): diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 2ca0fffdc..2d0da15e5 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -196,7 +196,7 @@ class FeedForwardModule(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=dropout_rate): + def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=None): config = self.config if dropout_rate is None: dropout_rate = config.dropout_rate From 0c8dd14d617ff1a642915f34ccd1e504d5a8c0a1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 6 Jun 2025 00:39:59 +0000 Subject: [PATCH 19/39] change docker base image to 12.1.1 --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 72e3a810f..f1fc99550 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,7 +5,7 @@ # docker build -t --build-arg framework=pytorch # To build Docker image -FROM nvidia/cuda:12.9.0-cudnn-devel-ubuntu20.04 +FROM nvidia/cuda:12.1.1-cudnn-devel-ubuntu20.04 # Installing machine packages RUN echo "Setting up machine" From a78fa6642aed752010e76e953d11ee4c54bafddd Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 6 Jun 2025 00:49:04 +0000 Subject: [PATCH 20/39] update base image --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index f1fc99550..9926b0542 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,7 +5,7 @@ # docker build -t --build-arg framework=pytorch # To build Docker image -FROM nvidia/cuda:12.1.1-cudnn-devel-ubuntu20.04 +FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 # Installing machine packages RUN echo "Setting up machine" From f0019ac4ad04fb8bfe2c6430475a7043dec77ff3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Jun 2025 04:23:18 +0000 Subject: [PATCH 21/39] small fix --- .../workloads/librispeech_deepspeech/librispeech_jax/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 3ad31b532..455366e5e 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -115,7 +115,7 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=None): else: input_dropout_rate = config.input_dropout_rate outputs = Dropout( - rate=input_dropout_rate, deterministic=not train, rate=dropout_rate)( + rate=input_dropout_rate, deterministic=not train)( outputs, rate=dropout_rate) return outputs, output_paddings From 3cb012e919374e204f123145ebbe596bf72b4eac Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Jun 2025 04:34:36 +0000 Subject: [PATCH 22/39] remove aux_dropout from submission_runner.py --- submission_runner.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 468a04c7c..bb4a8c6cc 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -229,13 +229,10 @@ def train_once( logging.info('Initializing model.') with profiler.profile('Initializing model'): dropout_rate = None - aux_dropout_rate = None if hasattr(hyperparameters, 'dropout_rate'): dropout_rate = hyperparameters.dropout_rate - if hasattr(hyperparameters, 'aux_dropout_rate'): - aux_dropout_rate = hyperparameters.aux_dropout_rate model_params, model_state = workload.init_model_fn( - model_init_rng, dropout_rate, aux_dropout_rate) + model_init_rng, dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ 'librispeech_conformer', From 5e192dd6397194d1435631752a1c29289b9a9888 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 12 Jun 2025 21:02:11 +0000 Subject: [PATCH 23/39] remove dropout_rate from init_model_fn for all jax workloads --- .../workloads/cifar/cifar_jax/workload.py | 6 +-- .../criteo1tb/criteo1tb_jax/workload.py | 34 ++++++----------- .../workloads/fastmri/fastmri_jax/workload.py | 27 +++++-------- .../imagenet_resnet/imagenet_jax/workload.py | 7 +--- .../imagenet_vit/imagenet_jax/workload.py | 28 +++++--------- .../librispeech_jax/workload.py | 28 +++++--------- .../librispeech_jax/workload.py | 38 ++++++------------- .../workloads/mnist/mnist_jax/workload.py | 7 +--- algoperf/workloads/ogbg/ogbg_jax/workload.py | 29 +++++--------- algoperf/workloads/wmt/wmt_jax/workload.py | 26 ++++--------- 10 files changed, 71 insertions(+), 159 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index ad43bc62f..3f2397f8c 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -81,12 +81,8 @@ def sync_batch_stats( def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: """Dropout is unused.""" - del dropout_rate - del aux_dropout_rate model_cls = getattr(models, 'ResNet18') model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 101e02c15..dcb7b9a57 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -72,7 +72,6 @@ def loss_fn( def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, tabulate: Optional[bool] = False, ) -> spec.ModelInitState: """Only dropout is used.""" @@ -81,27 +80,16 @@ def init_model_fn( else: model_class = models.DlrmSmall - if dropout_rate is None: - self._model = model_class( - vocab_size=self.vocab_size, - num_dense_features=self.num_dense_features, - mlp_bottom_dims=self.mlp_bottom_dims, - mlp_top_dims=self.mlp_top_dims, - embed_dim=self.embed_dim, - use_layer_norm=self.use_layer_norm, - embedding_init_multiplier=self.embedding_init_multiplier) - else: - self._model = model_class( - vocab_size=self.vocab_size, - num_dense_features=self.num_dense_features, - mlp_bottom_dims=self.mlp_bottom_dims, - mlp_top_dims=self.mlp_top_dims, - embed_dim=self.embed_dim, - dropout_rate=dropout_rate, - use_layer_norm=self.use_layer_norm, - embedding_init_multiplier=self.embedding_init_multiplier) - - params_rng, dropout_rng = jax.random.split(rng) + self._model = model_class( + vocab_size=self.vocab_size, + num_dense_features=self.num_dense_features, + mlp_bottom_dims=self.mlp_bottom_dims, + mlp_top_dims=self.mlp_top_dims, + embed_dim=self.embed_dim, + use_layer_norm=self.use_layer_norm, + embedding_init_multiplier=self.embedding_init_multiplier) + + params_rng, _= jax.random.split(rng) init_fake_batch_size = 2 num_categorical_features = 26 num_dense_features = 13 @@ -109,7 +97,7 @@ def init_model_fn( input_shape = (init_fake_batch_size, input_size) init_fn = functools.partial(self._model.init, train=False) initial_variables = jax.jit(init_fn)( - {'params': params_rng, 'dropout': dropout_rng}, + {'params': params_rng,}, jnp.ones(input_shape, jnp.float32)) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 3d891cf8f..bf0acfc8d 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -21,28 +21,19 @@ class FastMRIWorkload(BaseFastMRIWorkload): def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, ) -> spec.ModelInitState: """aux_dropout_rate is unused.""" fake_batch = jnp.zeros((13, 320, 320)) - if dropout_rate is None: - self._model = UNet( - num_pool_layers=self.num_pool_layers, - num_channels=self.num_channels, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm, + self._model = UNet( + num_pool_layers=self.num_pool_layers, + num_channels=self.num_channels, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm, ) - else: - self._model = UNet( - num_pool_layers=self.num_pool_layers, - num_channels=self.num_channels, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm, - dropout_rate=dropout_rate) - - params_rng, dropout_rng = jax.random.split(rng) - variables = jax.jit( - self._model.init)({'params': params_rng, 'dropout': dropout_rng}, + + params_rng, _ = jax.random.split(rng) + init_fn = functools.partial(self._model.init, train=False) + variables = jax.jit(init_fn)({'params': params_rng}, fake_batch) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 4ec3937b8..2a255fee4 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -84,12 +84,7 @@ def sync_batch_stats( def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Dropout is unused.""" - del dropout_rate - del aux_dropout_rate + rng: spec.RandomState,) -> spec.ModelInitState: model_cls = getattr(models, 'ResNet50') if self.use_silu and self.use_gelu: diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 89355ac6e..b8a870de5 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -23,32 +23,22 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def initialized(self, key: spec.RandomState, model: nn.Module) -> spec.ModelInitState: input_shape = (1, 224, 224, 3) - params_rng, dropout_rng = jax.random.split(key) + params_rng, _ = jax.random.split(key) variables = jax.jit( - model.init)({'params': params_rng, 'dropout': dropout_rng}, + model.init)({'params': params_rng}, jnp.ones(input_shape)) model_state, params = pop(variables, "params") return params, model_state def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: - if dropout_rate is None: - self._model = models.ViT( - num_classes=self._num_classes, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - use_map=self.use_map, - **decode_variant('S/16')) - else: - self._model = models.ViT( - dropout_rate=dropout_rate, - num_classes=self._num_classes, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - use_map=self.use_map, - **decode_variant('S/16')) + rng: spec.RandomState) -> spec.ModelInitState: + self._model = models.ViT( + num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + **decode_variant('S/16')) params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 2e082cf07..042dba7f4 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -60,7 +60,6 @@ def attention_temperature(self) -> float: def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, ) -> spec.ModelInitState: """Conformer model init function. @@ -71,21 +70,14 @@ def init_model_fn( activation_function_name = 'gelu' else: activation_function_name = 'swish' - if dropout_rate is None: - model_config = models.ConformerConfig( - attention_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - input_dropout_rate=dropout_rate, - use_specaug=self.use_specaug, - attention_temperature=self.attention_temperature, - use_post_layer_norm=self.use_post_layer_norm, - activation_function_name=activation_function_name) - else: - model_config = models.ConformerConfig( - use_specaug=self.use_specaug, - attention_temperature=self.attention_temperature, - use_post_layer_norm=self.use_post_layer_norm, - activation_function_name=activation_function_name) + model_config = models.ConformerConfig( + attention_residual_dropout_rate=dropout_rate, + feed_forward_residual_dropout_rate=dropout_rate, + input_dropout_rate=dropout_rate, + use_specaug=self.use_specaug, + attention_temperature=self.attention_temperature, + use_post_layer_norm=self.use_post_layer_norm, + activation_function_name=activation_function_name) self._model = models.Conformer(model_config) input_shape = [(320000,), (320000,)] @@ -93,8 +85,8 @@ def init_model_fn( model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) - params_rng, dropout_rng = jax.random.split(rng, 2) - variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, + params_rng, _ = jax.random.split(rng, 2) + variables = model_init_fn({'params': params_rng}, *fake_input_batch) model_state, params = pop(variables, "params") diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 825b470db..2213f189e 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -17,40 +17,26 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: """Deepspeech model init function. """ - if dropout_rate is None: - model_config = models.DeepspeechConfig( - use_specaug=self.use_specaug, - use_tanh=self.use_tanh, - enable_residual_connections=self.enable_residual_connections, - enable_decoder_layer_norm=self.enable_decoder_layer_norm, - layernorm_everywhere=self.layernorm_everywhere, - freq_mask_count=self.freq_mask_count, - time_mask_count=self.time_mask_count, - ) - else: - model_config = models.DeepspeechConfig( - feed_forward_dropout_rate=dropout_rate, - use_specaug=self.use_specaug, - input_dropout_rate=dropout_rate, - use_tanh=self.use_tanh, - enable_residual_connections=self.enable_residual_connections, - enable_decoder_layer_norm=self.enable_decoder_layer_norm, - layernorm_everywhere=self.layernorm_everywhere, - freq_mask_count=self.freq_mask_count, - time_mask_count=self.time_mask_count, - ) + model_config = models.DeepspeechConfig( + use_specaug=self.use_specaug, + use_tanh=self.use_tanh, + enable_residual_connections=self.enable_residual_connections, + enable_decoder_layer_norm=self.enable_decoder_layer_norm, + layernorm_everywhere=self.layernorm_everywhere, + freq_mask_count=self.freq_mask_count, + time_mask_count=self.time_mask_count, + ) self._model = models.Deepspeech(model_config) input_shape = [(320000,), (320000,)] fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) - params_rng, dropout_rng = jax.random.split(rng, 2) - variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, + params_rng, _ = jax.random.split(rng, 2) + variables = model_init_fn({'params': params_rng,}, *fake_input_batch) model_state = variables[ diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 5a4382da1..5f3fdcf78 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -34,12 +34,7 @@ class MnistWorkload(BaseMnistWorkload): def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Dropout is unused.""" - del dropout_rate - del aux_dropout_rate + rng: spec.RandomState) -> spec.ModelInitState: init_val = jnp.ones((1, 28, 28, 1), jnp.float32) self._model = _Model() initial_params = self._model.init({'params': rng}, init_val, diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index 3becd5599..aaa5b4064 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -19,25 +19,14 @@ class OgbgWorkload(BaseOgbgWorkload): def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is unused.""" - rng, params_rng, dropout_rng = jax.random.split(rng, 3) - if dropout_rate is None: - self._model = models.GNN( - self._num_outputs, - activation_fn_name=self.activation_fn_name, - hidden_dims=self.hidden_dims, - latent_dim=self.latent_dim, - num_message_passing_steps=self.num_message_passing_steps) - else: - self._model = models.GNN( - self._num_outputs, - dropout_rate=dropout_rate, - activation_fn_name=self.activation_fn_name, - hidden_dims=self.hidden_dims, - latent_dim=self.latent_dim, - num_message_passing_steps=self.num_message_passing_steps) + rng: spec.RandomState) -> spec.ModelInitState: + rng, params_rng = jax.random.split(rng, 2) + self._model = models.GNN( + self._num_outputs, + activation_fn_name=self.activation_fn_name, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps) init_fn = jax.jit(functools.partial(self._model.init, train=False)) fake_batch = jraph.GraphsTuple( n_node=jnp.asarray([1]), @@ -47,7 +36,7 @@ def init_model_fn( globals=jnp.zeros((1, self._num_outputs)), senders=jnp.asarray([0]), receivers=jnp.asarray([0])) - params = init_fn({'params': params_rng, 'dropout': dropout_rng}, fake_batch) + params = init_fn({'params': params_rng}, fake_batch) params = params['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 193732640..9e109dc86 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -208,8 +208,7 @@ def translate_and_calculate_bleu(self, def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = 0.0) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: init_fake_batch_size = 2 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) @@ -221,26 +220,17 @@ def init_model_fn( else: raise ValueError(f'Unknown activation function {self.activation}.') - if dropout_rate is None: - model_config = models.TransformerConfig( - pre_ln=self.pre_ln, - attention_temp=self.attention_temp, - activation=activation, - glu=self.glu) - else: - model_config = models.TransformerConfig( - dropout_rate=dropout_rate, - attention_dropout_rate=dropout_rate, - pre_ln=self.pre_ln, - attention_temp=self.attention_temp, - activation=activation, - glu=self.glu) + model_config = models.TransformerConfig( + pre_ln=self.pre_ln, + attention_temp=self.attention_temp, + activation=activation, + glu=self.glu) self._train_model = models.Transformer(model_config) eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) - params_rng, dropout_rng = jax.random.split(rng) + params_rng, _ = jax.random.split(rng) initial_variables = jax.jit( - self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, + self._eval_model.init)({'params': params_rng}, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) From 23828cdb0d54207a6714263b4d9f44531011f375 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 12 Jun 2025 21:03:11 +0000 Subject: [PATCH 24/39] remove dropout from model initialization call in submission_runner.py --- submission_runner.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index bb4a8c6cc..d076a1043 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -228,11 +228,8 @@ def train_once( global_batch_size=global_batch_size) logging.info('Initializing model.') with profiler.profile('Initializing model'): - dropout_rate = None - if hasattr(hyperparameters, 'dropout_rate'): - dropout_rate = hyperparameters.dropout_rate model_params, model_state = workload.init_model_fn( - model_init_rng, dropout_rate) + model_init_rng) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ 'librispeech_conformer', From 86b86245a3754d59b8e707eecacd3be0477419d7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 12 Jun 2025 21:28:35 +0000 Subject: [PATCH 25/39] remove dropout check for None and use default instead if not passed --- .../criteo1tb/criteo1tb_jax/models.py | 11 +++---- .../workloads/fastmri/fastmri_jax/models.py | 12 +++----- .../imagenet_vit/imagenet_jax/models.py | 26 +++++------------ .../librispeech_jax/models.py | 27 +++++------------ .../librispeech_jax/models.py | 21 ++++---------- algoperf/workloads/ogbg/ogbg_jax/models.py | 7 ++--- algoperf/workloads/wmt/wmt_jax/models.py | 29 +++++++------------ 7 files changed, 41 insertions(+), 92 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index b7af15208..57cb7f2d9 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -7,6 +7,7 @@ from algoperf.jax_utils import Dropout +DROPOUT_RATE = 0.0 class DLRMResNet(nn.Module): """Define a DLRMResNet model. @@ -24,14 +25,12 @@ class DLRMResNet(nn.Module): mlp_bottom_dims: Sequence[int] = (256, 256, 256) mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) embed_dim: int = 128 - dropout_rate: float = 0.0 + dropout_rate: float = DROPOUT_RATE use_layer_norm: bool = False # Unused. embedding_init_multiplier: float = None # Unused @nn.compact - def __call__(self, x, train, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.dropout_rate + def __call__(self, x, train, dropout_rate=DROPOUT_RATE): bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -155,9 +154,7 @@ class DlrmSmall(nn.Module): embedding_init_multiplier: float = None @nn.compact - def __call__(self, x, train, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.dropout_rate + def __call__(self, x, train, dropout_rate=DROPOUT_RATE): bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index b04510297..5850defa7 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -21,6 +21,7 @@ from algoperf.jax_utils import Dropout +DROPOUT_RATE = 0.0 def _instance_norm2d(x, axes, epsilon=1e-5): # promote x to at least float32, this avoids half precision computation @@ -58,15 +59,12 @@ class UNet(nn.Module): num_channels: int = 32 num_pool_layers: int = 4 out_channels = 1 - dropout_rate: float = 0.0 + dropout_rate: float = DROPOUT_RATE use_tanh: bool = False use_layer_norm: bool = False @nn.compact - def __call__(self, x, train=True, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.dropout_rate - + def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE): # pylint: disable=invalid-name _ConvBlock = functools.partial( ConvBlock, @@ -144,7 +142,7 @@ class ConvBlock(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, x, train=True, dropout_rate=None): + def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE): """Forward function. Note: Pytorch is NCHW and jax/flax is NHWC. Args: @@ -153,8 +151,6 @@ def __call__(self, x, train=True, dropout_rate=None): Returns: jnp.array: Output tensor of shape `(N, H, W, out_channels)`. """ - if dropout_rate is None: - dropout_rate = self.dropout_rate x = nn.Conv( features=self.out_channels, kernel_size=(3, 3), diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 8ffc0b610..7c5d7bd26 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -13,6 +13,7 @@ from algoperf import spec from algoperf.jax_utils import Dropout +DROPOUT_RATE = 0.0 def posemb_sincos_2d(h: int, w: int, @@ -36,17 +37,14 @@ class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim. use_glu: bool = False - dropout_rate: float = 0.0 + dropout_rate: float = DROPOUT_RATE @nn.compact def __call__(self, x: spec.Tensor, train: bool = True, - dropout_rate=None) -> spec.Tensor: + dropout_rate=DROPOUT_RATE) -> spec.Tensor: """Applies Transformer MlpBlock module.""" - if dropout_rate is None: - dropout_rate = self.dropout_rate - inits = { 'kernel_init': nn.initializers.xavier_uniform(), 'bias_init': nn.initializers.normal(stddev=1e-6), @@ -78,8 +76,6 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate) -> spec.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate if not self.use_post_layer_norm: y = nn.LayerNorm(name='LayerNorm_0')(x) @@ -136,11 +132,7 @@ class Encoder(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, - train: bool = True, - dropout_rate=None) -> spec.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate - + train: bool = True) -> spec.Tensor: # Input Encoder for lyr in range(self.depth): block = Encoder1DBlock( @@ -165,9 +157,7 @@ class MAPHead(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, x, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.dropout_rate + def __call__(self, x, dropout_rate=DROPOUT_RATE): n, _, d = x.shape probe = self.param('probe', nn.initializers.xavier_uniform(), (1, 1, d), @@ -194,7 +184,7 @@ class ViT(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 rep_size: Union[int, bool] = True - dropout_rate: Optional[float] = 0.0 + dropout_rate: [float] = DROPOUT_RATE reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True use_glu: bool = False @@ -212,9 +202,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False, - dropout_rate=None) -> spec.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate + dropout_rate=DROPOUT_RATE) -> spec.Tensor: # Patch extraction x = nn.Conv( self.width, diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 2d0da15e5..f7beed914 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -28,6 +28,7 @@ spectrum_augmenter from algoperf.jax_utils import Dropout +DROPOUT_RATE = 0.1 @struct.dataclass class ConformerConfig: @@ -37,11 +38,7 @@ class ConformerConfig: encoder_dim: int = 512 num_attention_heads: int = 8 num_encoder_layers: int = 4 - dropout_rate: float = 0.1 - attention_residual_dropout_rate: Optional[float] = 0.0 - conv_residual_dropout_rate: Optional[float] = 0.0 - feed_forward_dropout_rate: float = 0.0 - feed_forward_residual_dropout_rate: Optional[float] = 0.0 + dropout_rate: float = DROPOUT_RATE convolution_kernel_size: int = 5 feed_forward_expansion_factor: int = 4 freq_mask_count: int = 2 @@ -96,12 +93,8 @@ class Subsample(nn.Module): input_dropout_rate: dropout rate for inputs. """ encoder_dim: int = 0 - dropout_rate: float = 0.0 - @nn.compact - def __call__(self, inputs, input_paddings, train, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.dropout_rate + def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): output_paddings = input_paddings outputs = jnp.expand_dims(inputs, axis=-1) @@ -196,7 +189,7 @@ class FeedForwardModule(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=None): + def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_RATE): config = self.config if dropout_rate is None: dropout_rate = config.dropout_rate @@ -388,10 +381,8 @@ class MultiHeadedSelfAttention(nn.Module): config: ConformerConfig = None @nn.compact - def __call__(self, inputs, paddings, train, dropout_rate=None): + def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE): config = self.config - if dropout_rate is None: - dropout_rate = config.dropout_rate mask_paddings = 1 - paddings attention_mask = nn.make_attention_mask( @@ -527,10 +518,8 @@ def __call__(self, train, update_batch_norm, use_running_average_bn, - dropout_rate=None): + dropout_rate=DROPOUT_RATE): config = self.config - if dropout_rate is None: - dropout_rate = config.dropout_rate inputs = LayerNorm(dim=config.encoder_dim)(inputs) input_gated1 = nn.Dense( @@ -603,7 +592,7 @@ def __call__(self, train, update_batch_norm, use_running_average, - dropout_rate=None): + dropout_rate=DROPOUT_RATE): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) @@ -658,7 +647,7 @@ def __call__(self, train, update_batch_norm: Optional[bool] = None, use_running_average_bn: Optional[bool] = None, - dropout_rate: Optional[float] = None): + dropout_rate: float = DROPOUT_RATE: config = self.config outputs = inputs diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 455366e5e..84ba58ee2 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -1,4 +1,4 @@ -r"""Deepspeech. +"""Deepspeech. This model uses a deepspeech2 network to convert speech to text. paper : https://arxiv.org/abs/1512.02595 @@ -31,6 +31,8 @@ CarryHistory = Any Output = Any +DROPOUT_RATE=0.1 + @struct.dataclass class DeepspeechConfig: @@ -52,10 +54,6 @@ class DeepspeechConfig: use_dynamic_time_mask_max_frames: bool = True batch_norm_momentum: float = 0.999 batch_norm_epsilon: float = 0.001 - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.1. - feed_forward_dropout_rate: Optional[float] = 0.1 enable_residual_connections: bool = True enable_decoder_layer_norm: bool = True bidirectional: bool = True @@ -73,11 +71,8 @@ class Subsample(nn.Module): config: DeepspeechConfig @nn.compact - def __call__(self, inputs, output_paddings, train, dropout_rate=None): + def __call__(self, inputs, output_paddings, train, dropout_rate=DROPOUT_RATE): config = self.config - if dropout_rate is None: - dropout_rate = config.dropout_rate - outputs = jnp.expand_dims(inputs, axis=-1) outputs, output_paddings = Conv2dSubsampling( @@ -196,9 +191,7 @@ def __call__(self, inputs, input_paddings=None, train=False, - dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.config.feed_forward_dropout_rate + dropout_rate=DROPOUT_RATE): padding_mask = jnp.expand_dims(1 - input_paddings, -1) config = self.config @@ -479,10 +472,8 @@ def setup(self): ) @nn.compact - def __call__(self, inputs, input_paddings, train, dropout_rate=None): + def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): config = self.config - if dropout_rate is None: - dropout_rate = config.dropout_rate outputs = inputs output_paddings = input_paddings diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index f6cb1c490..59d989284 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -8,6 +8,7 @@ from algoperf.jax_utils import Dropout +DROPOUT_RATE=0.1 def _make_embed(latent_dim, name): @@ -41,15 +42,11 @@ class GNN(nn.Module): num_outputs: int latent_dim: int = 256 hidden_dims: Tuple[int] = (256,) - # If None, defaults to 0.1. - dropout_rate: Optional[float] = 0.1 num_message_passing_steps: int = 5 activation_fn_name: str = 'relu' @nn.compact - def __call__(self, graph, train, dropout_rate=None): - if dropout_rate is not None: - dropout_rate = self.dropout_rate + def __call__(self, graph, train, dropout_rate=DROPOUT_RATE): dropout = Dropout(dropout_rate, deterministic=not train)(dropout_rate) graph = graph._replace( diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index 3947a1b81..e262214ac 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -14,6 +14,8 @@ from algoperf.jax_utils import Dropout +DROPOUT_RATE = 0.1 + @struct.dataclass class TransformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" @@ -28,7 +30,6 @@ class TransformerConfig: max_len: int = 256 activation: Callable = nn.relu glu: bool = False - dropout_rate: Optional[float] = 0.1 attention_temp: float = 1.0 deterministic: bool = False decode: bool = False @@ -148,11 +149,9 @@ class MlpBlock(nn.Module): out_dim: Optional[int] = None @nn.compact - def __call__(self, inputs, dropout_rate=None): + def __call__(self, inputs, dropout_rate=DROPOUT_RATE): """Applies Transformer MlpBlock module.""" cfg = self.config - if dropout_rate is None: - dropout_rate = cfg.dropout_rate actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( @@ -195,7 +194,7 @@ class Encoder1DBlock(nn.Module): config: TransformerConfig @nn.compact - def __call__(self, inputs, encoder_mask=None, dropout_rate=None): + def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE): """Applies Encoder1DBlock module. Args: @@ -206,8 +205,6 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None): output after transformer encoder block. """ cfg = self.config - if dropout_rate is None: - dropout_rate = cfg.dropout_rate pre_ln = cfg.pre_ln @@ -254,7 +251,7 @@ def __call__( encoded, decoder_mask=None, encoder_decoder_mask=None, - dropout_rate=None, + dropout_rate=DROPOUT_RATE, ): """Applies EncoderDecoder1DBlock module. @@ -268,8 +265,6 @@ def __call__( output after transformer encoder-decoder block. """ cfg = self.config - if dropout_rate is None: - dropout_rate = cfg.dropout_rate pre_ln = cfg.pre_ln @@ -337,7 +332,7 @@ def __call__(self, inputs, inputs_positions=None, encoder_mask=None, - dropout_rate=None): + dropout_rate=DROPOUT_RATE): """Applies Transformer model on the inputs. Args: @@ -349,8 +344,6 @@ def __call__(self, output of a transformer encoder. """ cfg = self.config - if dropout_rate is None: - dropout_rate = cfg.dropout_rate assert inputs.ndim == 2 # (batch, len) @@ -403,7 +396,7 @@ def __call__( targets_positions=None, decoder_mask=None, encoder_decoder_mask=None, - dropout_rate=None, + dropout_rate=DROPOUT_RATE, ): """Applies Transformer model on the inputs. @@ -418,8 +411,6 @@ def __call__( output of a transformer decoder. """ cfg = self.config - if dropout_rate is None: - dropout_rate = cfg.dropout_rate assert encoded.ndim == 3 # (batch, len, depth) assert targets.ndim == 2 # (batch, len) @@ -495,7 +486,7 @@ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, - dropout_rate=None): + dropout_rate=DROPOUT_RATE): """Applies Transformer encoder-branch on the inputs. Args: @@ -533,7 +524,7 @@ def decode( targets_positions=None, inputs_segmentation=None, targets_segmentation=None, - dropout_rate=None): + dropout_rate=DROPOUT_RATE): """Applies Transformer decoder-branch on encoded-input and target. Args: @@ -593,7 +584,7 @@ def __call__(self, targets_positions=None, inputs_segmentation=None, targets_segmentation=None, - dropout_rate=None): + dropout_rate=DROPOUT_RATE): """Applies Transformer model on the inputs. Args: From 05bff916dee7de6852afc6d95e2564ad57aa77ef Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 13 Jun 2025 20:45:13 +0000 Subject: [PATCH 26/39] fix to model_fn default dropout value --- algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py | 2 +- algoperf/workloads/fastmri/fastmri_jax/workload.py | 2 +- algoperf/workloads/imagenet_vit/imagenet_jax/workload.py | 2 +- .../workloads/librispeech_conformer/librispeech_jax/workload.py | 2 +- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- algoperf/workloads/ogbg/ogbg_jax/workload.py | 2 +- algoperf/workloads/wmt/wmt_jax/workload.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index dcb7b9a57..cb7e8cf9f 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -115,7 +115,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm inputs = augmented_and_preprocessed_input_batch['inputs'] diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index bf0acfc8d..acdf077e1 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -52,7 +52,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index b8a870de5..08a8f4eb1 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -57,7 +57,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 042dba7f4..8d966ef87 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -109,7 +109,7 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool, use_running_average_bn: Optional[bool] = None, - dropout_rate: Optional[float] = None, + dropout_rate: Optional[float] = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 2213f189e..2bb119439 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -57,7 +57,7 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool, use_running_average_bn: Optional[bool] = None, - dropout_rate: Optional[bool] = None + dropout_rate: Optional[bool] = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index aaa5b4064..e03252ed9 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -53,7 +53,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: Optional[float]) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del update_batch_norm # No BN in the GNN model. if model_state is not None: diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 9e109dc86..9548f5b7e 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -250,7 +250,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: Optional[float] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: [float] = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm From f7d99a62670e8c525eef295f423b47f2026f5a38 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 17 Jun 2025 21:17:19 +0000 Subject: [PATCH 27/39] fixes --- algoperf/jax_utils.py | 8 ++++---- .../librispeech_conformer/librispeech_jax/models.py | 9 ++------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py index 3ca3f1bfc..c4904dc75 100644 --- a/algoperf/jax_utils.py +++ b/algoperf/jax_utils.py @@ -14,8 +14,8 @@ class Dropout(Module): """Create a dropout layer. Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. The reference dropout implementation is modified support changes to dropout rate during training by: - 1) adding rate argument to the __call__ method - 2) removing the if-else condition to check for edge cases, which will trigger a recompile for jitted code + 1) adding rate argument to the __call__ method. + 2) removing the if-else condition to check for edge cases, which will trigger a recompile for jitted code. .. note:: When using :meth:`Module.apply() `, make sure @@ -82,8 +82,8 @@ def __call__( deterministic = merge_param("deterministic", self.deterministic, deterministic) # Override self.rate if rate is passed to __call__ - if not (self.rate is not None and rate is not None): - rate = merge_param("rate", self.rate, rate) + if rate is None: + rate = self.rate if self.legacy: if rate == 0.0: diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index f7beed914..0de6b1449 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -221,12 +221,7 @@ def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_ inputs) inputs = inputs * padding_mask - if config.feed_forward_residual_dropout_rate is None: - feed_forward_residual_dropout_rate = 0.1 - else: - feed_forward_residual_dropout_rate = ( - config.feed_forward_residual_dropout_rate) - inputs = Dropout(rate=feed_forward_residual_dropout_rate)( + inputs = Dropout(rate=dropout_rate)( inputs, deterministic=not train) return inputs @@ -401,7 +396,7 @@ def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE): use_bias=True, broadcast_dropout=False, attention_fn=attention_fn, - dropout_rate=config.attention_dropout_rate, + dropout_rate=dropout_rate, deterministic=not train)( inputs_q=inputs, mask=attention_mask) From 7c430227152b8951fb21454960f71169ab00eb09 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 20:12:45 +0000 Subject: [PATCH 28/39] fixes to ogbg and fastmri --- algoperf/workloads/fastmri/fastmri_jax/workload.py | 3 ++- algoperf/workloads/ogbg/ogbg_jax/models.py | 11 +++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index acdf077e1..b8067cbad 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -11,6 +11,7 @@ from algoperf import param_utils from algoperf import spec import algoperf.random_utils as prng +from algoperf.workloads.fastmri.fastmri_jax.models import DROPOUT_RATE from algoperf.workloads.fastmri.fastmri_jax.models import UNet from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload @@ -52,7 +53,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 59d989284..d51ca2f20 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -18,7 +18,7 @@ def make_fn(inputs): return make_fn -def _make_mlp(hidden_dims, dropout, activation_fn): +def _make_mlp(hidden_dims, activation_fn, train, dropout_rate=DROPOUT_RATE): """Creates a MLP with specified dimensions.""" @jraph.concatenated_args @@ -28,7 +28,7 @@ def make_fn(inputs): x = nn.Dense(features=dim)(x) x = nn.LayerNorm()(x) x = activation_fn(x) - x = dropout(x) + x = Dropout(rate=dropout_rate, deterministic=not train)(x, rate=dropout_rate) return x return make_fn @@ -47,7 +47,6 @@ class GNN(nn.Module): @nn.compact def __call__(self, graph, train, dropout_rate=DROPOUT_RATE): - dropout = Dropout(dropout_rate, deterministic=not train)(dropout_rate) graph = graph._replace( globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) @@ -70,11 +69,11 @@ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE): for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork( update_edge_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate), update_node_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate), update_global_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn)) + self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate)) graph = net(graph) From 894f4fb50f5bfdf0e4d2e197cf090e507a05fc15 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 20:29:37 +0000 Subject: [PATCH 29/39] fixes to fastmri and deepspeech --- algoperf/workloads/fastmri/fastmri_jax/workload.py | 11 +++++------ .../librispeech_conformer/librispeech_jax/models.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index b8067cbad..ccf9c6bad 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -58,12 +58,11 @@ def model_fn( del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN - if train: - logits = self._model.apply({'params': params}, - augmented_and_preprocessed_input_batch['inputs'], - rngs={'dropout': rng}, - train=train, - dropout_rate=dropout_rate) + logits = self._model.apply({'params': params}, + augmented_and_preprocessed_input_batch['inputs'], + rngs={'dropout': rng}, + train=train, + dropout_rate=dropout_rate) return logits, None # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 0de6b1449..366e42195 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -642,7 +642,7 @@ def __call__(self, train, update_batch_norm: Optional[bool] = None, use_running_average_bn: Optional[bool] = None, - dropout_rate: float = DROPOUT_RATE: + dropout_rate: float = DROPOUT_RATE): config = self.config outputs = inputs From 0bcf484282777f39d01b64e41ebba773aef1c913 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 20:36:49 +0000 Subject: [PATCH 30/39] fixes to conformer vit --- algoperf/workloads/imagenet_vit/imagenet_jax/models.py | 4 ++-- .../librispeech_conformer/librispeech_jax/models.py | 3 --- .../librispeech_conformer/librispeech_jax/workload.py | 3 --- .../librispeech_deepspeech/librispeech_jax/models.py | 6 +----- algoperf/workloads/wmt/wmt_jax/models.py | 2 +- 5 files changed, 4 insertions(+), 14 deletions(-) diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 7c5d7bd26..091a3473e 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -125,14 +125,14 @@ class Encoder(nn.Module): depth: int mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 - dropout_rate: float = 0.0 use_glu: bool = False use_post_layer_norm: bool = False @nn.compact def __call__(self, x: spec.Tensor, - train: bool = True) -> spec.Tensor: + train: bool = True, + dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: # Input Encoder for lyr in range(self.depth): block = Encoder1DBlock( diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 366e42195..bf0eb813e 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -38,7 +38,6 @@ class ConformerConfig: encoder_dim: int = 512 num_attention_heads: int = 8 num_encoder_layers: int = 4 - dropout_rate: float = DROPOUT_RATE convolution_kernel_size: int = 5 feed_forward_expansion_factor: int = 4 freq_mask_count: int = 2 @@ -191,8 +190,6 @@ class FeedForwardModule(nn.Module): @nn.compact def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_RATE): config = self.config - if dropout_rate is None: - dropout_rate = config.dropout_rate inputs = LayerNorm(dim=config.encoder_dim)(inputs) inputs = nn.Dense( diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 8d966ef87..eec707e5f 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -71,9 +71,6 @@ def init_model_fn( else: activation_function_name = 'swish' model_config = models.ConformerConfig( - attention_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - input_dropout_rate=dropout_rate, use_specaug=self.use_specaug, attention_temperature=self.attention_temperature, use_post_layer_norm=self.use_post_layer_norm, diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 84ba58ee2..b47b1359a 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -105,12 +105,8 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=DROPOUT_RATE): kernel_init=nn.initializers.xavier_uniform())( outputs) - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate outputs = Dropout( - rate=input_dropout_rate, deterministic=not train)( + rate=dropout_rate, deterministic=not train)( outputs, rate=dropout_rate) return outputs, output_paddings diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index e262214ac..38f76db80 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -574,7 +574,7 @@ def decode( targets_positions=targets_positions, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, - dropout_rate=droput_rate) + dropout_rate=dropout_rate) return logits.astype(self.config.dtype) def __call__(self, From 73c2276cb1907534f16b76f82e95c95400d04f8f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 20:48:56 +0000 Subject: [PATCH 31/39] conformer and vit fix for dropout refactor --- algoperf/workloads/imagenet_vit/imagenet_jax/models.py | 3 +-- .../librispeech_conformer/librispeech_jax/models.py | 7 +++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 091a3473e..a78a5e791 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -226,8 +226,7 @@ def __call__(self, num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - name='Transformer', - dropout_rate=dropout_rate)( + name='Transformer',)( x, train=not train, dropout_rate=dropout_rate) if self.use_map: diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index bf0eb813e..1c2d79e15 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -206,8 +206,8 @@ def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_ 'config.activation_function_name values, recieved ' f'{config.activation_function_name}') inputs = activation_fn(inputs) - inputs = Dropout(rate=config.feed_forward_dropout_rate)( - inputs, deterministic=not train) + inputs = Dropout(rate=dropout_rate)( + inputs, deterministic=not train, rate=dropout_rate) inputs = inputs * padding_mask @@ -665,8 +665,7 @@ def __call__(self, outputs, output_paddings = self.specaug(outputs, output_paddings) outputs, output_paddings = Subsample( - encoder_dim=config.encoder_dim, - dropout_rate=dropout_rate)( + encoder_dim=config.encoder_dim,)( outputs, output_paddings, train, dropout_rate=dropout_rate) # Run the conformer encoder layers. From 5ff94d23242a6613dc5d62579a9a4fe44d017eec Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:12:10 +0000 Subject: [PATCH 32/39] wmt fixes --- .../imagenet_vit/imagenet_jax/models.py | 54 +++++++++---------- algoperf/workloads/wmt/wmt_jax/workload.py | 2 +- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index a78a5e791..716bd4239 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -85,7 +85,7 @@ def __call__(self, deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = Dropout(dropout_rate)(y, train, dropout_rate=dropout_rate) + y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y y = nn.LayerNorm(name='LayerNorm_2')(x) @@ -121,33 +121,31 @@ def __call__(self, class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation.""" - depth: int - mlp_dim: Optional[int] = None # Defaults to 4x input dim. - num_heads: int = 12 - use_glu: bool = False - use_post_layer_norm: bool = False - - @nn.compact - def __call__(self, - x: spec.Tensor, - train: bool = True, - dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: - # Input Encoder - for lyr in range(self.depth): - block = Encoder1DBlock( - name=f'encoderblock_{lyr}', - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=dropout_rate)( - dropout_rate=dropout_rate) - x = block(x, train) - if not self.use_post_layer_norm: - return nn.LayerNorm(name='encoder_layernorm')(x) - else: - return x + """Transformer Model Encoder for sequence to sequence translation.""" + + depth: int + mlp_dim: Optional[int] = None # Defaults to 4x input dim. + num_heads: int = 12 + use_glu: bool = False + use_post_layer_norm: bool = False + + @nn.compact + def __call__( + self, x: spec.Tensor, train: bool = True, dropout_rate: float = DROPOUT_RATE + ) -> spec.Tensor: + # Input Encoder + for lyr in range(self.depth): + x = Encoder1DBlock( + name=f"encoderblock_{lyr}", + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + )(x, train=train, dropout_rate=dropout_rate) + if not self.use_post_layer_norm: + return nn.LayerNorm(name="encoder_layernorm")(x) + else: + return x class MAPHead(nn.Module): diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 9548f5b7e..24d4852b8 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -278,7 +278,7 @@ def model_fn( inputs_segmentation=inputs_segmentations, targets_segmentation=targets_segmentations, rngs={'dropout': rng}, - dropout_rate=None) + dropout_rate=dropout_rate) return logits_batch, None def _normalize_eval_metrics( From 4e69255642807cbb38c9d3390898f9713e59ece7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:16:42 +0000 Subject: [PATCH 33/39] formatting --- algoperf/jax_utils.py | 104 +++++++++--------- .../workloads/cifar/cifar_jax/workload.py | 4 +- .../criteo1tb/criteo1tb_jax/models.py | 1 + .../criteo1tb/criteo1tb_jax/workload.py | 12 +- .../workloads/fastmri/fastmri_jax/models.py | 1 + .../workloads/fastmri/fastmri_jax/workload.py | 16 +-- .../imagenet_resnet/imagenet_jax/workload.py | 3 +- .../imagenet_vit/imagenet_jax/models.py | 58 +++++----- .../imagenet_vit/imagenet_jax/workload.py | 12 +- .../librispeech_jax/models.py | 15 ++- .../librispeech_jax/workload.py | 3 +- .../librispeech_jax/models.py | 4 +- .../librispeech_jax/workload.py | 9 +- .../workloads/mnist/mnist_jax/workload.py | 4 +- algoperf/workloads/ogbg/ogbg_jax/models.py | 22 +++- algoperf/workloads/ogbg/ogbg_jax/workload.py | 7 +- algoperf/workloads/wmt/wmt_jax/models.py | 20 ++-- algoperf/workloads/wmt/wmt_jax/workload.py | 7 +- 18 files changed, 164 insertions(+), 138 deletions(-) diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py index c4904dc75..369eb1b1a 100644 --- a/algoperf/jax_utils.py +++ b/algoperf/jax_utils.py @@ -1,17 +1,19 @@ from collections.abc import Sequence -import jax -import jax.numpy as jnp -from jax import lax, random - import flax.linen as nn -from flax.linen.module import Module, compact, merge_param +from flax.linen.module import compact +from flax.linen.module import merge_param +from flax.linen.module import Module from flax.typing import PRNGKey +import jax +from jax import lax +from jax import random +import jax.numpy as jnp # Custom Layers class Dropout(Module): - """Create a dropout layer. + """Create a dropout layer. Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. The reference dropout implementation is modified support changes to dropout rate during training by: 1) adding rate argument to the __call__ method. @@ -51,21 +53,21 @@ class Dropout(Module): rng_collection: the rng collection name to use when requesting an rng key. """ - rate: float | None = None - broadcast_dims: Sequence[int] = () - deterministic: bool | None = None - rng_collection: str = "dropout" - legacy: bool = True - - @compact - def __call__( - self, - inputs, - deterministic: bool | None = None, - rate: float | None = None, - rng: PRNGKey | None = None, - ): - """Applies a random dropout mask to the input. + rate: float | None = None + broadcast_dims: Sequence[int] = () + deterministic: bool | None = None + rng_collection: str = "dropout" + legacy: bool = True + + @compact + def __call__( + self, + inputs, + deterministic: bool | None = None, + rate: float | None = None, + rng: PRNGKey | None = None, + ): + """Applies a random dropout mask to the input. Args: inputs: the inputs that should be randomly masked. @@ -79,40 +81,44 @@ def __call__( Returns: The masked inputs reweighted to preserve mean. """ - deterministic = merge_param("deterministic", self.deterministic, deterministic) + deterministic = merge_param("deterministic", + self.deterministic, + deterministic) - # Override self.rate if rate is passed to __call__ - if rate is None: - rate = self.rate + # Override self.rate if rate is passed to __call__ + if rate is None: + rate = self.rate - if self.legacy: - if rate == 0.0: - return inputs + if self.legacy: + if rate == 0.0: + return inputs - # Prevent gradient NaNs in 1.0 edge-case. - if rate == 1.0: - return jnp.zeros_like(inputs) + # Prevent gradient NaNs in 1.0 edge-case. + if rate == 1.0: + return jnp.zeros_like(inputs) - if deterministic: - return inputs + if deterministic: + return inputs - keep_prob = 1.0 - rate - if rng is None: - rng = self.make_rng(self.rng_collection) - broadcast_shape = list(inputs.shape) - for dim in self.broadcast_dims: - broadcast_shape[dim] = 1 - mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) - mask = jnp.broadcast_to(mask, inputs.shape) - return lax.select(mask, inputs, jnp.zeros_like(inputs)) + keep_prob = 1.0 - rate + if rng is None: + rng = self.make_rng(self.rng_collection) + broadcast_shape = list(inputs.shape) + for dim in self.broadcast_dims: + broadcast_shape[dim] = 1 + mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) + mask = jnp.broadcast_to(mask, inputs.shape) + return lax.select(mask, inputs, jnp.zeros_like(inputs)) # Utilities for debugging def print_jax_model_summary(model, fake_inputs): - """Prints a summary of the jax module.""" - tabulate_fn = nn.tabulate( - model, - jax.random.PRNGKey(0), - console_kwargs={"force_terminal": False, "force_jupyter": False, "width": 240}, - ) - print(tabulate_fn(fake_inputs, train=False)) + """Prints a summary of the jax module.""" + tabulate_fn = nn.tabulate( + model, + jax.random.PRNGKey(0), + console_kwargs={ + "force_terminal": False, "force_jupyter": False, "width": 240 + }, + ) + print(tabulate_fn(fake_inputs, train=False)) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index 3f2397f8c..c6cc50fbf 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -79,9 +79,7 @@ def sync_batch_stats( new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Dropout is unused.""" model_cls = getattr(models, 'ResNet18') model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index 57cb7f2d9..4a91a80b8 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -9,6 +9,7 @@ DROPOUT_RATE = 0.0 + class DLRMResNet(nn.Module): """Define a DLRMResNet model. diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index cb7e8cf9f..d84d18d5c 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -89,16 +89,17 @@ def init_model_fn( use_layer_norm=self.use_layer_norm, embedding_init_multiplier=self.embedding_init_multiplier) - params_rng, _= jax.random.split(rng) + params_rng, _ = jax.random.split(rng) init_fake_batch_size = 2 num_categorical_features = 26 num_dense_features = 13 input_size = num_dense_features + num_categorical_features input_shape = (init_fake_batch_size, input_size) init_fn = functools.partial(self._model.init, train=False) - initial_variables = jax.jit(init_fn)( - {'params': params_rng,}, - jnp.ones(input_shape, jnp.float32)) + initial_variables = jax.jit(init_fn)({ + 'params': params_rng, + }, + jnp.ones(input_shape, jnp.float32)) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -115,7 +116,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm inputs = augmented_and_preprocessed_input_batch['inputs'] diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index 5850defa7..70c7fc4a5 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -23,6 +23,7 @@ DROPOUT_RATE = 0.0 + def _instance_norm2d(x, axes, epsilon=1e-5): # promote x to at least float32, this avoids half precision computation # but preserves double or complex floating points diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index ccf9c6bad..bd0aa1d0b 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -30,12 +30,11 @@ def init_model_fn( num_channels=self.num_channels, use_tanh=self.use_tanh, use_layer_norm=self.use_layer_norm, - ) + ) params_rng, _ = jax.random.split(rng) init_fn = functools.partial(self._model.init, train=False) - variables = jax.jit(init_fn)({'params': params_rng}, - fake_batch) + variables = jax.jit(init_fn)({'params': params_rng}, fake_batch) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -53,16 +52,17 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN logits = self._model.apply({'params': params}, - augmented_and_preprocessed_input_batch['inputs'], - rngs={'dropout': rng}, - train=train, - dropout_rate=dropout_rate) + augmented_and_preprocessed_input_batch['inputs'], + rngs={'dropout': rng}, + train=train, + dropout_rate=dropout_rate) return logits, None # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 2a255fee4..7896dcd05 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -84,7 +84,8 @@ def sync_batch_stats( def init_model_fn( self, - rng: spec.RandomState,) -> spec.ModelInitState: + rng: spec.RandomState, + ) -> spec.ModelInitState: model_cls = getattr(models, 'ResNet50') if self.use_silu and self.use_gelu: diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 716bd4239..f33dea723 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -15,6 +15,7 @@ DROPOUT_RATE = 0.0 + def posemb_sincos_2d(h: int, w: int, width: int, @@ -121,31 +122,32 @@ def __call__(self, class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation.""" - - depth: int - mlp_dim: Optional[int] = None # Defaults to 4x input dim. - num_heads: int = 12 - use_glu: bool = False - use_post_layer_norm: bool = False - - @nn.compact - def __call__( - self, x: spec.Tensor, train: bool = True, dropout_rate: float = DROPOUT_RATE - ) -> spec.Tensor: - # Input Encoder - for lyr in range(self.depth): - x = Encoder1DBlock( - name=f"encoderblock_{lyr}", - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - )(x, train=train, dropout_rate=dropout_rate) - if not self.use_post_layer_norm: - return nn.LayerNorm(name="encoder_layernorm")(x) - else: - return x + """Transformer Model Encoder for sequence to sequence translation.""" + + depth: int + mlp_dim: Optional[int] = None # Defaults to 4x input dim. + num_heads: int = 12 + use_glu: bool = False + use_post_layer_norm: bool = False + + @nn.compact + def __call__(self, + x: spec.Tensor, + train: bool = True, + dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: + # Input Encoder + for lyr in range(self.depth): + x = Encoder1DBlock( + name=f"encoderblock_{lyr}", + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + )(x, train=train, dropout_rate=dropout_rate) + if not self.use_post_layer_norm: + return nn.LayerNorm(name="encoder_layernorm")(x) + else: + return x class MAPHead(nn.Module): @@ -182,7 +184,7 @@ class ViT(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 rep_size: Union[int, bool] = True - dropout_rate: [float] = DROPOUT_RATE + dropout_rate: [float] = DROPOUT_RATE reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True use_glu: bool = False @@ -224,8 +226,8 @@ def __call__(self, num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - name='Transformer',)( - x, train=not train, dropout_rate=dropout_rate) + name='Transformer', + )(x, train=not train, dropout_rate=dropout_rate) if self.use_map: x = MAPHead( diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 08a8f4eb1..d0fb4fd72 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -24,15 +24,12 @@ def initialized(self, key: spec.RandomState, model: nn.Module) -> spec.ModelInitState: input_shape = (1, 224, 224, 3) params_rng, _ = jax.random.split(key) - variables = jax.jit( - model.init)({'params': params_rng}, - jnp.ones(input_shape)) + variables = jax.jit(model.init)({'params': params_rng}, + jnp.ones(input_shape)) model_state, params = pop(variables, "params") return params, model_state - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: self._model = models.ViT( num_classes=self._num_classes, use_glu=self.use_glu, @@ -57,7 +54,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 1c2d79e15..b2eee1c37 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -22,14 +22,15 @@ import jax.numpy as jnp import numpy as np +from algoperf.jax_utils import Dropout from algoperf.workloads.librispeech_conformer.librispeech_jax import \ librispeech_preprocessor as preprocessor from algoperf.workloads.librispeech_conformer.librispeech_jax import \ spectrum_augmenter -from algoperf.jax_utils import Dropout DROPOUT_RATE = 0.1 + @struct.dataclass class ConformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" @@ -92,6 +93,7 @@ class Subsample(nn.Module): input_dropout_rate: dropout rate for inputs. """ encoder_dim: int = 0 + @nn.compact def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): output_paddings = input_paddings @@ -188,7 +190,11 @@ class FeedForwardModule(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_RATE): + def __call__(self, + inputs, + padding_mask=None, + train=False, + dropout_rate=DROPOUT_RATE): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -218,8 +224,7 @@ def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_ inputs) inputs = inputs * padding_mask - inputs = Dropout(rate=dropout_rate)( - inputs, deterministic=not train) + inputs = Dropout(rate=dropout_rate)(inputs, deterministic=not train) return inputs @@ -583,7 +588,7 @@ def __call__(self, input_paddings, train, update_batch_norm, - use_running_average, + use_running_average, dropout_rate=DROPOUT_RATE): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index eec707e5f..1e1a1d3f8 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -83,8 +83,7 @@ def init_model_fn( model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) params_rng, _ = jax.random.split(rng, 2) - variables = model_init_fn({'params': params_rng}, - *fake_input_batch) + variables = model_init_fn({'params': params_rng}, *fake_input_batch) model_state, params = pop(variables, "params") diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index b47b1359a..1bd998027 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -16,11 +16,11 @@ from jax.experimental import rnn import jax.numpy as jnp +from algoperf.jax_utils import Dropout from algoperf.workloads.librispeech_conformer.librispeech_jax import \ librispeech_preprocessor as preprocessor from algoperf.workloads.librispeech_conformer.librispeech_jax import \ spectrum_augmenter -from algoperf.jax_utils import Dropout Array = jnp.ndarray StateType = Union[Array, Tuple[Array, ...]] @@ -31,7 +31,7 @@ CarryHistory = Any Output = Any -DROPOUT_RATE=0.1 +DROPOUT_RATE = 0.1 @struct.dataclass diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 2bb119439..81a56db72 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -15,9 +15,7 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Deepspeech model init function. """ model_config = models.DeepspeechConfig( @@ -36,8 +34,9 @@ def init_model_fn( model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) params_rng, _ = jax.random.split(rng, 2) - variables = model_init_fn({'params': params_rng,}, - *fake_input_batch) + variables = model_init_fn({ + 'params': params_rng, + }, *fake_input_batch) model_state = variables[ 'batch_stats'] if not self.layernorm_everywhere else {} diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 5f3fdcf78..27bd9ae54 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -32,9 +32,7 @@ def __call__(self, x: spec.Tensor, train: bool) -> spec.Tensor: class MnistWorkload(BaseMnistWorkload): - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: init_val = jnp.ones((1, 28, 28, 1), jnp.float32) self._model = _Model() initial_params = self._model.init({'params': rng}, init_val, diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index d51ca2f20..06eef6187 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -8,7 +8,8 @@ from algoperf.jax_utils import Dropout -DROPOUT_RATE=0.1 +DROPOUT_RATE = 0.1 + def _make_embed(latent_dim, name): @@ -28,7 +29,9 @@ def make_fn(inputs): x = nn.Dense(features=dim)(x) x = nn.LayerNorm()(x) x = activation_fn(x) - x = Dropout(rate=dropout_rate, deterministic=not train)(x, rate=dropout_rate) + x = Dropout( + rate=dropout_rate, deterministic=not train)( + x, rate=dropout_rate) return x return make_fn @@ -69,11 +72,20 @@ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE): for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork( update_edge_fn=_make_mlp( - self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate), + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate), update_node_fn=_make_mlp( - self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate), + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate), update_global_fn=_make_mlp( - self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate)) + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate)) graph = net(graph) diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index e03252ed9..04a9bce2e 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -17,9 +17,7 @@ class OgbgWorkload(BaseOgbgWorkload): - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: rng, params_rng = jax.random.split(rng, 2) self._model = models.GNN( self._num_outputs, @@ -53,7 +51,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del update_batch_norm # No BN in the GNN model. if model_state is not None: diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index 38f76db80..1147eb34b 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -13,9 +13,9 @@ from algoperf.jax_utils import Dropout - DROPOUT_RATE = 0.1 + @struct.dataclass class TransformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" @@ -171,7 +171,8 @@ def __call__(self, inputs, dropout_rate=DROPOUT_RATE): )( inputs) x = x * y - x = Dropout(rate=dropout_rate)(x, rate=dropout_rate, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)( + x, rate=dropout_rate, deterministic=cfg.deterministic) output = nn.Dense( actual_out_dim, dtype=cfg.dtype, @@ -223,7 +224,8 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE): deterministic=cfg.deterministic, )(cfg.attention_temp * x, x, mask=encoder_mask) - x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = Dropout(rate=dropout_rate)( + x, deterministic=cfg.deterministic, rate=dropout_rate) x = x + inputs if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -285,7 +287,8 @@ def __call__( decode=cfg.decode, )(cfg.attention_temp * x, x, mask=decoder_mask) - x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = Dropout(rate=dropout_rate)( + x, deterministic=cfg.deterministic, rate=dropout_rate) x = x + targets if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -304,7 +307,8 @@ def __call__( deterministic=cfg.deterministic, )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) - y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic, rate=dropout_rate) + y = Dropout(rate=dropout_rate)( + y, deterministic=cfg.deterministic, rate=dropout_rate) y = y + x if not pre_ln: y = nn.LayerNorm(dtype=cfg.dtype)(y) @@ -361,7 +365,8 @@ def __call__(self, x = AddPositionEmbs( config=cfg, decode=False, name="posembed_input")( x, inputs_positions=inputs_positions) - x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = Dropout(rate=dropout_rate)( + x, deterministic=cfg.deterministic, rate=dropout_rate) x = x.astype(cfg.dtype) @@ -432,7 +437,8 @@ def __call__( y = AddPositionEmbs( config=cfg, decode=cfg.decode, name="posembed_output")( y, inputs_positions=targets_positions) - y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic, rate=dropout_rate) + y = Dropout(rate=dropout_rate)( + y, deterministic=cfg.deterministic, rate=dropout_rate) y = y.astype(cfg.dtype) diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 24d4852b8..d402f9d95 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -206,9 +206,7 @@ def translate_and_calculate_bleu(self, bleu_score = bleu.corpus_bleu(predictions, [references]).score return bleu_score - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: init_fake_batch_size = 2 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) @@ -250,7 +248,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: [float] = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: [float] = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm From badf12453a56b078c3b156c0b850ec5e8158bb81 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:23:03 +0000 Subject: [PATCH 34/39] fix test --- tests/reference_algorithm_tests.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index c4ca514a8..58a4a5ddc 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -184,12 +184,11 @@ def __init__(self): if 'librispeech' in workload_name: self.tokenizer = _FakeTokenizer() - def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None): + def init_model_fn(self, rng): # pylint: disable=line-too-long if not (FLAGS.identical and os.path.exists(f'tests/modeldiffs/{workload_name}/compare.py')): - return super().init_model_fn( - rng, dropout_rate=dropout_rate, aux_dropout_rate=aux_dropout_rate) + return super().init_model_fn(rng) if framework == 'jax': compare_module = importlib.import_module( f'tests.modeldiffs.{workload_name}.compare') @@ -201,7 +200,7 @@ def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None): return (FrozenDict(**jax_utils.replicate(jax_params)), FrozenDict(**jax_utils.replicate(model_state)) if model_state is not None else model_state) - return super().init_model_fn([0], dropout_rate=0.0, aux_dropout_rate=0.0) + return super().init_model_fn([0]) @property def num_eval_train_examples(self): From eff3ea19572d0c3f3497bcf997948d7b3cbe7c69 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:28:18 +0000 Subject: [PATCH 35/39] fix lint errors --- algoperf/workloads/fastmri/fastmri_jax/models.py | 1 - algoperf/workloads/imagenet_vit/imagenet_jax/workload.py | 2 +- .../workloads/librispeech_deepspeech/librispeech_jax/models.py | 3 ++- algoperf/workloads/ogbg/ogbg_jax/models.py | 2 +- algoperf/workloads/ogbg/ogbg_jax/workload.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index 70c7fc4a5..a5fe060b9 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -13,7 +13,6 @@ github.com/facebookresearch/fastMRI/tree/main/fastmri/data """ import functools -from typing import Optional import flax.linen as nn import jax diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index d0fb4fd72..ab9df0f62 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -1,6 +1,6 @@ """ImageNet workload implemented in Jax.""" -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple from flax import jax_utils from flax import linen as nn diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 1bd998027..fab0b3259 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -489,7 +489,8 @@ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): # Subsample input by a factor of 4 by performing strided convolutions. outputs, output_paddings = Subsample( - config=config)(outputs, output_paddings, train, dropout_rate=dropout_rate) + config=config)(outputs, output_paddings, train, + dropout_rate=dropout_rate) # Run the lstm layers. for _ in range(config.num_lstm_layers): diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 06eef6187..8524bb60e 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -1,6 +1,6 @@ # Forked from the init2winit implementation here # https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. -from typing import Optional, Tuple +from typing import Tuple from flax import linen as nn import jax.numpy as jnp diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index 04a9bce2e..0535aea83 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -1,6 +1,6 @@ """OGBG workload implemented in Jax.""" import functools -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple from flax import jax_utils import jax From f7fd6c7452ead2771a9ebc6cd1e50ba99d5f3d9a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:42:36 +0000 Subject: [PATCH 36/39] formatting --- algoperf/jax_utils.py | 92 +++++++++---------- .../librispeech_jax/models.py | 2 +- 2 files changed, 45 insertions(+), 49 deletions(-) diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py index 369eb1b1a..214a178c6 100644 --- a/algoperf/jax_utils.py +++ b/algoperf/jax_utils.py @@ -13,7 +13,7 @@ # Custom Layers class Dropout(Module): - """Create a dropout layer. + """Create a dropout layer. Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. The reference dropout implementation is modified support changes to dropout rate during training by: 1) adding rate argument to the __call__ method. @@ -53,21 +53,21 @@ class Dropout(Module): rng_collection: the rng collection name to use when requesting an rng key. """ - rate: float | None = None - broadcast_dims: Sequence[int] = () - deterministic: bool | None = None - rng_collection: str = "dropout" - legacy: bool = True - - @compact - def __call__( - self, - inputs, - deterministic: bool | None = None, - rate: float | None = None, - rng: PRNGKey | None = None, - ): - """Applies a random dropout mask to the input. + rate: float | None = None + broadcast_dims: Sequence[int] = () + deterministic: bool | None = None + rng_collection: str = "dropout" + legacy: bool = True + + @compact + def __call__( + self, + inputs, + deterministic: bool | None = None, + rate: float | None = None, + rng: PRNGKey | None = None, + ): + """Applies a random dropout mask to the input. Args: inputs: the inputs that should be randomly masked. @@ -81,44 +81,40 @@ def __call__( Returns: The masked inputs reweighted to preserve mean. """ - deterministic = merge_param("deterministic", - self.deterministic, - deterministic) + deterministic = merge_param("deterministic", self.deterministic, deterministic) - # Override self.rate if rate is passed to __call__ - if rate is None: - rate = self.rate + # Override self.rate if rate is passed to __call__ + if rate is None: + rate = self.rate - if self.legacy: - if rate == 0.0: - return inputs + if self.legacy: + if rate == 0.0: + return inputs - # Prevent gradient NaNs in 1.0 edge-case. - if rate == 1.0: - return jnp.zeros_like(inputs) + # Prevent gradient NaNs in 1.0 edge-case. + if rate == 1.0: + return jnp.zeros_like(inputs) - if deterministic: - return inputs + if deterministic: + return inputs - keep_prob = 1.0 - rate - if rng is None: - rng = self.make_rng(self.rng_collection) - broadcast_shape = list(inputs.shape) - for dim in self.broadcast_dims: - broadcast_shape[dim] = 1 - mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) - mask = jnp.broadcast_to(mask, inputs.shape) - return lax.select(mask, inputs, jnp.zeros_like(inputs)) + keep_prob = 1.0 - rate + if rng is None: + rng = self.make_rng(self.rng_collection) + broadcast_shape = list(inputs.shape) + for dim in self.broadcast_dims: + broadcast_shape[dim] = 1 + mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) + mask = jnp.broadcast_to(mask, inputs.shape) + return lax.select(mask, inputs, jnp.zeros_like(inputs)) # Utilities for debugging def print_jax_model_summary(model, fake_inputs): - """Prints a summary of the jax module.""" - tabulate_fn = nn.tabulate( - model, - jax.random.PRNGKey(0), - console_kwargs={ - "force_terminal": False, "force_jupyter": False, "width": 240 - }, - ) - print(tabulate_fn(fake_inputs, train=False)) + """Prints a summary of the jax module.""" + tabulate_fn = nn.tabulate( + model, + jax.random.PRNGKey(0), + console_kwargs={"force_terminal": False, "force_jupyter": False, "width": 240}, + ) + print(tabulate_fn(fake_inputs, train=False)) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index fab0b3259..262fc1a95 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -489,7 +489,7 @@ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): # Subsample input by a factor of 4 by performing strided convolutions. outputs, output_paddings = Subsample( - config=config)(outputs, output_paddings, train, + config=config)(outputs, output_paddings, train, dropout_rate=dropout_rate) # Run the lstm layers. From 8fc4cc5cd7914d698271bd50583432832e8dc98c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:44:35 +0000 Subject: [PATCH 37/39] fix spacing issues --- algoperf/jax_utils.py | 92 ++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py index 214a178c6..369eb1b1a 100644 --- a/algoperf/jax_utils.py +++ b/algoperf/jax_utils.py @@ -13,7 +13,7 @@ # Custom Layers class Dropout(Module): - """Create a dropout layer. + """Create a dropout layer. Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. The reference dropout implementation is modified support changes to dropout rate during training by: 1) adding rate argument to the __call__ method. @@ -53,21 +53,21 @@ class Dropout(Module): rng_collection: the rng collection name to use when requesting an rng key. """ - rate: float | None = None - broadcast_dims: Sequence[int] = () - deterministic: bool | None = None - rng_collection: str = "dropout" - legacy: bool = True - - @compact - def __call__( - self, - inputs, - deterministic: bool | None = None, - rate: float | None = None, - rng: PRNGKey | None = None, - ): - """Applies a random dropout mask to the input. + rate: float | None = None + broadcast_dims: Sequence[int] = () + deterministic: bool | None = None + rng_collection: str = "dropout" + legacy: bool = True + + @compact + def __call__( + self, + inputs, + deterministic: bool | None = None, + rate: float | None = None, + rng: PRNGKey | None = None, + ): + """Applies a random dropout mask to the input. Args: inputs: the inputs that should be randomly masked. @@ -81,40 +81,44 @@ def __call__( Returns: The masked inputs reweighted to preserve mean. """ - deterministic = merge_param("deterministic", self.deterministic, deterministic) + deterministic = merge_param("deterministic", + self.deterministic, + deterministic) - # Override self.rate if rate is passed to __call__ - if rate is None: - rate = self.rate + # Override self.rate if rate is passed to __call__ + if rate is None: + rate = self.rate - if self.legacy: - if rate == 0.0: - return inputs + if self.legacy: + if rate == 0.0: + return inputs - # Prevent gradient NaNs in 1.0 edge-case. - if rate == 1.0: - return jnp.zeros_like(inputs) + # Prevent gradient NaNs in 1.0 edge-case. + if rate == 1.0: + return jnp.zeros_like(inputs) - if deterministic: - return inputs + if deterministic: + return inputs - keep_prob = 1.0 - rate - if rng is None: - rng = self.make_rng(self.rng_collection) - broadcast_shape = list(inputs.shape) - for dim in self.broadcast_dims: - broadcast_shape[dim] = 1 - mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) - mask = jnp.broadcast_to(mask, inputs.shape) - return lax.select(mask, inputs, jnp.zeros_like(inputs)) + keep_prob = 1.0 - rate + if rng is None: + rng = self.make_rng(self.rng_collection) + broadcast_shape = list(inputs.shape) + for dim in self.broadcast_dims: + broadcast_shape[dim] = 1 + mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) + mask = jnp.broadcast_to(mask, inputs.shape) + return lax.select(mask, inputs, jnp.zeros_like(inputs)) # Utilities for debugging def print_jax_model_summary(model, fake_inputs): - """Prints a summary of the jax module.""" - tabulate_fn = nn.tabulate( - model, - jax.random.PRNGKey(0), - console_kwargs={"force_terminal": False, "force_jupyter": False, "width": 240}, - ) - print(tabulate_fn(fake_inputs, train=False)) + """Prints a summary of the jax module.""" + tabulate_fn = nn.tabulate( + model, + jax.random.PRNGKey(0), + console_kwargs={ + "force_terminal": False, "force_jupyter": False, "width": 240 + }, + ) + print(tabulate_fn(fake_inputs, train=False)) From 99c31114af33c6bd2ea5e9fcdc400131dc17bc78 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:53:37 +0000 Subject: [PATCH 38/39] formatting --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 4e15e4400..1daa72848 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,7 @@ pytorch_gpu = [ based_on_style = "yapf" each_dict_entry_on_separate_line = false split_all_top_level_comma_separated_values = true +column_limit = 80 [tool.yapfignore] ignore_patterns = ["algoperf/_version.py"] From c2f4ed0eb0fe5ceebcd9a21c7a857de644654f2e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 22:05:33 +0000 Subject: [PATCH 39/39] formatting --- algoperf/jax_utils.py | 30 ++++++++++++++++++------------ submission_runner.py | 3 +-- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py index 369eb1b1a..467606241 100644 --- a/algoperf/jax_utils.py +++ b/algoperf/jax_utils.py @@ -13,11 +13,15 @@ # Custom Layers class Dropout(Module): + # pylint: disable=line-too-long """Create a dropout layer. - Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. - The reference dropout implementation is modified support changes to dropout rate during training by: + Forked from + https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. + The reference dropout implementation is modified support changes + to dropout rate during training by: 1) adding rate argument to the __call__ method. - 2) removing the if-else condition to check for edge cases, which will trigger a recompile for jitted code. + 2) removing the if-else condition to check for edge cases, which + will trigger a recompile for jitted code. .. note:: When using :meth:`Module.apply() `, make sure @@ -47,10 +51,11 @@ class Dropout(Module): Attributes: rate: the dropout probability. (_not_ the keep rate!) broadcast_dims: dimensions that will share the same dropout mask - deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and - masked, whereas if true, no mask is applied and the inputs are returned as - is. - rng_collection: the rng collection name to use when requesting an rng key. + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` + and masked, whereas if true, no mask is applied and the inputs are + returned as is. + rng_collection: the rng collection name to use when requesting an rng + key. """ rate: float | None = None @@ -71,12 +76,13 @@ def __call__( Args: inputs: the inputs that should be randomly masked. - deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and - masked, whereas if true, no mask is applied and the inputs are returned - as is. + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` + and masked, whereas if true, no mask is applied and the inputs are + returned as is. rate: the dropout probability. (_not_ the keep rate!) - rng: an optional PRNGKey used as the random key, if not specified, one - will be generated using ``make_rng`` with the ``rng_collection`` name. + rng: an optional PRNGKey used as the random key, if not specified, + one will be generated using ``make_rng`` with the + ``rng_collection`` name. Returns: The masked inputs reweighted to preserve mean. diff --git a/submission_runner.py b/submission_runner.py index d076a1043..221a7c21d 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -228,8 +228,7 @@ def train_once( global_batch_size=global_batch_size) logging.info('Initializing model.') with profiler.profile('Initializing model'): - model_params, model_state = workload.init_model_fn( - model_init_rng) + model_params, model_state = workload.init_model_fn(model_init_rng) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ 'librispeech_conformer',