1515"""Gradient-based trajectory length adaptation kernel."""
1616
1717import collections
18+ import functools
1819
1920import tensorflow .compat .v2 as tf
2021
@@ -89,6 +90,16 @@ def _map_structure_up_to_with_axes(structure, fn, *args,
8990 experimental_shard_axis_names )
9091
9192
93+ def _reduce_with_axes (index_op , name_op , x , axis_idx = None , axis_names = None ):
94+ return name_op (index_op (x , axis_idx ), axis_names )
95+
96+
97+ _reduce_sum_with_axes = functools .partial (_reduce_with_axes , tf .reduce_sum ,
98+ distribute_lib .psum )
99+ _reduce_mean_with_axes = functools .partial (_reduce_with_axes , tf .reduce_mean ,
100+ distribute_lib .pmean )
101+
102+
92103def hmc_like_num_leapfrog_steps_getter_fn (kernel_results ):
93104 """Getter for `num_leapfrog_steps` so it can be inspected."""
94105 return unnest .get_innermost (kernel_results , 'num_leapfrog_steps' )
@@ -132,7 +143,8 @@ def chees_criterion(previous_state,
132143 proposed_state ,
133144 accept_prob ,
134145 validate_args = False ,
135- experimental_shard_axis_names = None ):
146+ experimental_shard_axis_names = None ,
147+ experimental_chain_axis_names = None ):
136148 """The ChEES criterion from [1].
137149
138150 ChEES stands for Change in the Estimator of the Expected Square.
@@ -166,6 +178,8 @@ def chees_criterion(previous_state,
166178 validate_args: Whether to perform non-static argument validation.
167179 experimental_shard_axis_names: A structure of string names indicating how
168180 members of the state are sharded.
181+ experimental_chain_axis_names: A string or list of string names indicating
182+ how batches of chains are sharded.
169183
170184 Returns:
171185 chees: The value of the ChEES criterion.
@@ -182,7 +196,13 @@ def chees_criterion(previous_state,
182196 """
183197 batch_ndims = ps .rank (accept_prob )
184198 batch_axes = ps .range (batch_ndims , dtype = tf .int32 )
185- num_chains = ps .size (accept_prob )
199+ experimental_chain_axis_names = distribute_lib .canonicalize_axis_name (
200+ experimental_chain_axis_names )
201+ # Number of total chains is local batch size * distributed axis size
202+ local_axis_size = ps .maximum (ps .size (accept_prob ), 1 )
203+ distributed_axis_size = int (ps .reduce_prod ([
204+ distribute_lib .get_axis_size (a ) for a in experimental_chain_axis_names ]))
205+ num_chains = local_axis_size * distributed_axis_size
186206 num_chains_ = tf .get_static_value (num_chains )
187207 if num_chains_ is not None :
188208 if num_chains_ < 2 :
@@ -199,7 +219,9 @@ def chees_criterion(previous_state,
199219 def _center_previous_state (x ):
200220 # The empirical mean here is a stand-in for the true mean, so we drop the
201221 # gradient that flows through this term.
202- return x - tf .stop_gradient (tf .reduce_mean (x , axis = batch_axes ))
222+ x_mean = _reduce_mean_with_axes (
223+ x , batch_axes , experimental_chain_axis_names )
224+ return x - tf .stop_gradient (x_mean )
203225
204226 def _center_proposed_state (x ):
205227 # The empirical mean here is a stand-in for the true mean, so we drop the
@@ -216,8 +238,10 @@ def _center_proposed_state(x):
216238 # If all accept_prob's are zero, the x_center will have a nonsense value,
217239 # but we'll discard the resultant gradients later on, so it's fine.
218240 x_center = (
219- tf .reduce_sum (expanded_accept_prob * x_safe , axis = batch_axes ) /
220- (tf .reduce_sum (expanded_accept_prob , axis = batch_axes ) + 1e-20 ))
241+ _reduce_sum_with_axes (expanded_accept_prob * x_safe , batch_axes ,
242+ experimental_chain_axis_names ) /
243+ (_reduce_sum_with_axes (expanded_accept_prob , batch_axes ,
244+ experimental_chain_axis_names ) + 1e-20 ))
221245
222246 return x - tf .stop_gradient (x_center )
223247
@@ -358,6 +382,7 @@ def __init__(
358382 proposed_state_getter_fn = hmc_like_proposed_state_getter_fn ,
359383 validate_args = False ,
360384 experimental_shard_axis_names = None ,
385+ experimental_chain_axis_names = None ,
361386 name = None ):
362387 """Creates the trajectory length adaptation kernel.
363388
@@ -414,6 +439,8 @@ def __init__(
414439 outputs.
415440 experimental_shard_axis_names: A structure of string names indicating how
416441 members of the state are sharded.
442+ experimental_chain_axis_names: A string or list of string names indicating
443+ how batches of chains are sharded.
417444 name: Python `str` name prefixed to Ops created by this class. Default:
418445 'simple_step_size_adaptation'.
419446
@@ -452,6 +479,7 @@ class docstring).
452479 proposed_state_getter_fn = hmc_like_proposed_state_getter_fn ,
453480 validate_args = validate_args ,
454481 experimental_shard_axis_names = experimental_shard_axis_names ,
482+ experimental_chain_axis_names = experimental_chain_axis_names ,
455483 name = name ,
456484 )
457485
@@ -468,12 +496,15 @@ def num_adaptation_steps(self):
468496 return self ._parameters ['num_adaptation_steps' ]
469497
470498 def criterion_fn (self , previous_state , proposed_state , accept_prob ):
471- if self .experimental_shard_axis_names is None :
472- return self ._parameters ['criterion_fn' ](previous_state , proposed_state ,
473- accept_prob )
474- return self ._parameters ['criterion_fn' ](
475- previous_state , proposed_state , accept_prob ,
476- experimental_shard_axis_names = self .experimental_shard_axis_names )
499+ kwargs = {}
500+ if self .experimental_chain_axis_names is not None :
501+ kwargs ['experimental_chain_axis_names' ] = (
502+ self .experimental_chain_axis_names )
503+ if self .experimental_shard_axis_names is not None :
504+ kwargs ['experimental_shard_axis_names' ] = (
505+ self .experimental_shard_axis_names )
506+ return self ._parameters ['criterion_fn' ](previous_state , proposed_state ,
507+ accept_prob , ** kwargs )
477508
478509 @property
479510 def max_leapfrog_steps (self ):
@@ -567,7 +598,8 @@ def one_step(self, current_state, previous_kernel_results, seed=None):
567598 step_size = step_size ,
568599 criterion_fn = self .criterion_fn ,
569600 max_leapfrog_steps = self .max_leapfrog_steps ,
570- experimental_shard_axis_names = self .experimental_shard_axis_names )
601+ experimental_shard_axis_names = self .experimental_shard_axis_names ,
602+ experimental_chain_axis_names = self .experimental_chain_axis_names )
571603
572604 # Undo the effect of adaptation if we're not in the burnin phase. We keep
573605 # the criterion, however, as that's a diagnostic. We also keep the
@@ -623,9 +655,16 @@ def is_calibrated(self):
623655 def experimental_shard_axis_names (self ):
624656 return self ._parameters ['experimental_shard_axis_names' ]
625657
658+ @property
659+ def experimental_chain_axis_names (self ):
660+ return self ._parameters ['experimental_chain_axis_names' ]
661+
626662 def experimental_with_shard_axes (self , shard_axis_names ):
627663 return self .copy (experimental_shard_axis_names = shard_axis_names )
628664
665+ def experimental_with_chain_axes (self , chain_axis_names ):
666+ return self .copy (experimental_chain_axis_names = chain_axis_names )
667+
629668
630669def _forbid_inner_transformed_kernel (inner_kernel ):
631670 """Forbids inner kernel from containing `TransformedTransitionKernel`."""
@@ -669,7 +708,8 @@ def _update_trajectory_grad(previous_kernel_results, previous_state,
669708 proposed_state , proposed_velocity ,
670709 trajectory_jitter , accept_prob , step_size ,
671710 criterion_fn , max_leapfrog_steps ,
672- experimental_shard_axis_names = None ):
711+ experimental_shard_axis_names = None ,
712+ experimental_chain_axis_names = None ):
673713 """Updates the trajectory length."""
674714 # Compute criterion grads.
675715 def leapfrog_action (dt ):
@@ -693,12 +733,16 @@ def adjust_state(x, v, shard_axes=None):
693733 trajectory_grad *= trajectory_jitter
694734
695735 # Weight by acceptance probability.
736+ experimental_chain_axis_names = distribute_lib .canonicalize_axis_name (
737+ experimental_chain_axis_names )
696738 trajectory_grad = tf .where (accept_prob > 1e-4 , trajectory_grad , 0. )
697739 trajectory_grad = tf .where (
698740 tf .math .is_finite (trajectory_grad ), trajectory_grad , 0. )
699741 trajectory_grad = (
700- tf .reduce_sum (trajectory_grad * accept_prob ) /
701- tf .reduce_sum (accept_prob + 1e-20 ))
742+ _reduce_sum_with_axes (trajectory_grad * accept_prob ,
743+ None , experimental_chain_axis_names ) /
744+ _reduce_sum_with_axes (accept_prob + 1e-20 , None ,
745+ experimental_chain_axis_names ))
702746
703747 # Compute Adam/RMSProp step size.
704748 dtype = previous_kernel_results .adaptation_rate .dtype
0 commit comments