Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions oryx/core/interpreters/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,14 +1211,14 @@ def _calc_extra_inps(num_consts, params):
def _reap_pjit_rule(trace, *tracers, **params):
"""Reap pjit rule."""
if params['in_shardings'] and not any(
sharding_impls.is_unspecified(i) for i in params['in_shardings']
isinstance(i, sharding_impls.UnspecifiedValue) for i in params['in_shardings']
):
raise ValueError(
'oryx only supports pjit which has no in_axis_resources '
f'specified. Got {params["in_shardings"]}'
)
if params['out_shardings'] and not any(
sharding_impls.is_unspecified(o) for o in params['out_shardings']
isinstance(o, sharding_impls.UnspecifiedValue) for o in params['out_shardings']
):
raise ValueError(
'oryx only supports pjit which has no out_axis_resources '
Expand Down Expand Up @@ -1648,14 +1648,14 @@ def _plant_checkpoint_rule(trace, *tracers, jaxpr, policy, prevent_cse,
def _plant_pjit_rule(trace, *tracers, **params):
"""Plant pjit rule."""
if params['in_shardings'] and not any(
sharding_impls.is_unspecified(i) for i in params['in_shardings']
isinstance(i, sharding_impls.UnspecifiedValue) for i in params['in_shardings']
):
raise ValueError(
'oryx only supports pjit which has no in_axis_resources '
f'specified. Got {params["in_shardings"]}'
)
if params['out_shardings'] and not any(
sharding_impls.is_unspecified(o) for o in params['out_shardings']
isinstance(o, sharding_impls.UnspecifiedValue) for o in params['out_shardings']
):
raise ValueError(
'oryx only supports pjit which has no out_axis_resources '
Expand Down
5 changes: 2 additions & 3 deletions oryx/core/interpreters/propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from jax._src import sharding_impls
from jax.extend import linear_util as lu
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla

from oryx.core import pytree
from oryx.core import trace_util
Expand Down Expand Up @@ -367,10 +366,10 @@ def _pjit_propagate_rule(incells, outcells, **params):
"""Propagate rule for pjit primitive."""
# TODO(https://github.com/jax-ml/oryx/issues/29): Fix this rule so that it # pylint: disable=g-bad-todo
# works correct for in_sharding, out_shardings and donated_invars.
if not any(pxla._is_unspecified(i) for i in params['in_shardings']): # pylint: disable=protected-access
if not any(isinstance(i, sharding_impls.UnspecifiedValue) for i in params['in_shardings']): # pylint: disable=protected-access
raise ValueError('oryx only supports pjit which has no in_axis_resources '
'specified.')
if not any(pxla._is_unspecified(o) for o in params['out_shardings']): # pylint: disable=protected-access
if not any(isinstance(o, sharding_impls.UnspecifiedValue) for o in params['out_shardings']): # pylint: disable=protected-access
raise ValueError('oryx only supports pjit which has no out_axis_resources '
'specified.')

Expand Down