1818from __future__ import division
1919from __future__ import print_function
2020
21- import copy
2221import six
2322
2423import tensorflow .compat .v2 as tf
2524
2625from tensorflow_probability .python .distributions import distribution as tfd
26+ from tensorflow_probability .python .distributions import kullback_leibler
2727from tensorflow_probability .python .internal import nest_util
2828from tensorflow_probability .python .internal import parameter_properties
2929from tensorflow_probability .python .util .deferred_tensor import TensorMetaClass
3030from tensorflow .python .framework import composite_tensor # pylint: disable=g-direct-tensorflow-import
31+ from tensorflow .python .training .tracking import data_structures # pylint: disable=g-direct-tensorflow-import
3132
3233
3334__all__ = [] # We intend nothing public.
3435
36+ _NOT_FOUND = object ()
37+
3538
3639# Define mixin type because Distribution already has its own metaclass.
3740class _DistributionAndTensorCoercibleMeta (type (tfd .Distribution ),
@@ -43,43 +46,123 @@ class _DistributionAndTensorCoercibleMeta(type(tfd.Distribution),
4346class _TensorCoercible (tfd .Distribution ):
4447 """Docstring."""
4548
46- registered_class_list = {}
47-
48- def __new__ (cls , distribution , convert_to_tensor_fn = tfd .Distribution .sample ):
49- if isinstance (distribution , cls ):
50- return distribution
51- if not isinstance (distribution , tfd .Distribution ):
52- raise TypeError ('`distribution` argument must be a '
53- '`tfd.Distribution` instance; '
54- 'saw "{}" of type "{}".' .format (
55- distribution , type (distribution )))
56- self = copy .copy (distribution )
57- distcls = distribution .__class__
58- self_class = _TensorCoercible .registered_class_list .get (distcls )
59- if not self_class :
60- self_class = type (distcls .__name__ , (cls , distcls ), {})
61- _TensorCoercible .registered_class_list [distcls ] = self_class
62- self .__class__ = self_class
63- return self
64-
6549 def __init__ (self ,
6650 distribution ,
6751 convert_to_tensor_fn = tfd .Distribution .sample ):
6852 self ._concrete_value = None # pylint: disable=protected-access
6953 self ._convert_to_tensor_fn = convert_to_tensor_fn # pylint: disable=protected-access
54+ self .tensor_distribution = distribution
55+ super (_TensorCoercible , self ).__init__ (
56+ dtype = distribution .dtype ,
57+ reparameterization_type = distribution .reparameterization_type ,
58+ validate_args = distribution .validate_args ,
59+ allow_nan_stats = distribution .allow_nan_stats ,
60+ parameters = distribution .parameters )
61+
62+ def __setattr__ (self , name , value ):
63+ """Support self.foo = trackable syntax.
64+
65+ Redefined from `tensorflow/python/training/tracking/tracking.py` to avoid
66+ calling `getattr`, which causes an infinite loop.
67+
68+ Args:
69+ name: str, name of the attribute to be set.
70+ value: value to be set.
71+ """
72+ if vars (self ).get (name , _NOT_FOUND ) is value :
73+ return
74+
75+ if vars (self ).get ('_self_setattr_tracking' , True ):
76+ value = data_structures .sticky_attribute_assignment (
77+ trackable = self , value = value , name = name )
78+ object .__setattr__ (self , name , value )
79+
80+ def __getattr__ (self , name ):
81+ # If the attribute is set in the _TensorCoercible object, return it. This
82+ # ensures that direct calls to `getattr` behave as expected.
83+ if name in vars (self ):
84+ return vars (self )[name ]
85+ # Look for the attribute in `tensor_distribution`, unless it's a `_tracking`
86+ # attribute accessed directly by `getattr` in the `Trackable` base class, in
87+ # which case the default passed to `getattr` should be returned.
88+ if 'tensor_distribution' in vars (self ) and '_tracking' not in name :
89+ return getattr (vars (self )['tensor_distribution' ], name )
90+ # Otherwise invoke `__getattribute__`, which will return the default passed
91+ # to `getattr` if the attribute was not found.
92+ return self .__getattribute__ (name )
7093
7194 @classmethod
7295 def _parameter_properties (cls , dtype , num_classes = None ):
7396 return dict (distribution = parameter_properties .BatchedComponentProperties ())
7497
98+ # pylint: disable=protected-access
7599 def _batch_shape_tensor (self , ** parameter_kwargs ):
76- # Any parameter kwargs are for the inner distribution, so pass them
77- # to its `_batch_shape_tensor` method instead of handling them directly.
78- return self .parameters ['distribution' ]._batch_shape_tensor ( # pylint: disable=protected-access
79- ** parameter_kwargs )
100+ return self .tensor_distribution ._batch_shape_tensor (** parameter_kwargs )
101+
102+ def _batch_shape (self ):
103+ return self .tensor_distribution ._batch_shape ()
104+
105+ def _event_shape_tensor (self ):
106+ return self .tensor_distribution ._event_shape_tensor ()
107+
108+ def _event_shape (self ):
109+ return self .tensor_distribution ._event_shape ()
110+
111+ def sample (self , sample_shape = (), seed = None , name = 'sample' , ** kwargs ):
112+ return self .tensor_distribution .sample (
113+ sample_shape = sample_shape , seed = seed , name = name , ** kwargs )
114+
115+ def _log_prob (self , value , ** kwargs ):
116+ return self .tensor_distribution ._log_prob (value , ** kwargs )
117+
118+ def _prob (self , value , ** kwargs ):
119+ return self .tensor_distribution ._prob (value , ** kwargs )
120+
121+ def _log_cdf (self , value , ** kwargs ):
122+ return self .tensor_distribution ._log_cdf (value , ** kwargs )
123+
124+ def _cdf (self , value , ** kwargs ):
125+ return self .tensor_distribution ._cdf (value , ** kwargs )
126+
127+ def _log_survival_function (self , value , ** kwargs ):
128+ return self .tensor_distribution ._log_survival_function (value , ** kwargs )
129+
130+ def _survival_function (self , value , ** kwargs ):
131+ return self .tensor_distribution ._survival_function (value , ** kwargs )
132+
133+ def _entropy (self , ** kwargs ):
134+ return self .tensor_distribution ._entropy (** kwargs )
135+
136+ def _mean (self , ** kwargs ):
137+ return self .tensor_distribution ._mean (** kwargs )
138+
139+ def _quantile (self , value , ** kwargs ):
140+ return self .tensor_distribution ._quantile (value , ** kwargs )
141+
142+ def _variance (self , ** kwargs ):
143+ return self .tensor_distribution ._variance (** kwargs )
144+
145+ def _stddev (self , ** kwargs ):
146+ return self .tensor_distribution ._stddev (** kwargs )
147+
148+ def _covariance (self , ** kwargs ):
149+ return self .tensor_distribution ._covariance (** kwargs )
150+
151+ def _mode (self , ** kwargs ):
152+ return self .tensor_distribution ._mode (** kwargs )
153+
154+ def _default_event_space_bijector (self , * args , ** kwargs ):
155+ return self .tensor_distribution ._default_event_space_bijector (
156+ * args , ** kwargs )
157+
158+ def _parameter_control_dependencies (self , is_init ):
159+ return self .tensor_distribution ._parameter_control_dependencies (is_init )
80160
81161 @property
82162 def shape (self ):
163+ return self ._shape
164+
165+ def _shape (self ):
83166 return (tf .TensorShape (None ) if self ._concrete_value is None
84167 else self ._concrete_value .shape )
85168
@@ -130,15 +213,26 @@ def _value(self, dtype=None, name=None, as_ref=False):
130213 ' results in `tf.convert_to_tensor(x)` being identical to '
131214 '`x.mean()`.' .format (type (self ), self ))
132215 with self ._name_and_control_scope ('value' ):
133- self ._concrete_value = (self ._convert_to_tensor_fn (self )
134- if callable (self ._convert_to_tensor_fn )
135- else self ._convert_to_tensor_fn )
216+ self ._concrete_value = (
217+ self ._convert_to_tensor_fn (self .tensor_distribution )
218+ if callable (self ._convert_to_tensor_fn )
219+ else self ._convert_to_tensor_fn )
136220 if (not tf .is_tensor (self ._concrete_value ) and
137221 not isinstance (self ._concrete_value ,
138222 composite_tensor .CompositeTensor )):
139223 self ._concrete_value = nest_util .convert_to_nested_tensor ( # pylint: disable=protected-access
140224 self ._concrete_value ,
141225 name = name or 'concrete_value' ,
142226 dtype = dtype ,
143- dtype_hint = self .dtype )
227+ dtype_hint = self .tensor_distribution . dtype )
144228 return self ._concrete_value
229+
230+
231+ @kullback_leibler .RegisterKL (_TensorCoercible , tfd .Distribution )
232+ def _kl_tensor_coercible_distribution (a , b , name = None ):
233+ return kullback_leibler .kl_divergence (a .tensor_distribution , b , name = name )
234+
235+
236+ @kullback_leibler .RegisterKL (tfd .Distribution , _TensorCoercible )
237+ def _kl_distribution_tensor_coercible (a , b , name = None ):
238+ return kullback_leibler .kl_divergence (a , b .tensor_distribution , name = name )
0 commit comments