Skip to content

Dropout JAX -> dropout_support #864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: dropout_support
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
bf61255
Merge pull request #825 from mlcommons/dev
priyakasimbeg Feb 5, 2025
9653f18
Merge pull request #843 from mlcommons/dev
priyakasimbeg Feb 11, 2025
e212af3
add basic summary of benchmark to beginning of readme
qpwo Feb 18, 2025
258a13c
Merge branch 'dev' into documentation_update
priyakasimbeg Apr 8, 2025
e8ed95b
updates to documentation
priyakasimbeg Apr 8, 2025
7638497
Merge pull request #861 from mlcommons/documentation_update
priyakasimbeg Apr 8, 2025
d672b84
add jit-friendly dropout w rate in call
priyakasimbeg Apr 17, 2025
aa25e20
remove nan_to_num convertion
priyakasimbeg May 5, 2025
85a3578
update models with custom dropout layer
priyakasimbeg May 5, 2025
9354079
add functional dropout for criteo, fastmri, and vit
priyakasimbeg May 5, 2025
feb9cc5
add functional dropout for ogbg
priyakasimbeg May 5, 2025
9bba078
modify wmt model for dropout passing
priyakasimbeg May 15, 2025
31f6019
modify wmt model for dropout passing
priyakasimbeg May 15, 2025
e36d294
reformatting and dropout fixes to fastmri and vit
priyakasimbeg May 29, 2025
363da8a
dropout fix for criteo1tb jax
priyakasimbeg May 29, 2025
341bf89
dropout fix for criteo1tb jax
priyakasimbeg May 29, 2025
f0c385b
remove aux dropout option from conformer and from init_model_fn signa…
priyakasimbeg May 29, 2025
7af5c94
add dropout piping for conformer and deepspeech
priyakasimbeg May 31, 2025
cbd065b
pipe dropout through model_fn
priyakasimbeg May 31, 2025
31babfd
fix syntax
priyakasimbeg Jun 4, 2025
95d67db
dropout changes wmt jax
priyakasimbeg Jun 4, 2025
2c96b88
modify dockerfile
priyakasimbeg Jun 4, 2025
54786a6
modify docker build script
priyakasimbeg Jun 4, 2025
246d68e
fsmall fixes
priyakasimbeg Jun 5, 2025
0c8dd14
change docker base image to 12.1.1
priyakasimbeg Jun 6, 2025
a78fa66
update base image
priyakasimbeg Jun 6, 2025
f0019ac
small fix
priyakasimbeg Jun 7, 2025
3cb012e
remove aux_dropout from submission_runner.py
priyakasimbeg Jun 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@

---

Unlike benchmarks that focus on model architecture or hardware, the AlgoPerf benchmark isolates the training algorithm itself, measuring how quickly it can achieve target performance levels on a fixed set of representative deep learning tasks. These tasks span various domains, including image classification, speech recognition, machine translation, and more, all running on standardized hardware (8x NVIDIA V100 GPUs). The benchmark includes 8 base workloads, which are fully specified. In addition there are definitions for "randomized" workloads, which are variations of the fixed workloads, which are designed to discourage overfitting. These randomized workloads were used for scoring the AlgPerf competition but will not be used for future scoring.

Submissions are evaluated based on their "time-to-result", i.e., the wall-clock time it takes to reach predefined validation and test set performance targets on each workload. Submissions are scored under one of two different tuning rulesets. The [external tuning rule set](https://github.com/mlcommons/algorithmic-efficiency/blob/main/docs/DOCUMENTATION.md#external-tuning-ruleset) allows a limited amount of hyperparameter tuning (20 quasirandom trials) for each workload. The [self-tuning rule set](https://github.com/mlcommons/algorithmic-efficiency/blob/main/docs/DOCUMENTATION.md#self-tuning-ruleset) allows no external tuning, so any tuning is done "on-the-clock". For each submission, a single, overall benchmark score is computed by integrating its "performance profile" across all fixed workloads. The performance profile captures the relative training time of the submission to the best submission on each workload. Therefore the score of each submission is a function of other submissions in the submission pool. The higher the benchmark score, the better the submission's overall performance.

---

> This is the repository for the *AlgoPerf: Training Algorithms benchmark* measuring neural network training speedups due to algorithmic improvements.
> It is developed by the [MLCommons Algorithms Working Group](https://mlcommons.org/en/groups/research-algorithms/).
> This repository holds the benchmark code, the benchmark's [**technical documentation**](/docs/DOCUMENTATION.md) and [**getting started guides**](/docs/GETTING_STARTED.md). For a detailed description of the benchmark design, see our [**introductory paper**](https://arxiv.org/abs/2306.07179), for the results of the inaugural competition see our [**results paper**](https://openreview.net/forum?id=CtM5xjRSfm).
Expand Down
118 changes: 118 additions & 0 deletions algoperf/jax_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from collections.abc import Sequence

import jax
import jax.numpy as jnp
from jax import lax, random

import flax.linen as nn
from flax.linen.module import Module, compact, merge_param
from flax.typing import PRNGKey


# Custom Layers
class Dropout(Module):
"""Create a dropout layer.
Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
The reference dropout implementation is modified support changes to dropout rate during training by:
1) adding rate argument to the __call__ method
2) removing the if-else condition to check for edge cases, which will trigger a recompile for jitted code

.. note::
When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
variable initialization.

Example usage::

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class MLP(nn.Module):
... @nn.compact
... def __call__(self, x, train):
... x = nn.Dense(4)(x)
... x = nn.Dropout(0.5, deterministic=not train)(x)
... return x

>>> model = MLP()
>>> x = jnp.ones((1, 3))
>>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
>>> model.apply(variables, x, train=False) # don't use dropout
Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32)
>>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32)

Attributes:
rate: the dropout probability. (_not_ the keep rate!)
broadcast_dims: dimensions that will share the same dropout mask
deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and
masked, whereas if true, no mask is applied and the inputs are returned as
is.
rng_collection: the rng collection name to use when requesting an rng key.
"""

rate: float | None = None
broadcast_dims: Sequence[int] = ()
deterministic: bool | None = None
rng_collection: str = "dropout"
legacy: bool = True

@compact
def __call__(
self,
inputs,
deterministic: bool | None = None,
rate: float | None = None,
rng: PRNGKey | None = None,
):
"""Applies a random dropout mask to the input.

Args:
inputs: the inputs that should be randomly masked.
deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and
masked, whereas if true, no mask is applied and the inputs are returned
as is.
rate: the dropout probability. (_not_ the keep rate!)
rng: an optional PRNGKey used as the random key, if not specified, one
will be generated using ``make_rng`` with the ``rng_collection`` name.

Returns:
The masked inputs reweighted to preserve mean.
"""
deterministic = merge_param("deterministic", self.deterministic, deterministic)

# Override self.rate if rate is passed to __call__
if not (self.rate is not None and rate is not None):
rate = merge_param("rate", self.rate, rate)

if self.legacy:
if rate == 0.0:
return inputs

# Prevent gradient NaNs in 1.0 edge-case.
if rate == 1.0:
return jnp.zeros_like(inputs)

if deterministic:
return inputs

keep_prob = 1.0 - rate
if rng is None:
rng = self.make_rng(self.rng_collection)
broadcast_shape = list(inputs.shape)
for dim in self.broadcast_dims:
broadcast_shape[dim] = 1
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
mask = jnp.broadcast_to(mask, inputs.shape)
return lax.select(mask, inputs, jnp.zeros_like(inputs))


# Utilities for debugging
def print_jax_model_summary(model, fake_inputs):
"""Prints a summary of the jax module."""
tabulate_fn = nn.tabulate(
model,
jax.random.PRNGKey(0),
console_kwargs={"force_terminal": False, "force_jupyter": False, "width": 240},
)
print(tabulate_fn(fake_inputs, train=False))
25 changes: 16 additions & 9 deletions algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""A JAX implementation of DLRM-Small."""

from typing import Sequence

import flax.linen as nn
from jax import nn as jnn
import jax.numpy as jnp

from algoperf.jax_utils import Dropout


class DLRMResNet(nn.Module):
"""Define a DLRMResNet model.
Expand All @@ -28,7 +29,10 @@ class DLRMResNet(nn.Module):
embedding_init_multiplier: float = None # Unused

@nn.compact
def __call__(self, x, train):
def __call__(self, x, train, dropout_rate=None):
if dropout_rate is None:
dropout_rate = self.dropout_rate

bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)

Expand Down Expand Up @@ -88,8 +92,8 @@ def scaled_init(key, shape, dtype=jnp.float_):
stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))(
top_mlp_input)
x = nn.relu(x)
if self.dropout_rate and layer_idx == num_layers_top - 2:
x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
if dropout_rate and layer_idx == num_layers_top - 2:
x = Dropout(dropout_rate, deterministic=not train)(x, rate=dropout_rate)
top_mlp_input += x
# In the DLRM model the last layer width is always 1. We can hardcode that
# below.
Expand Down Expand Up @@ -151,7 +155,10 @@ class DlrmSmall(nn.Module):
embedding_init_multiplier: float = None

@nn.compact
def __call__(self, x, train):
def __call__(self, x, train, dropout_rate=None):
if dropout_rate is None:
dropout_rate = self.dropout_rate

bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)

Expand Down Expand Up @@ -210,10 +217,10 @@ def scaled_init(key, shape, dtype=jnp.float_):
top_mlp_input = nn.relu(top_mlp_input)
if self.use_layer_norm:
top_mlp_input = nn.LayerNorm()(top_mlp_input)
if (self.dropout_rate is not None and self.dropout_rate > 0.0 and
if (dropout_rate is not None and dropout_rate > 0.0 and
layer_idx == num_layers_top - 2):
top_mlp_input = nn.Dropout(
rate=self.dropout_rate, deterministic=not train)(
top_mlp_input)
top_mlp_input = Dropout(
dropout_rate, deterministic=not train)(
top_mlp_input, rate=dropout_rate)
logits = top_mlp_input
return logits
35 changes: 23 additions & 12 deletions algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,33 @@ def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None,
tabulate: Optional[bool] = False,
) -> spec.ModelInitState:
"""Only dropout is used."""
del aux_dropout_rate
if self.use_resnet:
model_class = models.DLRMResNet
else:
model_class = models.DlrmSmall
self._model = model_class(
vocab_size=self.vocab_size,
num_dense_features=self.num_dense_features,
mlp_bottom_dims=self.mlp_bottom_dims,
mlp_top_dims=self.mlp_top_dims,
embed_dim=self.embed_dim,
dropout_rate=dropout_rate,
use_layer_norm=self.use_layer_norm,
embedding_init_multiplier=self.embedding_init_multiplier)

if dropout_rate is None:
self._model = model_class(
vocab_size=self.vocab_size,
num_dense_features=self.num_dense_features,
mlp_bottom_dims=self.mlp_bottom_dims,
mlp_top_dims=self.mlp_top_dims,
embed_dim=self.embed_dim,
use_layer_norm=self.use_layer_norm,
embedding_init_multiplier=self.embedding_init_multiplier)
else:
self._model = model_class(
vocab_size=self.vocab_size,
num_dense_features=self.num_dense_features,
mlp_bottom_dims=self.mlp_bottom_dims,
mlp_top_dims=self.mlp_top_dims,
embed_dim=self.embed_dim,
dropout_rate=dropout_rate,
use_layer_norm=self.use_layer_norm,
embedding_init_multiplier=self.embedding_init_multiplier)

params_rng, dropout_rng = jax.random.split(rng)
init_fake_batch_size = 2
Expand All @@ -117,14 +126,16 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
inputs = augmented_and_preprocessed_input_batch['inputs']
train = mode == spec.ForwardPassMode.TRAIN
apply_kwargs = {'train': train}
if train:
apply_kwargs['rngs'] = {'dropout': rng}
apply_kwargs['dropout_rate'] = dropout_rate
logits_batch = self._model.apply({'params': params}, inputs, **apply_kwargs)
return logits_batch, None

Expand Down
27 changes: 15 additions & 12 deletions algoperf/workloads/fastmri/fastmri_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import jax
import jax.numpy as jnp

from algoperf.jax_utils import Dropout


def _instance_norm2d(x, axes, epsilon=1e-5):
# promote x to at least float32, this avoids half precision computation
Expand Down Expand Up @@ -56,15 +58,14 @@ class UNet(nn.Module):
num_channels: int = 32
num_pool_layers: int = 4
out_channels = 1
dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
dropout_rate: float = 0.0
use_tanh: bool = False
use_layer_norm: bool = False

@nn.compact
def __call__(self, x, train=True):
dropout_rate = self.dropout_rate
def __call__(self, x, train=True, dropout_rate=None):
if dropout_rate is None:
dropout_rate = 0.0
dropout_rate = self.dropout_rate

# pylint: disable=invalid-name
_ConvBlock = functools.partial(
Expand Down Expand Up @@ -138,12 +139,12 @@ class ConvBlock(nn.Module):
dropout_rate: Dropout probability.
"""
out_channels: int
dropout_rate: float
use_tanh: bool
use_layer_norm: bool
dropout_rate: float = 0.0

@nn.compact
def __call__(self, x, train=True):
def __call__(self, x, train=True, dropout_rate=None):
"""Forward function.
Note: Pytorch is NCHW and jax/flax is NHWC.
Args:
Expand All @@ -152,6 +153,8 @@ def __call__(self, x, train=True):
Returns:
jnp.array: Output tensor of shape `(N, H, W, out_channels)`.
"""
if dropout_rate is None:
dropout_rate = self.dropout_rate
x = nn.Conv(
features=self.out_channels,
kernel_size=(3, 3),
Expand All @@ -172,9 +175,9 @@ def __call__(self, x, train=True):
x = activation_fn(x)
# Ref code uses dropout2d which applies the same mask for the entire channel
# Replicated by using broadcast dims to have the same filter on HW
x = nn.Dropout(
self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
x)
x = Dropout(
dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
x, rate=dropout_rate)
x = nn.Conv(
features=self.out_channels,
kernel_size=(3, 3),
Expand All @@ -186,9 +189,9 @@ def __call__(self, x, train=True):
else:
x = _instance_norm2d(x, (1, 2))
x = activation_fn(x)
x = nn.Dropout(
self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
x)
x = Dropout(
dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
x, rate=dropout_rate)
return x


Expand Down
38 changes: 25 additions & 13 deletions algoperf/workloads/fastmri/fastmri_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,24 @@ def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
) -> spec.ModelInitState:
"""aux_dropout_rate is unused."""
del aux_dropout_rate
fake_batch = jnp.zeros((13, 320, 320))
self._model = UNet(
num_pool_layers=self.num_pool_layers,
num_channels=self.num_channels,
use_tanh=self.use_tanh,
use_layer_norm=self.use_layer_norm,
dropout_rate=dropout_rate)
if dropout_rate is None:
self._model = UNet(
num_pool_layers=self.num_pool_layers,
num_channels=self.num_channels,
use_tanh=self.use_tanh,
use_layer_norm=self.use_layer_norm,
)
else:
self._model = UNet(
num_pool_layers=self.num_pool_layers,
num_channels=self.num_channels,
use_tanh=self.use_tanh,
use_layer_norm=self.use_layer_norm,
dropout_rate=dropout_rate)

params_rng, dropout_rng = jax.random.split(rng)
variables = jax.jit(
self._model.init)({'params': params_rng, 'dropout': dropout_rng},
Expand All @@ -52,14 +60,18 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
train = mode == spec.ForwardPassMode.TRAIN
logits = self._model.apply({'params': params},
augmented_and_preprocessed_input_batch['inputs'],
rngs={'dropout': rng},
train=train)

if train:
logits = self._model.apply({'params': params},
augmented_and_preprocessed_input_batch['inputs'],
rngs={'dropout': rng},
train=train,
dropout_rate=dropout_rate)
return logits, None

# Does NOT apply regularization, which is left to the submitter to do in
Expand Down
Loading
Loading