diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py new file mode 100644 index 000000000..467606241 --- /dev/null +++ b/algoperf/jax_utils.py @@ -0,0 +1,130 @@ +from collections.abc import Sequence + +import flax.linen as nn +from flax.linen.module import compact +from flax.linen.module import merge_param +from flax.linen.module import Module +from flax.typing import PRNGKey +import jax +from jax import lax +from jax import random +import jax.numpy as jnp + + +# Custom Layers +class Dropout(Module): + # pylint: disable=line-too-long + """Create a dropout layer. + Forked from + https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. + The reference dropout implementation is modified support changes + to dropout rate during training by: + 1) adding rate argument to the __call__ method. + 2) removing the if-else condition to check for edge cases, which + will trigger a recompile for jitted code. + + .. note:: + When using :meth:`Module.apply() `, make sure + to include an RNG seed named ``'dropout'``. Dropout isn't necessary for + variable initialization. + + Example usage:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> class MLP(nn.Module): + ... @nn.compact + ... def __call__(self, x, train): + ... x = nn.Dense(4)(x) + ... x = nn.Dropout(0.5, deterministic=not train)(x) + ... return x + + >>> model = MLP() + >>> x = jnp.ones((1, 3)) + >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout + >>> model.apply(variables, x, train=False) # don't use dropout + Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32) + >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout + Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32) + + Attributes: + rate: the dropout probability. (_not_ the keep rate!) + broadcast_dims: dimensions that will share the same dropout mask + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` + and masked, whereas if true, no mask is applied and the inputs are + returned as is. + rng_collection: the rng collection name to use when requesting an rng + key. + """ + + rate: float | None = None + broadcast_dims: Sequence[int] = () + deterministic: bool | None = None + rng_collection: str = "dropout" + legacy: bool = True + + @compact + def __call__( + self, + inputs, + deterministic: bool | None = None, + rate: float | None = None, + rng: PRNGKey | None = None, + ): + """Applies a random dropout mask to the input. + + Args: + inputs: the inputs that should be randomly masked. + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` + and masked, whereas if true, no mask is applied and the inputs are + returned as is. + rate: the dropout probability. (_not_ the keep rate!) + rng: an optional PRNGKey used as the random key, if not specified, + one will be generated using ``make_rng`` with the + ``rng_collection`` name. + + Returns: + The masked inputs reweighted to preserve mean. + """ + deterministic = merge_param("deterministic", + self.deterministic, + deterministic) + + # Override self.rate if rate is passed to __call__ + if rate is None: + rate = self.rate + + if self.legacy: + if rate == 0.0: + return inputs + + # Prevent gradient NaNs in 1.0 edge-case. + if rate == 1.0: + return jnp.zeros_like(inputs) + + if deterministic: + return inputs + + keep_prob = 1.0 - rate + if rng is None: + rng = self.make_rng(self.rng_collection) + broadcast_shape = list(inputs.shape) + for dim in self.broadcast_dims: + broadcast_shape[dim] = 1 + mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) + mask = jnp.broadcast_to(mask, inputs.shape) + return lax.select(mask, inputs, jnp.zeros_like(inputs)) + + +# Utilities for debugging +def print_jax_model_summary(model, fake_inputs): + """Prints a summary of the jax module.""" + tabulate_fn = nn.tabulate( + model, + jax.random.PRNGKey(0), + console_kwargs={ + "force_terminal": False, "force_jupyter": False, "width": 240 + }, + ) + print(tabulate_fn(fake_inputs, train=False)) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index ad43bc62f..c6cc50fbf 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -79,14 +79,8 @@ def sync_batch_stats( new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Dropout is unused.""" - del dropout_rate - del aux_dropout_rate model_cls = getattr(models, 'ResNet18') model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index 6d9a489ff..4a91a80b8 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -1,11 +1,14 @@ """A JAX implementation of DLRM-Small.""" - from typing import Sequence import flax.linen as nn from jax import nn as jnn import jax.numpy as jnp +from algoperf.jax_utils import Dropout + +DROPOUT_RATE = 0.0 + class DLRMResNet(nn.Module): """Define a DLRMResNet model. @@ -23,12 +26,13 @@ class DLRMResNet(nn.Module): mlp_bottom_dims: Sequence[int] = (256, 256, 256) mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) embed_dim: int = 128 - dropout_rate: float = 0.0 + dropout_rate: float = DROPOUT_RATE use_layer_norm: bool = False # Unused. embedding_init_multiplier: float = None # Unused @nn.compact - def __call__(self, x, train): + def __call__(self, x, train, dropout_rate=DROPOUT_RATE): + bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -88,8 +92,8 @@ def scaled_init(key, shape, dtype=jnp.float_): stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))( top_mlp_input) x = nn.relu(x) - if self.dropout_rate and layer_idx == num_layers_top - 2: - x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) + if dropout_rate and layer_idx == num_layers_top - 2: + x = Dropout(dropout_rate, deterministic=not train)(x, rate=dropout_rate) top_mlp_input += x # In the DLRM model the last layer width is always 1. We can hardcode that # below. @@ -151,7 +155,8 @@ class DlrmSmall(nn.Module): embedding_init_multiplier: float = None @nn.compact - def __call__(self, x, train): + def __call__(self, x, train, dropout_rate=DROPOUT_RATE): + bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -210,10 +215,10 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input = nn.relu(top_mlp_input) if self.use_layer_norm: top_mlp_input = nn.LayerNorm()(top_mlp_input) - if (self.dropout_rate is not None and self.dropout_rate > 0.0 and + if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): - top_mlp_input = nn.Dropout( - rate=self.dropout_rate, deterministic=not train)( - top_mlp_input) + top_mlp_input = Dropout( + dropout_rate, deterministic=not train)( + top_mlp_input, rate=dropout_rate) logits = top_mlp_input return logits diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 91761e458..d84d18d5c 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -72,36 +72,34 @@ def loss_fn( def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None, tabulate: Optional[bool] = False, ) -> spec.ModelInitState: """Only dropout is used.""" - del aux_dropout_rate if self.use_resnet: model_class = models.DLRMResNet else: model_class = models.DlrmSmall + self._model = model_class( vocab_size=self.vocab_size, num_dense_features=self.num_dense_features, mlp_bottom_dims=self.mlp_bottom_dims, mlp_top_dims=self.mlp_top_dims, embed_dim=self.embed_dim, - dropout_rate=dropout_rate, use_layer_norm=self.use_layer_norm, embedding_init_multiplier=self.embedding_init_multiplier) - params_rng, dropout_rng = jax.random.split(rng) + params_rng, _ = jax.random.split(rng) init_fake_batch_size = 2 num_categorical_features = 26 num_dense_features = 13 input_size = num_dense_features + num_categorical_features input_shape = (init_fake_batch_size, input_size) init_fn = functools.partial(self._model.init, train=False) - initial_variables = jax.jit(init_fn)( - {'params': params_rng, 'dropout': dropout_rng}, - jnp.ones(input_shape, jnp.float32)) + initial_variables = jax.jit(init_fn)({ + 'params': params_rng, + }, + jnp.ones(input_shape, jnp.float32)) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -117,7 +115,9 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm inputs = augmented_and_preprocessed_input_batch['inputs'] @@ -125,6 +125,7 @@ def model_fn( apply_kwargs = {'train': train} if train: apply_kwargs['rngs'] = {'dropout': rng} + apply_kwargs['dropout_rate'] = dropout_rate logits_batch = self._model.apply({'params': params}, inputs, **apply_kwargs) return logits_batch, None diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index 44bff0e21..a5fe060b9 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -13,12 +13,15 @@ github.com/facebookresearch/fastMRI/tree/main/fastmri/data """ import functools -from typing import Optional import flax.linen as nn import jax import jax.numpy as jnp +from algoperf.jax_utils import Dropout + +DROPOUT_RATE = 0.0 + def _instance_norm2d(x, axes, epsilon=1e-5): # promote x to at least float32, this avoids half precision computation @@ -56,16 +59,12 @@ class UNet(nn.Module): num_channels: int = 32 num_pool_layers: int = 4 out_channels = 1 - dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. + dropout_rate: float = DROPOUT_RATE use_tanh: bool = False use_layer_norm: bool = False @nn.compact - def __call__(self, x, train=True): - dropout_rate = self.dropout_rate - if dropout_rate is None: - dropout_rate = 0.0 - + def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE): # pylint: disable=invalid-name _ConvBlock = functools.partial( ConvBlock, @@ -138,12 +137,12 @@ class ConvBlock(nn.Module): dropout_rate: Dropout probability. """ out_channels: int - dropout_rate: float use_tanh: bool use_layer_norm: bool + dropout_rate: float = 0.0 @nn.compact - def __call__(self, x, train=True): + def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE): """Forward function. Note: Pytorch is NCHW and jax/flax is NHWC. Args: @@ -172,9 +171,9 @@ def __call__(self, x, train=True): x = activation_fn(x) # Ref code uses dropout2d which applies the same mask for the entire channel # Replicated by using broadcast dims to have the same filter on HW - x = nn.Dropout( - self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( - x) + x = Dropout( + dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( + x, rate=dropout_rate) x = nn.Conv( features=self.out_channels, kernel_size=(3, 3), @@ -186,9 +185,9 @@ def __call__(self, x, train=True): else: x = _instance_norm2d(x, (1, 2)) x = activation_fn(x) - x = nn.Dropout( - self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( - x) + x = Dropout( + dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( + x, rate=dropout_rate) return x diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 1156cf30a..bd0aa1d0b 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -11,6 +11,7 @@ from algoperf import param_utils from algoperf import spec import algoperf.random_utils as prng +from algoperf.workloads.fastmri.fastmri_jax.models import DROPOUT_RATE from algoperf.workloads.fastmri.fastmri_jax.models import UNet from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload @@ -21,21 +22,19 @@ class FastMRIWorkload(BaseFastMRIWorkload): def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + ) -> spec.ModelInitState: """aux_dropout_rate is unused.""" - del aux_dropout_rate fake_batch = jnp.zeros((13, 320, 320)) self._model = UNet( num_pool_layers=self.num_pool_layers, num_channels=self.num_channels, use_tanh=self.use_tanh, use_layer_norm=self.use_layer_norm, - dropout_rate=dropout_rate) - params_rng, dropout_rng = jax.random.split(rng) - variables = jax.jit( - self._model.init)({'params': params_rng, 'dropout': dropout_rng}, - fake_batch) + ) + + params_rng, _ = jax.random.split(rng) + init_fn = functools.partial(self._model.init, train=False) + variables = jax.jit(init_fn)({'params': params_rng}, fake_batch) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -52,14 +51,18 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN + logits = self._model.apply({'params': params}, augmented_and_preprocessed_input_batch['inputs'], rngs={'dropout': rng}, - train=train) + train=train, + dropout_rate=dropout_rate) return logits, None # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 4ec3937b8..7896dcd05 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -85,11 +85,7 @@ def sync_batch_stats( def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Dropout is unused.""" - del dropout_rate - del aux_dropout_rate + ) -> spec.ModelInitState: model_cls = getattr(models, 'ResNet50') if self.use_silu and self.use_gelu: diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 7ce3a0395..f33dea723 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -11,6 +11,9 @@ import jax.numpy as jnp from algoperf import spec +from algoperf.jax_utils import Dropout + +DROPOUT_RATE = 0.0 def posemb_sincos_2d(h: int, @@ -35,10 +38,13 @@ class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim. use_glu: bool = False - dropout_rate: float = 0.0 + dropout_rate: float = DROPOUT_RATE @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + def __call__(self, + x: spec.Tensor, + train: bool = True, + dropout_rate=DROPOUT_RATE) -> spec.Tensor: """Applies Transformer MlpBlock module.""" inits = { 'kernel_init': nn.initializers.xavier_uniform(), @@ -53,7 +59,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: y = nn.Dense(self.mlp_dim, **inits)(x) x = x * y - x = nn.Dropout(rate=self.dropout_rate)(x, train) + x = Dropout(dropout_rate)(x, train, rate=dropout_rate) x = nn.Dense(d, **inits)(x) return x @@ -67,7 +73,11 @@ class Encoder1DBlock(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + def __call__(self, + x: spec.Tensor, + train: bool = True, + dropout_rate=dropout_rate) -> spec.Tensor: + if not self.use_post_layer_norm: y = nn.LayerNorm(name='LayerNorm_0')(x) y = nn.MultiHeadDotProductAttention( @@ -76,16 +86,14 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y y = nn.LayerNorm(name='LayerNorm_2')(x) y = MlpBlock( - mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3')( + y, train, dropout_rate=dropout_rate) + y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y else: y = x @@ -95,7 +103,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y x = nn.LayerNorm(name='LayerNorm_0')(x) @@ -103,9 +111,10 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, - dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + name='MlpBlock_3', + dropout_rate=dropout_rate)( + y, train, dropout_rate=dropout_rate) + y = Dropout(dropout_rate)(y, train)(rate=dropout_rate) x = x + y x = nn.LayerNorm(name='LayerNorm_2')(x) @@ -114,27 +123,29 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation.""" + depth: int mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 - dropout_rate: float = 0.0 use_glu: bool = False use_post_layer_norm: bool = False @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + def __call__(self, + x: spec.Tensor, + train: bool = True, + dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: # Input Encoder for lyr in range(self.depth): - block = Encoder1DBlock( - name=f'encoderblock_{lyr}', + x = Encoder1DBlock( + name=f"encoderblock_{lyr}", mlp_dim=self.mlp_dim, num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=self.dropout_rate) - x = block(x, train) + )(x, train=train, dropout_rate=dropout_rate) if not self.use_post_layer_norm: - return nn.LayerNorm(name='encoder_layernorm')(x) + return nn.LayerNorm(name="encoder_layernorm")(x) else: return x @@ -143,9 +154,10 @@ class MAPHead(nn.Module): """Multihead Attention Pooling.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 + dropout_rate: float = 0.0 @nn.compact - def __call__(self, x): + def __call__(self, x, dropout_rate=DROPOUT_RATE): n, _, d = x.shape probe = self.param('probe', nn.initializers.xavier_uniform(), (1, 1, d), @@ -158,7 +170,7 @@ def __call__(self, x): kernel_init=nn.initializers.xavier_uniform())(probe, x) y = nn.LayerNorm()(x) - x = x + MlpBlock(mlp_dim=self.mlp_dim)(y) + x = x + MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=dropout_rate)(y) return x[:, 0] @@ -172,7 +184,7 @@ class ViT(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 rep_size: Union[int, bool] = True - dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. + dropout_rate: [float] = DROPOUT_RATE reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True use_glu: bool = False @@ -186,7 +198,11 @@ def get_posemb(self, return posemb_sincos_2d(*seqshape, width, dtype=dtype) @nn.compact - def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: + def __call__(self, + x: spec.Tensor, + *, + train: bool = False, + dropout_rate=DROPOUT_RATE) -> spec.Tensor: # Patch extraction x = nn.Conv( self.width, @@ -202,10 +218,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: # Add posemb before adding extra token. x = x + self.get_posemb((h, w), c, x.dtype) - dropout_rate = self.dropout_rate - if dropout_rate is None: - dropout_rate = 0.0 - x = nn.Dropout(rate=dropout_rate)(x, not train) + x = Dropout(dropout_rate)(x, not train, rate=dropout_rate) x = Encoder( depth=self.depth, @@ -213,12 +226,15 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=dropout_rate, - name='Transformer')( - x, train=not train) + name='Transformer', + )(x, train=not train, dropout_rate=dropout_rate) if self.use_map: - x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) + x = MAPHead( + num_heads=self.num_heads, + mlp_dim=self.mlp_dim, + dropout_rate=dropout_rate)( + x, dropout_rate=dropout_rate) else: x = jnp.mean(x, axis=1) diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 35a6c46be..ab9df0f62 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -1,6 +1,6 @@ """ImageNet workload implemented in Jax.""" -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple from flax import jax_utils from flax import linen as nn @@ -23,21 +23,14 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def initialized(self, key: spec.RandomState, model: nn.Module) -> spec.ModelInitState: input_shape = (1, 224, 224, 3) - params_rng, dropout_rng = jax.random.split(key) - variables = jax.jit( - model.init)({'params': params_rng, 'dropout': dropout_rng}, - jnp.ones(input_shape)) + params_rng, _ = jax.random.split(key) + variables = jax.jit(model.init)({'params': params_rng}, + jnp.ones(input_shape)) model_state, params = pop(variables, "params") return params, model_state - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - del aux_dropout_rate + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: self._model = models.ViT( - dropout_rate=dropout_rate, num_classes=self._num_classes, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, @@ -60,14 +53,17 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN logits = self._model.apply({'params': params}, augmented_and_preprocessed_input_batch['inputs'], rngs={'dropout': rng}, - train=train) + train=train, + dropout_rate=dropout_rate) return logits, None def _eval_model_on_split(self, diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 593d463c3..b2eee1c37 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -22,11 +22,14 @@ import jax.numpy as jnp import numpy as np +from algoperf.jax_utils import Dropout from algoperf.workloads.librispeech_conformer.librispeech_jax import \ librispeech_preprocessor as preprocessor from algoperf.workloads.librispeech_conformer.librispeech_jax import \ spectrum_augmenter +DROPOUT_RATE = 0.1 + @struct.dataclass class ConformerConfig: @@ -36,14 +39,6 @@ class ConformerConfig: encoder_dim: int = 512 num_attention_heads: int = 8 num_encoder_layers: int = 4 - attention_dropout_rate: float = 0.0 - # If None, defaults to 0.1. - attention_residual_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.0. - conv_residual_dropout_rate: Optional[float] = 0.0 - feed_forward_dropout_rate: float = 0.0 - # If None, defaults to 0.1. - feed_forward_residual_dropout_rate: Optional[float] = 0.1 convolution_kernel_size: int = 5 feed_forward_expansion_factor: int = 4 freq_mask_count: int = 2 @@ -53,8 +48,6 @@ class ConformerConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 batch_norm_momentum: float = 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True @@ -100,10 +93,9 @@ class Subsample(nn.Module): input_dropout_rate: dropout rate for inputs. """ encoder_dim: int = 0 - input_dropout_rate: float = 0.0 @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): output_paddings = input_paddings outputs = jnp.expand_dims(inputs, axis=-1) @@ -129,9 +121,9 @@ def __call__(self, inputs, input_paddings, train): outputs = outputs + AddPositionalEmbedding(embedding_dim=self.encoder_dim)( seq_length=outputs.shape[1]) - outputs = nn.Dropout( - rate=self.input_dropout_rate, deterministic=not train)( - outputs) + outputs = Dropout( + rate=dropout_rate, deterministic=not train)( + outputs, rate=dropout_rate) return outputs, output_paddings @@ -198,9 +190,12 @@ class FeedForwardModule(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, padding_mask=None, train=False): + def __call__(self, + inputs, + padding_mask=None, + train=False, + dropout_rate=DROPOUT_RATE): config = self.config - inputs = LayerNorm(dim=config.encoder_dim)(inputs) inputs = nn.Dense( @@ -217,8 +212,8 @@ def __call__(self, inputs, padding_mask=None, train=False): 'config.activation_function_name values, recieved ' f'{config.activation_function_name}') inputs = activation_fn(inputs) - inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)( - inputs, deterministic=not train) + inputs = Dropout(rate=dropout_rate)( + inputs, deterministic=not train, rate=dropout_rate) inputs = inputs * padding_mask @@ -229,13 +224,7 @@ def __call__(self, inputs, padding_mask=None, train=False): inputs) inputs = inputs * padding_mask - if config.feed_forward_residual_dropout_rate is None: - feed_forward_residual_dropout_rate = 0.1 - else: - feed_forward_residual_dropout_rate = ( - config.feed_forward_residual_dropout_rate) - inputs = nn.Dropout(rate=feed_forward_residual_dropout_rate)( - inputs, deterministic=not train) + inputs = Dropout(rate=dropout_rate)(inputs, deterministic=not train) return inputs @@ -389,8 +378,9 @@ class MultiHeadedSelfAttention(nn.Module): config: ConformerConfig = None @nn.compact - def __call__(self, inputs, paddings, train): + def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE): config = self.config + mask_paddings = 1 - paddings attention_mask = nn.make_attention_mask( mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32) @@ -408,17 +398,13 @@ def __call__(self, inputs, paddings, train): use_bias=True, broadcast_dropout=False, attention_fn=attention_fn, - dropout_rate=config.attention_dropout_rate, + dropout_rate=dropout_rate, deterministic=not train)( inputs_q=inputs, mask=attention_mask) - if config.attention_residual_dropout_rate is None: - attention_residual_dropout_rate = 0.1 - else: - attention_residual_dropout_rate = config.attention_residual_dropout_rate - result = nn.Dropout( - rate=attention_residual_dropout_rate, deterministic=not train)( - result) + result = Dropout( + rate=dropout_rate, deterministic=not train)( + result, rate=dropout_rate) return result @@ -528,7 +514,8 @@ def __call__(self, input_paddings, train, update_batch_norm, - use_running_average_bn): + use_running_average_bn, + dropout_rate=DROPOUT_RATE): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -574,13 +561,9 @@ def __call__(self, config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())( inputs) - if config.conv_residual_dropout_rate is None: - conv_residual_dropout_rate = 0.0 - else: - conv_residual_dropout_rate = config.conv_residual_dropout_rate - inputs = nn.Dropout( - rate=conv_residual_dropout_rate, deterministic=not train)( - inputs) + inputs = Dropout( + rate=dropout_rate, deterministic=not train)( + inputs, rate=dropout_rate) return inputs @@ -605,26 +588,28 @@ def __call__(self, input_paddings, train, update_batch_norm, - use_running_average): + use_running_average, + dropout_rate=DROPOUT_RATE): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( - inputs, padding_mask, train) + inputs, padding_mask, train, dropout_rate) inputs = inputs + MultiHeadedSelfAttention(config=self.config)( - inputs, input_paddings, train) + inputs, input_paddings, train, dropout_rate=dropout_rate) inputs = inputs + \ ConvolutionBlock(config)(inputs, input_paddings, train, update_batch_norm, - use_running_average + use_running_average, + dropout_rate ) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( - inputs, padding_mask, train) + inputs, padding_mask, train, dropout_rate) if config.use_post_layer_norm: inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -658,7 +643,8 @@ def __call__(self, input_paddings, train, update_batch_norm: Optional[bool] = None, - use_running_average_bn: Optional[bool] = None): + use_running_average_bn: Optional[bool] = None, + dropout_rate: float = DROPOUT_RATE): config = self.config outputs = inputs @@ -683,15 +669,9 @@ def __call__(self, if train and config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - # Subsample input by a factor of 4 by performing strided convolutions. - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate outputs, output_paddings = Subsample( - encoder_dim=config.encoder_dim, - input_dropout_rate=input_dropout_rate)( - outputs, output_paddings, train) + encoder_dim=config.encoder_dim,)( + outputs, output_paddings, train, dropout_rate=dropout_rate) # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): @@ -699,7 +679,8 @@ def __call__(self, output_paddings, train, update_batch_norm, - use_running_average_bn) + use_running_average_bn, + dropout_rate) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 39012a20d..1e1a1d3f8 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -60,11 +60,10 @@ def attention_temperature(self) -> float: def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + ) -> spec.ModelInitState: """Conformer model init function. - Here we use dropout_rate as *_residual_dropout_rate, and aux_dropout_rate as + Here we use dropout_rate as *_residual_dropout_rate, and for input_dropout_rate. """ if self.use_gelu: @@ -72,22 +71,19 @@ def init_model_fn( else: activation_function_name = 'swish' model_config = models.ConformerConfig( - attention_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - input_dropout_rate=aux_dropout_rate, use_specaug=self.use_specaug, attention_temperature=self.attention_temperature, use_post_layer_norm=self.use_post_layer_norm, activation_function_name=activation_function_name) + self._model = models.Conformer(model_config) input_shape = [(320000,), (320000,)] fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) - params_rng, dropout_rng = jax.random.split(rng, 2) - variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, - *fake_input_batch) + params_rng, _ = jax.random.split(rng, 2) + variables = model_init_fn({'params': params_rng}, *fake_input_batch) model_state, params = pop(variables, "params") @@ -108,7 +104,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None + use_running_average_bn: Optional[bool] = None, + dropout_rate: Optional[float] = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] @@ -121,7 +118,8 @@ def model_fn( train=True, rngs={'dropout' : rng}, mutable=['batch_stats'], - use_running_average_bn=use_running_average_bn) + use_running_average_bn=use_running_average_bn, + dropout_rate=dropout_rate) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index b116f44cd..262fc1a95 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -1,4 +1,4 @@ -r"""Deepspeech. +"""Deepspeech. This model uses a deepspeech2 network to convert speech to text. paper : https://arxiv.org/abs/1512.02595 @@ -16,6 +16,7 @@ from jax.experimental import rnn import jax.numpy as jnp +from algoperf.jax_utils import Dropout from algoperf.workloads.librispeech_conformer.librispeech_jax import \ librispeech_preprocessor as preprocessor from algoperf.workloads.librispeech_conformer.librispeech_jax import \ @@ -30,6 +31,8 @@ CarryHistory = Any Output = Any +DROPOUT_RATE = 0.1 + @struct.dataclass class DeepspeechConfig: @@ -51,10 +54,6 @@ class DeepspeechConfig: use_dynamic_time_mask_max_frames: bool = True batch_norm_momentum: float = 0.999 batch_norm_epsilon: float = 0.001 - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.1. - feed_forward_dropout_rate: Optional[float] = 0.1 enable_residual_connections: bool = True enable_decoder_layer_norm: bool = True bidirectional: bool = True @@ -72,7 +71,7 @@ class Subsample(nn.Module): config: DeepspeechConfig @nn.compact - def __call__(self, inputs, output_paddings, train): + def __call__(self, inputs, output_paddings, train, dropout_rate=DROPOUT_RATE): config = self.config outputs = jnp.expand_dims(inputs, axis=-1) @@ -106,13 +105,9 @@ def __call__(self, inputs, output_paddings, train): kernel_init=nn.initializers.xavier_uniform())( outputs) - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate - outputs = nn.Dropout( - rate=input_dropout_rate, deterministic=not train)( - outputs) + outputs = Dropout( + rate=dropout_rate, deterministic=not train)( + outputs, rate=dropout_rate) return outputs, output_paddings @@ -188,7 +183,11 @@ class FeedForwardModule(nn.Module): config: DeepspeechConfig @nn.compact - def __call__(self, inputs, input_paddings=None, train=False): + def __call__(self, + inputs, + input_paddings=None, + train=False, + dropout_rate=DROPOUT_RATE): padding_mask = jnp.expand_dims(1 - input_paddings, -1) config = self.config @@ -212,12 +211,8 @@ def __call__(self, inputs, input_paddings=None, train=False): inputs = nn.relu(inputs) inputs *= padding_mask - if config.feed_forward_dropout_rate is None: - feed_forward_dropout_rate = 0.1 - else: - feed_forward_dropout_rate = config.feed_forward_dropout_rate - inputs = nn.Dropout(rate=feed_forward_dropout_rate)( - inputs, deterministic=not train) + inputs = Dropout(rate=dropout_rate)( + inputs, deterministic=not train, rate=dropout_rate) return inputs @@ -473,7 +468,7 @@ def setup(self): ) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): config = self.config outputs = inputs @@ -494,7 +489,8 @@ def __call__(self, inputs, input_paddings, train): # Subsample input by a factor of 4 by performing strided convolutions. outputs, output_paddings = Subsample( - config=config)(outputs, output_paddings, train) + config=config)(outputs, output_paddings, train, + dropout_rate=dropout_rate) # Run the lstm layers. for _ in range(config.num_lstm_layers): @@ -508,9 +504,8 @@ def __call__(self, inputs, input_paddings, train): outputs = outputs + FeedForwardModule(config=self.config)( outputs, output_paddings, train) else: - outputs = FeedForwardModule(config=self.config)(outputs, - output_paddings, - train) + outputs = FeedForwardModule(config=self.config)( + outputs, output_paddings, train, dropout_rate=dropout_rate) # Run the decoder which in this case is a trivial projection layer. if config.enable_decoder_layer_norm: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index d3b616f43..81a56db72 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -15,20 +15,11 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Deepspeech model init function. - - Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate - as input_dropout_rate. """ model_config = models.DeepspeechConfig( - feed_forward_dropout_rate=dropout_rate, use_specaug=self.use_specaug, - input_dropout_rate=aux_dropout_rate, use_tanh=self.use_tanh, enable_residual_connections=self.enable_residual_connections, enable_decoder_layer_norm=self.enable_decoder_layer_norm, @@ -42,9 +33,10 @@ def init_model_fn( model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) - params_rng, dropout_rng = jax.random.split(rng, 2) - variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, - *fake_input_batch) + params_rng, _ = jax.random.split(rng, 2) + variables = model_init_fn({ + 'params': params_rng, + }, *fake_input_batch) model_state = variables[ 'batch_stats'] if not self.layernorm_everywhere else {} @@ -63,7 +55,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None + use_running_average_bn: Optional[bool] = None, + dropout_rate: Optional[bool] = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] @@ -75,7 +68,8 @@ def model_fn( input_paddings, train=True, rngs={'dropout' : rng}, - mutable=['batch_stats']) + mutable=['batch_stats'], + dropout_rate=dropout_rate) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 5a4382da1..27bd9ae54 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -32,14 +32,7 @@ def __call__(self, x: spec.Tensor, train: bool) -> spec.Tensor: class MnistWorkload(BaseMnistWorkload): - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Dropout is unused.""" - del dropout_rate - del aux_dropout_rate + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: init_val = jnp.ones((1, 28, 28, 1), jnp.float32) self._model = _Model() initial_params = self._model.init({'params': rng}, init_val, diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 0e66d2ab8..8524bb60e 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -1,11 +1,15 @@ # Forked from the init2winit implementation here # https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. -from typing import Optional, Tuple +from typing import Tuple from flax import linen as nn import jax.numpy as jnp import jraph +from algoperf.jax_utils import Dropout + +DROPOUT_RATE = 0.1 + def _make_embed(latent_dim, name): @@ -15,7 +19,7 @@ def make_fn(inputs): return make_fn -def _make_mlp(hidden_dims, dropout, activation_fn): +def _make_mlp(hidden_dims, activation_fn, train, dropout_rate=DROPOUT_RATE): """Creates a MLP with specified dimensions.""" @jraph.concatenated_args @@ -25,7 +29,9 @@ def make_fn(inputs): x = nn.Dense(features=dim)(x) x = nn.LayerNorm()(x) x = activation_fn(x) - x = dropout(x) + x = Dropout( + rate=dropout_rate, deterministic=not train)( + x, rate=dropout_rate) return x return make_fn @@ -39,18 +45,11 @@ class GNN(nn.Module): num_outputs: int latent_dim: int = 256 hidden_dims: Tuple[int] = (256,) - # If None, defaults to 0.1. - dropout_rate: Optional[float] = 0.1 num_message_passing_steps: int = 5 activation_fn_name: str = 'relu' @nn.compact - def __call__(self, graph, train): - if self.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = self.dropout_rate - dropout = nn.Dropout(rate=dropout_rate, deterministic=not train) + def __call__(self, graph, train, dropout_rate=DROPOUT_RATE): graph = graph._replace( globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) @@ -73,11 +72,20 @@ def __call__(self, graph, train): for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork( update_edge_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate), update_node_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate), update_global_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn)) + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate)) graph = net(graph) diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index e895d15a7..0535aea83 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -1,6 +1,6 @@ """OGBG workload implemented in Jax.""" import functools -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple from flax import jax_utils import jax @@ -17,17 +17,10 @@ class OgbgWorkload(BaseOgbgWorkload): - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is unused.""" - del aux_dropout_rate - rng, params_rng, dropout_rng = jax.random.split(rng, 3) + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: + rng, params_rng = jax.random.split(rng, 2) self._model = models.GNN( self._num_outputs, - dropout_rate=dropout_rate, activation_fn_name=self.activation_fn_name, hidden_dims=self.hidden_dims, latent_dim=self.latent_dim, @@ -41,7 +34,7 @@ def init_model_fn( globals=jnp.zeros((1, self._num_outputs)), senders=jnp.asarray([0]), receivers=jnp.asarray([0])) - params = init_fn({'params': params_rng, 'dropout': dropout_rng}, fake_batch) + params = init_fn({'params': params_rng}, fake_batch) params = params['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -57,7 +50,9 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del update_batch_norm # No BN in the GNN model. if model_state is not None: @@ -68,7 +63,8 @@ def model_fn( logits = self._model.apply({'params': params}, augmented_and_preprocessed_input_batch['inputs'], rngs={'dropout': rng}, - train=train) + train=train, + dropout_rate=dropout_rate) return logits, None def _binary_cross_entropy_with_mask( diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index 97fee032f..1147eb34b 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -11,6 +11,10 @@ import jax.numpy as jnp import numpy as np +from algoperf.jax_utils import Dropout + +DROPOUT_RATE = 0.1 + @struct.dataclass class TransformerConfig: @@ -26,10 +30,6 @@ class TransformerConfig: max_len: int = 256 activation: Callable = nn.relu glu: bool = False - #If None, defaults to 0.1. - dropout_rate: Optional[float] = 0.1 - #If None, defaults to 0.1. - attention_dropout_rate: Optional[float] = 0.1 attention_temp: float = 1.0 deterministic: bool = False decode: bool = False @@ -140,79 +140,78 @@ def __call__(self, inputs, inputs_positions=None): class MlpBlock(nn.Module): """Transformer MLP / feed-forward block. - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - out_dim: optionally specify out dimension. - """ + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + out_dim: optionally specify out dimension. + """ + config: TransformerConfig out_dim: Optional[int] = None @nn.compact - def __call__(self, inputs): + def __call__(self, inputs, dropout_rate=DROPOUT_RATE): """Applies Transformer MlpBlock module.""" cfg = self.config - actual_out_dim = ( - inputs.shape[-1] if self.out_dim is None else self.out_dim) + + actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( cfg.mlp_dim, dtype=cfg.dtype, kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - inputs) + bias_init=cfg.bias_init, + )( + inputs) x = cfg.activation(x) if cfg.glu: y = nn.Dense( cfg.mlp_dim, dtype=cfg.dtype, kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - inputs) + bias_init=cfg.bias_init, + )( + inputs) x = x * y - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)( + x, rate=dropout_rate, deterministic=cfg.deterministic) output = nn.Dense( actual_out_dim, dtype=cfg.dtype, kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - x) - output = nn.Dropout(rate=dropout_rate)( - output, deterministic=cfg.deterministic) + bias_init=cfg.bias_init, + )( + x) + output = Dropout(rate=dropout_rate)( + output, rate=dropout_rate, deterministic=cfg.deterministic) return output class Encoder1DBlock(nn.Module): """Transformer encoder layer. - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - """ + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + """ + config: TransformerConfig @nn.compact - def __call__(self, inputs, encoder_mask=None): + def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE): """Applies Encoder1DBlock module. - Args: - inputs: input data. - encoder_mask: encoder self-attention mask. + Args: + inputs: input data. + encoder_mask: encoder self-attention mask. - Returns: - output after transformer encoder block. - """ + Returns: + output after transformer encoder block. + """ cfg = self.config + pre_ln = cfg.pre_ln # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs - if cfg.attention_dropout_rate is None: - attention_dropout_rate = 0.1 - else: - attention_dropout_rate = cfg.attention_dropout_rate x = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, @@ -221,22 +220,19 @@ def __call__(self, inputs, encoder_mask=None): bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, - dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)( - cfg.attention_temp * x, x, mask=encoder_mask) + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + )(cfg.attention_temp * x, x, mask=encoder_mask) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)( + x, deterministic=cfg.deterministic, rate=dropout_rate) x = x + inputs if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) # MLP block. y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x - y = MlpBlock(config=cfg)(y) + y = MlpBlock(config=cfg)(y, dropout_rate=dropout_rate) return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y) @@ -244,39 +240,40 @@ def __call__(self, inputs, encoder_mask=None): class EncoderDecoder1DBlock(nn.Module): """Transformer encoder-decoder layer. - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - """ + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + """ + config: TransformerConfig @nn.compact - def __call__(self, - targets, - encoded, - decoder_mask=None, - encoder_decoder_mask=None): + def __call__( + self, + targets, + encoded, + decoder_mask=None, + encoder_decoder_mask=None, + dropout_rate=DROPOUT_RATE, + ): """Applies EncoderDecoder1DBlock module. - Args: - targets: input data for decoder - encoded: input data from encoder - decoder_mask: decoder self-attention mask. - encoder_decoder_mask: encoder-decoder attention mask. + Args: + targets: input data for decoder + encoded: input data from encoder + decoder_mask: decoder self-attention mask. + encoder_decoder_mask: encoder-decoder attention mask. - Returns: - output after transformer encoder-decoder block. - """ + Returns: + output after transformer encoder-decoder block. + """ cfg = self.config + pre_ln = cfg.pre_ln # Decoder block. assert targets.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets - if cfg.attention_dropout_rate is None: - attention_dropout_rate = 0.1 - else: - attention_dropout_rate = cfg.attention_dropout_rate x = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, @@ -285,15 +282,13 @@ def __call__(self, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, - dropout_rate=attention_dropout_rate, + dropout_rate=dropout_rate, deterministic=cfg.deterministic, - decode=cfg.decode)( - cfg.attention_temp * x, x, mask=decoder_mask) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + decode=cfg.decode, + )(cfg.attention_temp * x, x, mask=decoder_mask) + + x = Dropout(rate=dropout_rate)( + x, deterministic=cfg.deterministic, rate=dropout_rate) x = x + targets if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -308,18 +303,19 @@ def __call__(self, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, - dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)( - cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) - y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) + y = Dropout(rate=dropout_rate)( + y, deterministic=cfg.deterministic, rate=dropout_rate) y = y + x if not pre_ln: y = nn.LayerNorm(dtype=cfg.dtype)(y) # MLP block. z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y - z = MlpBlock(config=cfg)(z) + z = MlpBlock(config=cfg)(z, dropout_rate=dropout_rate) return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z) @@ -327,26 +323,32 @@ def __call__(self, class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation. - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - shared_embedding: a shared embedding layer to use. - """ + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + shared_embedding: a shared embedding layer to use. + """ + config: TransformerConfig shared_embedding: Any = None @nn.compact - def __call__(self, inputs, inputs_positions=None, encoder_mask=None): + def __call__(self, + inputs, + inputs_positions=None, + encoder_mask=None, + dropout_rate=DROPOUT_RATE): """Applies Transformer model on the inputs. - Args: - inputs: input data - inputs_positions: input subsequence positions for packed examples. - encoder_mask: decoder self-attention mask. + Args: + inputs: input data + inputs_positions: input subsequence positions for packed examples. + encoder_mask: decoder self-attention mask. - Returns: - output of a transformer encoder. - """ + Returns: + output of a transformer encoder. + """ cfg = self.config + assert inputs.ndim == 2 # (batch, len) # Input Embedding @@ -354,29 +356,27 @@ def __call__(self, inputs, inputs_positions=None, encoder_mask=None): input_embed = nn.Embed( num_embeddings=cfg.vocab_size, features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) + embedding_init=nn.initializers.normal(stddev=1.0), + ) else: input_embed = self.shared_embedding - x = inputs.astype('int32') + x = inputs.astype("int32") x = input_embed(x) x = AddPositionEmbs( - config=cfg, decode=False, name='posembed_input')( + config=cfg, decode=False, name="posembed_input")( x, inputs_positions=inputs_positions) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)( + x, deterministic=cfg.deterministic, rate=dropout_rate) x = x.astype(cfg.dtype) # Input Encoder for lyr in range(cfg.num_layers): x = Encoder1DBlock( - config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask) + config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask, dropout_rate) encoded = ( - nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x) + nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x) if cfg.pre_ln else x) return encoded @@ -385,32 +385,36 @@ def __call__(self, inputs, inputs_positions=None, encoder_mask=None): class Decoder(nn.Module): """Transformer Model Decoder for sequence to sequence translation. - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - shared_embedding: a shared embedding layer to use. - """ + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + shared_embedding: a shared embedding layer to use. + """ + config: TransformerConfig shared_embedding: Any = None @nn.compact - def __call__(self, - encoded, - targets, - targets_positions=None, - decoder_mask=None, - encoder_decoder_mask=None): + def __call__( + self, + encoded, + targets, + targets_positions=None, + decoder_mask=None, + encoder_decoder_mask=None, + dropout_rate=DROPOUT_RATE, + ): """Applies Transformer model on the inputs. - Args: - encoded: encoded input data from encoder. - targets: target inputs. - targets_positions: input subsequence positions for packed examples. - decoder_mask: decoder self-attention mask. - encoder_decoder_mask: encoder-decoder attention mask. + Args: + encoded: encoded input data from encoder. + targets: target inputs. + targets_positions: input subsequence positions for packed examples. + decoder_mask: decoder self-attention mask. + encoder_decoder_mask: encoder-decoder attention mask. - Returns: - output of a transformer decoder. - """ + Returns: + output of a transformer decoder. + """ cfg = self.config assert encoded.ndim == 3 # (batch, len, depth) @@ -421,35 +425,35 @@ def __call__(self, output_embed = nn.Embed( num_embeddings=cfg.vocab_size, features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) + embedding_init=nn.initializers.normal(stddev=1.0), + ) else: output_embed = self.shared_embedding - y = targets.astype('int32') + y = targets.astype("int32") if not cfg.decode: y = shift_right(y) y = output_embed(y) y = AddPositionEmbs( - config=cfg, decode=cfg.decode, name='posembed_output')( + config=cfg, decode=cfg.decode, name="posembed_output")( y, inputs_positions=targets_positions) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) + y = Dropout(rate=dropout_rate)( + y, deterministic=cfg.deterministic, rate=dropout_rate) y = y.astype(cfg.dtype) # Target-Input Decoder for lyr in range(cfg.num_layers): y = EncoderDecoder1DBlock( - config=cfg, name=f'encoderdecoderblock_{lyr}')( + config=cfg, name=f"encoderdecoderblock_{lyr}")( y, encoded, decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask) + encoder_decoder_mask=encoder_decoder_mask, + dropout_rate=dropout_rate, + ) y = ( - nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y) + nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y) if cfg.pre_ln else y) # Use the transpose of embedding matrix for logit transform. @@ -484,7 +488,11 @@ def setup(self): self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) - def encode(self, inputs, inputs_positions=None, inputs_segmentation=None): + def encode(self, + inputs, + inputs_positions=None, + inputs_segmentation=None, + dropout_rate=DROPOUT_RATE): """Applies Transformer encoder-branch on the inputs. Args: @@ -509,7 +517,10 @@ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None): jnp.equal, dtype=cfg.dtype)) return self.encoder( - inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask) + inputs, + inputs_positions=inputs_positions, + encoder_mask=encoder_mask, + dropout_rate=dropout_rate) def decode( self, @@ -518,7 +529,8 @@ def decode( targets, targets_positions=None, inputs_segmentation=None, - targets_segmentation=None): + targets_segmentation=None, + dropout_rate=DROPOUT_RATE): """Applies Transformer decoder-branch on encoded-input and target. Args: @@ -567,7 +579,8 @@ def decode( targets, targets_positions=targets_positions, decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask) + encoder_decoder_mask=encoder_decoder_mask, + dropout_rate=dropout_rate) return logits.astype(self.config.dtype) def __call__(self, @@ -576,7 +589,8 @@ def __call__(self, inputs_positions=None, targets_positions=None, inputs_segmentation=None, - targets_segmentation=None): + targets_segmentation=None, + dropout_rate=DROPOUT_RATE): """Applies Transformer model on the inputs. Args: @@ -593,7 +607,8 @@ def __call__(self, encoded = self.encode( inputs, inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation) + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate) return self.decode( encoded, @@ -601,4 +616,5 @@ def __call__(self, targets, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, - targets_segmentation=targets_segmentation) + targets_segmentation=targets_segmentation, + dropout_rate=dropout_rate) diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index cdfcb91df..d402f9d95 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -206,13 +206,7 @@ def translate_and_calculate_bleu(self, bleu_score = bleu.corpus_bleu(predictions, [references]).score return bleu_score - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" - + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: init_fake_batch_size = 2 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) @@ -225,8 +219,6 @@ def init_model_fn( raise ValueError(f'Unknown activation function {self.activation}.') model_config = models.TransformerConfig( - dropout_rate=dropout_rate, - attention_dropout_rate=aux_dropout_rate, pre_ln=self.pre_ln, attention_temp=self.attention_temp, activation=activation, @@ -234,9 +226,9 @@ def init_model_fn( self._train_model = models.Transformer(model_config) eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) - params_rng, dropout_rng = jax.random.split(rng) + params_rng, _ = jax.random.split(rng) initial_variables = jax.jit( - self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, + self._eval_model.init)({'params': params_rng}, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) @@ -255,7 +247,9 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: [float] = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm @@ -282,7 +276,8 @@ def model_fn( targets_positions=targets_positions, inputs_segmentation=inputs_segmentations, targets_segmentation=targets_segmentations, - rngs={'dropout': rng}) + rngs={'dropout': rng}, + dropout_rate=dropout_rate) return logits_batch, None def _normalize_eval_metrics( diff --git a/docker/Dockerfile b/docker/Dockerfile index 76bc5cfe0..9926b0542 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -23,8 +23,8 @@ RUN apt-get update && apt-get install -y \ libreadline-dev \ libffi-dev \ curl \ - libbz2-dev \ liblzma-dev \ + libbz2-dev \ vim # Download and install Python 3.11 @@ -56,8 +56,6 @@ RUN echo "Setting up directories for data and experiment_runs" RUN mkdir -p data/ RUN mkdir -p experiment_runs/ -RUN pip install --upgrade pip - # Install Algorithmic efficiency repo RUN pip install --upgrade pip @@ -71,18 +69,18 @@ RUN cd /algorithmic-efficiency && git checkout $branch RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ + && pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \ + && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_cpu]' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ + && pip install -e '.[pytorch_gpu]' \ + && pip install -e '.[jax_cpu]'; \ elif [ "$framework" = "both" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ + && pip install -e '.[pytorch_gpu]' \ + && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \ diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 645b81955..6b5e67ceb 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -1,27 +1,40 @@ #!/bin/bash # Bash script to build and push dev docker images to artifact repo # Usage: -# bash build_docker_images.sh -b +# bash build_docker_images.sh -b -f # Make program exit with non-zero exit code if any command fails. set -e -while getopts b: flag +while getopts "b:p:f:" flag; do case "${flag}" in b) GIT_BRANCH=${OPTARG};; + p) PROJECT=${OPTARG};; + f) FRAMEWORK=${OPTARG};; esac done # Artifact repostiory -ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" +if [ "$PROJECT" = "mlcommons-algoperf" ]; then + ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" +else + ARTIFACT_REPO="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo" +fi -if [[ -z ${GIT_BRANCH+x} ]] +if [[ -z ${GIT_BRANCH+x} ]]; then GIT_BRANCH='main' # Set default argument fi -for FRAMEWORK in "jax" "pytorch" "both" +FRAMEWORKS=( "jax" "pythorch" "both" ) + +if [[ -n "$FRAMEWORK" ]]; +then + FRAMEWORKS=("$FRAMEWORK") +fi + +for FRAMEWORK in "${FRAMEWORKS[@]}"; do IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" DOCKER_BUILD_COMMAND="docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH" diff --git a/pyproject.toml b/pyproject.toml index 4e15e4400..1daa72848 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,7 @@ pytorch_gpu = [ based_on_style = "yapf" each_dict_entry_on_separate_line = false split_all_top_level_comma_separated_values = true +column_limit = 80 [tool.yapfignore] ignore_patterns = ["algoperf/_version.py"] diff --git a/submission_runner.py b/submission_runner.py index 468a04c7c..221a7c21d 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -228,14 +228,7 @@ def train_once( global_batch_size=global_batch_size) logging.info('Initializing model.') with profiler.profile('Initializing model'): - dropout_rate = None - aux_dropout_rate = None - if hasattr(hyperparameters, 'dropout_rate'): - dropout_rate = hyperparameters.dropout_rate - if hasattr(hyperparameters, 'aux_dropout_rate'): - aux_dropout_rate = hyperparameters.aux_dropout_rate - model_params, model_state = workload.init_model_fn( - model_init_rng, dropout_rate, aux_dropout_rate) + model_params, model_state = workload.init_model_fn(model_init_rng) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ 'librispeech_conformer', diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index c4ca514a8..58a4a5ddc 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -184,12 +184,11 @@ def __init__(self): if 'librispeech' in workload_name: self.tokenizer = _FakeTokenizer() - def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None): + def init_model_fn(self, rng): # pylint: disable=line-too-long if not (FLAGS.identical and os.path.exists(f'tests/modeldiffs/{workload_name}/compare.py')): - return super().init_model_fn( - rng, dropout_rate=dropout_rate, aux_dropout_rate=aux_dropout_rate) + return super().init_model_fn(rng) if framework == 'jax': compare_module = importlib.import_module( f'tests.modeldiffs.{workload_name}.compare') @@ -201,7 +200,7 @@ def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None): return (FrozenDict(**jax_utils.replicate(jax_params)), FrozenDict(**jax_utils.replicate(model_state)) if model_state is not None else model_state) - return super().init_model_fn([0], dropout_rate=0.0, aux_dropout_rate=0.0) + return super().init_model_fn([0]) @property def num_eval_train_examples(self):