@@ -714,54 +714,24 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
714
714
raise ValueError (f"Can only compute the gradient of continuous types: { var } " )
715
715
716
716
if tempered :
717
- with self :
718
- # Convert random variables into their log-likelihood inputs and
719
- # apply their transforms, if any
720
- potentials , _ = rvs_to_value_vars (self .potentials , apply_transforms = True )
721
-
722
- free_RVs_logp = at .sum (
723
- [at .sum (logpt (var , self .rvs_to_values .get (var , None ))) for var in self .free_RVs ]
724
- + list (potentials )
725
- )
726
- observed_RVs_logp = at .sum (
727
- [at .sum (logpt (obs , obs .tag .observations )) for obs in self .observed_RVs ]
728
- )
729
-
730
- costs = [free_RVs_logp , observed_RVs_logp ]
717
+ # TODO: Should this differ from self.datalogpt,
718
+ # where the potential terms are added to the observations?
719
+ costs = [self .varlogpt + self .potentiallogpt , self .observedlogpt ]
731
720
else :
732
721
costs = [self .logpt ]
733
722
734
723
input_vars = {i for i in graph_inputs (costs ) if not isinstance (i , Constant )}
735
724
extra_vars = [self .rvs_to_values .get (var , var ) for var in self .free_RVs ]
725
+ ip = self .recompute_initial_point (0 )
736
726
extra_vars_and_values = {
737
- var : self .initial_point [var .name ]
738
- for var in extra_vars
739
- if var in input_vars and var not in grad_vars
727
+ var : ip [var .name ] for var in extra_vars if var in input_vars and var not in grad_vars
740
728
}
741
729
return ValueGradFunction (costs , grad_vars , extra_vars_and_values , ** kwargs )
742
730
743
731
@property
744
732
def logpt (self ):
745
733
"""Aesara scalar of log-probability of the model"""
746
-
747
- rv_values = {}
748
- for var in self .free_RVs :
749
- rv_values [var ] = self .rvs_to_values .get (var , None )
750
- rv_factors = logpt (self .free_RVs , rv_values )
751
-
752
- obs_values = {}
753
- for obs in self .observed_RVs :
754
- obs_values [obs ] = obs .tag .observations
755
- obs_factors = logpt (self .observed_RVs , obs_values )
756
-
757
- # Convert random variables into their log-likelihood inputs and
758
- # apply their transforms, if any
759
- potentials , _ = rvs_to_value_vars (self .potentials , apply_transforms = True )
760
- logp_var = at .sum ([at .sum (factor ) for factor in potentials ])
761
- if rv_factors is not None :
762
- logp_var += rv_factors
763
- if obs_factors is not None :
764
- logp_var += obs_factors
734
+ logp_var = self .varlogpt + self .datalogpt
765
735
766
736
if self .name :
767
737
logp_var .name = f"__logp_{ self .name } "
@@ -777,60 +747,65 @@ def logp_nojact(self):
777
747
Note that if there is no transformed variable in the model, logp_nojact
778
748
will be the same as logpt as there is no need for Jacobian correction.
779
749
"""
780
- with self :
781
- rv_values = {}
782
- for var in self .free_RVs :
783
- rv_values [var ] = getattr (var .tag , "value_var" , None )
784
- rv_factors = logpt (self .free_RVs , rv_values , jacobian = False )
785
-
786
- obs_values = {}
787
- for obs in self .observed_RVs :
788
- obs_values [obs ] = obs .tag .observations
789
- obs_factors = logpt (self .observed_RVs , obs_values , jacobian = False )
790
-
791
- # Convert random variables into their log-likelihood inputs and
792
- # apply their transforms, if any
793
- potentials , _ = rvs_to_value_vars (self .potentials , apply_transforms = True )
794
- logp_var = at .sum ([at .sum (factor ) for factor in potentials ])
795
-
796
- if rv_factors is not None :
797
- logp_var += rv_factors
798
- if obs_factors is not None :
799
- logp_var += obs_factors
800
-
801
- if self .name :
802
- logp_var .name = f"__logp_nojac_{ self .name } "
803
- else :
804
- logp_var .name = "__logp_nojac"
805
- return logp_var
750
+ logp_var = self .varlogp_nojact + self .datalogpt
751
+
752
+ if self .name :
753
+ logp_var .name = f"__logp_nojac_{ self .name } "
754
+ else :
755
+ logp_var .name = "__logp_nojac"
756
+ return logp_var
757
+
758
+ @property
759
+ def datalogpt (self ):
760
+ """Aesara scalar of log-probability of the observed variables and
761
+ potential terms"""
762
+ return self .observedlogpt + self .potentiallogpt
806
763
807
764
@property
808
765
def varlogpt (self ):
809
766
"""Aesara scalar of log-probability of the unobserved random variables
810
767
(excluding deterministic)."""
811
- with self :
812
- rv_values = {}
813
- for var in self .free_RVs :
814
- rv_values [ var ] = getattr ( var . tag , "value_var" , None )
768
+ rv_values = {}
769
+ for var in self . free_RVs :
770
+ rv_values [ var ] = self .rvs_to_values [ var ]
771
+ if rv_values :
815
772
return logpt (self .free_RVs , rv_values )
773
+ else :
774
+ return 0
816
775
817
776
@property
818
- def datalogpt (self ):
819
- with self :
820
- obs_values = {}
821
- for obs in self .observed_RVs :
822
- obs_values [obs ] = obs .tag .observations
823
- obs_factors = logpt (self .observed_RVs , obs_values )
824
-
825
- # Convert random variables into their log-likelihood inputs and
826
- # apply their transforms, if any
827
- potentials , _ = rvs_to_value_vars (self .potentials , apply_transforms = True )
828
- logp_var = at .sum ([at .sum (factor ) for factor in potentials ])
777
+ def varlogp_nojact (self ):
778
+ """Aesara scalar of log-probability of the unobserved random variables
779
+ (excluding deterministic) without jacobian term."""
780
+ rv_values = {}
781
+ for var in self .free_RVs :
782
+ rv_values [var ] = self .rvs_to_values [var ]
783
+ if rv_values :
784
+ return logpt (self .free_RVs , rv_values , jacobian = False )
785
+ else :
786
+ return 0
829
787
830
- if obs_factors is not None :
831
- logp_var += obs_factors
788
+ @property
789
+ def observedlogpt (self ):
790
+ """Aesara scalar of log-probability of the observed variables"""
791
+ obs_values = {}
792
+ for obs in self .observed_RVs :
793
+ obs_values [obs ] = obs .tag .observations
794
+ if obs_values :
795
+ return logpt (self .observed_RVs , obs_values )
796
+ else :
797
+ return 0
832
798
833
- return logp_var
799
+ @property
800
+ def potentiallogpt (self ):
801
+ """Aesara scalar of log-probability of the Potential terms"""
802
+ # Convert random variables in Potential expression into their log-likelihood
803
+ # inputs and apply their transforms, if any
804
+ potentials , _ = rvs_to_value_vars (self .potentials , apply_transforms = True )
805
+ if potentials :
806
+ return at .sum ([at .sum (factor ) for factor in potentials ])
807
+ else :
808
+ return 0
834
809
835
810
@property
836
811
def vars (self ):
0 commit comments