2222from tensorflow_probability .python .internal import cache_util
2323from tensorflow_probability .python .internal import samplers
2424
25-
2625def build_highway_flow_layer (width ,
2726 residual_fraction_initial_value = 0.5 ,
2827 activation_fn = False ,
@@ -363,3 +362,243 @@ def _inverse_log_det_jacobian(self, y):
363362 _ , attrs = self ._augmented_inverse (y )
364363 cached .update (attrs )
365364 return cached ['ildj' ]
365+ == == == =
366+ def build_highway_flow_layer (width , residual_fraction_initial_value = 0.5 ,
367+ activation_fn = False , seed = None ):
368+ # TODO: add control that residual_fraction_initial_value is between 0 and 1
369+ residual_fraction_initial_value = tf .convert_to_tensor (
370+ residual_fraction_initial_value ,
371+ dtype_hint = tf .float32 ,
372+ name = 'residual_fraction_initial_value' )
373+ dtype = residual_fraction_initial_value .dtype
374+
375+ bias_seed , upper_seed , lower_seed , diagonal_seed = samplers .split_seed (seed ,
376+ n = 4 )
377+ return HighwayFlow (
378+ residual_fraction = util .TransformedVariable (
379+ initial_value = residual_fraction_initial_value ,
380+ bijector = tfb .Sigmoid (),
381+ dtype = dtype ),
382+ activation_fn = activation_fn ,
383+ bias = tf .Variable (
384+ samplers .normal ((width ,), mean = 0. , stddev = 0.01 , seed = bias_seed ),
385+ dtype = dtype ),
386+ upper_diagonal_weights_matrix = util .TransformedVariable (
387+ initial_value = tf .experimental .numpy .tril (
388+ samplers .normal ((width , width ), mean = 0. , stddev = 1. ,
389+ seed = upper_seed ),
390+ k = - 1 ) + tf .linalg .diag (
391+ samplers .uniform ((width ,), minval = 0. , maxval = 1. ,
392+ seed = diagonal_seed )),
393+ bijector = tfb .FillScaleTriL (diag_bijector = tfb .Softplus (),
394+ diag_shift = None ),
395+ dtype = dtype ),
396+ lower_diagonal_weights_matrix = util .TransformedVariable (
397+ initial_value = samplers .normal ((width , width ), mean = 0. , stddev = 1. ,
398+ seed = lower_seed ),
399+ bijector = tfb .Chain (
400+ [tfb .TransformDiagonal (diag_bijector = tfb .Shift (1. )),
401+ tfb .Pad (paddings = [(1 , 0 ), (0 , 1 )]),
402+ tfb .FillTriangular ()]),
403+ dtype = dtype )
404+ )
405+
406+
407+ class HighwayFlow (tfb .Bijector ):
408+ """Implements an Highway Flow bijector [1], which interpolates the input
409+ `X` with the transformations at each step of the bjiector.
410+ The Highway Flow can be used as building block for a Cascading flow [1]
411+ or as a generic normalizing flow.
412+
413+ The transformation consists in a convex update between the input `X` and a
414+ linear transformation of `X` followed by activation with the form `g(A @
415+ X + b)`, where `g(.)` is a differentiable non-decreasing activation
416+ function, and `A` and `b` are trainable weights.
417+
418+ The convex update is regulated by a trainable residual fraction `l`
419+ constrained between 0 and 1, and can be
420+ formalized as:
421+ `Y = l * X + (1 - l) * g(A @ X + b)`.
422+
423+ To make this transformation invertible, the bijector is split in three
424+ convex updates:
425+ - `Y1 = l * X + (1 - l) * L @ X`, with `L` lower diagonal matrix with ones
426+ on the diagonal;
427+ - `Y2 = l * Y1 + (1 - l) * (U @ Y1 + b)`, with `U` upper diagonal matrix
428+ with positive diagonal;
429+ - `Y = l * Y2 + (1 - l) * g(Y2)`
430+
431+ The function `build_highway_flow_layer` helps initializing the bijector
432+ with the variables respecting the various constraints.
433+
434+ For more details on Highway Flow and Cascading Flows see [1].
435+
436+ #### Usage example:
437+ ```python
438+ tfd = tfp.distributions
439+ tfb = tfp.bijectors
440+
441+ dim = 4 # last input dimension
442+
443+ bijector = build_highway_flow_layer(dim, activation_fn=True)
444+ y = bijector.forward(x) # forward mapping
445+ x = bijector.inverse(y) # inverse mapping
446+ base = tfd.MultivariateNormalDiag(loc=tf.zeros(dim)) # Base distribution
447+ transformed_distribution = tfd.TransformedDistribution(base, bijector)
448+ ```
449+
450+ #### References
451+
452+ [1]: Ambrogioni, Luca, Gianluigi Silvestri, and Marcel van Gerven.
453+ "Automatic variational inference with
454+ cascading flows." arXiv preprint arXiv:2102.04801 (2021).
455+ """
456+
457+ # HighWay Flow simultaneously computes `forward` and `fldj`
458+ # (and `inverse`/`ildj`), so we override the bijector cache to update the
459+ # LDJ entries of attrs on forward/inverse inverse calls (instead of
460+ # updating them only when the LDJ methods themselves are called).
461+
462+ _cache = cache_util .BijectorCacheWithGreedyAttrs (
463+ forward_name = '_augmented_forward' ,
464+ inverse_name = '_augmented_inverse' )
465+
466+ def __init__ (self , residual_fraction , activation_fn , bias ,
467+ upper_diagonal_weights_matrix ,
468+ lower_diagonal_weights_matrix , validate_args = False ,
469+ name = 'highway_flow' ):
470+ '''
471+ Args:
472+ residual_fraction: scalar `Tensor` used for the convex update,
473+ must be
474+ between 0 and 1
475+ activation_fn: bool to decide whether to use softplus (True)
476+ activation or no activation (False)
477+ bias: bias vector
478+ upper_diagonal_weights_matrix: Lower diagional matrix of size
479+ (width, width) with positive diagonal
480+ (is transposed to Upper diagonal within the bijector)
481+ lower_diagonal_weights_matrix: Lower diagonal matrix with ones on
482+ the main diagional.
483+ '''
484+ parameters = dict (locals ())
485+ with tf .name_scope (name ) as name :
486+ self ._width = tf .shape (bias )[- 1 ]
487+ self ._bias = bias
488+ self ._residual_fraction = residual_fraction
489+ # The upper matrix is still lower triangular, transpose is done in
490+ # _inverse and _forwars metowds, within matvec.
491+ self ._upper_diagonal_weights_matrix = upper_diagonal_weights_matrix
492+ self ._lower_diagonal_weights_matrix = lower_diagonal_weights_matrix
493+ self ._activation_fn = activation_fn
494+
495+ super (HighwayFlow , self ).__init__ (
496+ validate_args = validate_args ,
497+ forward_min_event_ndims = 1 ,
498+ parameters = parameters ,
499+ name = name )
500+
501+ @property
502+ def bias (self ):
503+ return self ._bias
504+
505+ @property
506+ def width (self ):
507+ return self ._width
508+
509+ @property
510+ def residual_fraction (self ):
511+ return self ._residual_fraction
512+
513+ @property
514+ def upper_diagonal_weights_matrix (self ):
515+ return self ._upper_diagonal_weights_matrix
516+
517+ @property
518+ def lower_diagonal_weights_matrix (self ):
519+ return self ._lower_diagonal_weights_matrix
520+
521+ @property
522+ def activation_fn (self ):
523+ return self ._activation_fn
524+
525+ def _derivative_of_sigmoid (self , x ):
526+ return self .residual_fraction + (
527+ 1. - self .residual_fraction ) * tf .math .sigmoid (x )
528+
529+ def _convex_update (self , weights_matrix ):
530+ return self .residual_fraction * tf .eye (self .width ) + (
531+ 1. - self .residual_fraction ) * weights_matrix
532+
533+ def _inverse_of_sigmoid (self , y , N = 20 ):
534+ # Inverse of the activation layer with softplus using Newton iteration.
535+ x = tf .ones (y .shape )
536+ for _ in range (N ):
537+ x = x - (self .residual_fraction * x + (
538+ 1. - self .residual_fraction ) * tf .math .softplus (
539+ x ) - y ) / (
540+ self ._derivative_of_sigmoid (x ))
541+ return x
542+
543+ def _augmented_forward (self , x ):
544+ # Log determinant term from the upper matrix. Note that the log determinant
545+ # of the lower matrix is zero.
546+ fldj = tf .zeros (x .shape [:- 1 ]) + tf .reduce_sum (
547+ tf .math .log (self .residual_fraction + (
548+ 1. - self .residual_fraction ) * tf .linalg .diag_part (
549+ self .upper_diagonal_weights_matrix )))
550+ x = tf .linalg .matvec (
551+ self ._convex_update (self .lower_diagonal_weights_matrix ), x )
552+ x = tf .linalg .matvec (tf .transpose (
553+ self ._convex_update (self .upper_diagonal_weights_matrix )),
554+ x ) + (
555+ 1 - self .residual_fraction ) * self .bias
556+ if self .activation_fn :
557+ fldj += tf .reduce_sum (tf .math .log (self ._derivative_of_sigmoid (x )),
558+ - 1 )
559+ x = self .residual_fraction * x + (
560+ 1. - self .residual_fraction ) * self .activation_fn (x )
561+ return x , {'ildj' : - fldj , 'fldj' : fldj }
562+
563+ def _augmented_inverse (self , y ):
564+ ildj = tf .zeros (y .shape [:- 1 ]) - tf .reduce_sum (
565+ tf .math .log (self .residual_fraction + (
566+ 1. - self .residual_fraction ) * tf .linalg .diag_part (
567+ self .upper_diagonal_weights_matrix )))
568+ if self .activation_fn :
569+ y = self ._inverse_of_sigmoid (y )
570+ ildj -= tf .reduce_sum (tf .math .log (self ._derivative_of_sigmoid (y )),
571+ - 1 )
572+
573+ y = tf .linalg .triangular_solve (tf .transpose (
574+ self ._convex_update (self .upper_diagonal_weights_matrix )),
575+ tf .linalg .matrix_transpose (y - (
576+ 1 - self .residual_fraction ) * self .bias ),
577+ lower = False )
578+ y = tf .linalg .triangular_solve (
579+ self ._convex_update (self .lower_diagonal_weights_matrix ), y )
580+ return tf .linalg .matrix_transpose (y ), {'ildj' : ildj , 'fldj' : - ildj }
581+
582+ def _forward (self , x ):
583+ y , _ = self ._augmented_forward (x )
584+ return y
585+
586+ def _inverse (self , y ):
587+ x , _ = self ._augmented_inverse (y )
588+ return x
589+
590+ def _forward_log_det_jacobian (self , x ):
591+ cached = self ._cache .forward_attributes (x )
592+ # If LDJ isn't in the cache, call forward once.
593+ if 'fldj' not in cached :
594+ _ , attrs = self ._augmented_forward (x )
595+ cached .update (attrs )
596+ return cached ['fldj' ]
597+
598+ def _inverse_log_det_jacobian (self , y ):
599+ cached = self ._cache .inverse_attributes (y )
600+ # If LDJ isn't in the cache, call inverse once.
601+ if 'ildj' not in cached :
602+ _ , attrs = self ._augmented_inverse (y )
603+ cached .update (attrs )
604+ return cached ['ildj' ]
0 commit comments