Skip to content

Commit a7c5aef

Browse files
tomcobleyOrbax Authors
authored andcommitted
Add flag option to raise an error if Pathways client is initialized before _gemini_lost_callback is set.
PiperOrigin-RevId: 823495438
1 parent 3a5ef80 commit a7c5aef

File tree

1 file changed

+5
-1
lines changed
  • checkpoint/orbax/checkpoint/_src/path

1 file changed

+5
-1
lines changed

checkpoint/orbax/checkpoint/_src/path/step.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,11 @@ def _find_all_with_single_host_load_and_broadcast(
528528

529529
def find_all(self, base_path: epath.PathLike) -> Iterator[Metadata]:
530530
"""Returns metadata of all steps matching with name_format attributes."""
531-
if multihost.process_count() > 1 and self.single_host_load_and_broadcast:
531+
# Note: the order of conjuncts is important here; we should not call
532+
# `multihost.process_count()` when `single_host_load_and_broadcast` is False
533+
# as this has the possible side effect of initializing the jax backend. See
534+
# b/454565916 for details.
535+
if self.single_host_load_and_broadcast and multihost.process_count() > 1:
532536
return self._find_all_with_single_host_load_and_broadcast(base_path)
533537

534538
# <step_prefix>_?<0 padding>?*

0 commit comments

Comments
 (0)