File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed 
checkpoint/orbax/checkpoint/_src/path Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -528,7 +528,11 @@ def _find_all_with_single_host_load_and_broadcast(
528528
529529  def  find_all (self , base_path : epath .PathLike ) ->  Iterator [Metadata ]:
530530    """Returns metadata of all steps matching with name_format attributes.""" 
531-     if  multihost .process_count () >  1  and  self .single_host_load_and_broadcast :
531+     # Note: the order of conjuncts is important here; we should not call 
532+     # `multihost.process_count()` when `single_host_load_and_broadcast` is False 
533+     # as this has the possible side effect of initializing the jax backend. See 
534+     # b/454565916 for details. 
535+     if  self .single_host_load_and_broadcast  and  multihost .process_count () >  1 :
532536      return  self ._find_all_with_single_host_load_and_broadcast (base_path )
533537
534538    # <step_prefix>_?<0 padding>?* 
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments