|
23 | 23 | from aesara.tensor import TensorVariable
|
24 | 24 | from aesara.tensor.random.op import RandomVariable
|
25 | 25 |
|
26 |
| -from pymc.aesaraf import change_rv_size, take_along_axis |
| 26 | +from pymc.aesaraf import change_rv_size |
27 | 27 | from pymc.distributions.continuous import Normal, get_tau_sigma
|
28 | 28 | from pymc.distributions.dist_math import check_parameters
|
29 | 29 | from pymc.distributions.distribution import Discrete, Distribution, SymbolicDistribution
|
30 | 30 | from pymc.distributions.logprob import logp
|
31 | 31 | from pymc.distributions.shape_utils import to_tuple
|
32 |
| -from pymc.math import logsumexp |
33 | 32 | from pymc.util import check_dist_not_registered
|
34 | 33 |
|
35 |
| -__all__ = ["Mixture", "NormalMixture", "MixtureSameFamily"] |
| 34 | +__all__ = ["Mixture", "NormalMixture"] |
36 | 35 |
|
37 | 36 |
|
38 | 37 | def all_discrete(comp_dists):
|
@@ -468,235 +467,3 @@ def dist(cls, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), **kwargs):
|
468 | 467 | _, sigma = get_tau_sigma(tau=tau, sigma=sigma)
|
469 | 468 |
|
470 | 469 | return Mixture.dist(w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs)
|
471 |
| - |
472 |
| - |
473 |
| -class MixtureSameFamily(Distribution): |
474 |
| - R""" |
475 |
| - Mixture Same Family log-likelihood |
476 |
| - This distribution handles mixtures of multivariate distributions in a vectorized |
477 |
| - manner. It is used over Mixture distribution when the mixture components are not |
478 |
| - present on the last axis of components' distribution. |
479 |
| -
|
480 |
| - .. math::f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i)\textrm{ Along mixture\_axis} |
481 |
| -
|
482 |
| - ======== ============================================ |
483 |
| - Support :math:`\textrm{support}(f)` |
484 |
| - Mean :math:`w\mu` |
485 |
| - ======== ============================================ |
486 |
| -
|
487 |
| - Parameters |
488 |
| - ---------- |
489 |
| - w: array of floats |
490 |
| - w >= 0 and w <= 1 |
491 |
| - the mixture weights |
492 |
| - comp_dists: PyMC distribution (e.g. `pm.Multinomial.dist(...)`) |
493 |
| - The `comp_dists` can be scalar or multidimensional distribution. |
494 |
| - Assuming its shape to be - (i_0, ..., i_n, mixture_axis, i_n+1, ..., i_N), |
495 |
| - the `mixture_axis` is consumed resulting in the shape of mixture as - |
496 |
| - (i_0, ..., i_n, i_n+1, ..., i_N). |
497 |
| - mixture_axis: int, default = -1 |
498 |
| - Axis representing the mixture components to be reduced in the mixture. |
499 |
| -
|
500 |
| - Notes |
501 |
| - ----- |
502 |
| - The default behaviour resembles Mixture distribution wherein the last axis of component |
503 |
| - distribution is reduced. |
504 |
| - """ |
505 |
| - |
506 |
| - def __init__(self, w, comp_dists, mixture_axis=-1, *args, **kwargs): |
507 |
| - self.w = at.as_tensor_variable(w) |
508 |
| - if not isinstance(comp_dists, Distribution): |
509 |
| - raise TypeError( |
510 |
| - "The MixtureSameFamily distribution only accepts Distribution " |
511 |
| - f"instances as its components. Got {type(comp_dists)} instead." |
512 |
| - ) |
513 |
| - self.comp_dists = comp_dists |
514 |
| - if mixture_axis < 0: |
515 |
| - mixture_axis = len(comp_dists.shape) + mixture_axis |
516 |
| - if mixture_axis < 0: |
517 |
| - raise ValueError( |
518 |
| - "`mixture_axis` is supposed to be in shape of components' distribution. " |
519 |
| - f"Got {mixture_axis + len(comp_dists.shape)} axis instead out of the bounds." |
520 |
| - ) |
521 |
| - comp_shape = to_tuple(comp_dists.shape) |
522 |
| - self.shape = comp_shape[:mixture_axis] + comp_shape[mixture_axis + 1 :] |
523 |
| - self.mixture_axis = mixture_axis |
524 |
| - kwargs.setdefault("dtype", self.comp_dists.dtype) |
525 |
| - |
526 |
| - # Compute the mode so we don't always have to pass a initval |
527 |
| - defaults = kwargs.pop("defaults", []) |
528 |
| - event_shape = self.comp_dists.shape[mixture_axis + 1 :] |
529 |
| - _w = at.shape_padleft( |
530 |
| - at.shape_padright(w, len(event_shape)), |
531 |
| - len(self.comp_dists.shape) - w.ndim - len(event_shape), |
532 |
| - ) |
533 |
| - mode = take_along_axis( |
534 |
| - self.comp_dists.mode, |
535 |
| - at.argmax(_w, keepdims=True), |
536 |
| - axis=mixture_axis, |
537 |
| - ) |
538 |
| - self.mode = mode[(..., 0) + (slice(None),) * len(event_shape)] |
539 |
| - |
540 |
| - if not all_discrete(comp_dists): |
541 |
| - mean = at.as_tensor_variable(self.comp_dists.mean) |
542 |
| - self.mean = (_w * mean).sum(axis=mixture_axis) |
543 |
| - if "mean" not in defaults: |
544 |
| - defaults.append("mean") |
545 |
| - defaults.append("mode") |
546 |
| - |
547 |
| - super().__init__(defaults=defaults, *args, **kwargs) |
548 |
| - |
549 |
| - def logp(self, value): |
550 |
| - """ |
551 |
| - Calculate log-probability of defined ``MixtureSameFamily`` distribution at specified value. |
552 |
| -
|
553 |
| - Parameters |
554 |
| - ---------- |
555 |
| - value : numeric |
556 |
| - Value(s) for which log-probability is calculated. If the log probabilities for multiple |
557 |
| - values are desired the values must be provided in a numpy array or Aesara tensor |
558 |
| -
|
559 |
| - Returns |
560 |
| - ------- |
561 |
| - TensorVariable |
562 |
| - """ |
563 |
| - |
564 |
| - comp_dists = self.comp_dists |
565 |
| - w = self.w |
566 |
| - mixture_axis = self.mixture_axis |
567 |
| - |
568 |
| - event_shape = comp_dists.shape[mixture_axis + 1 :] |
569 |
| - |
570 |
| - # To be able to broadcast the comp_dists.logp with w and value |
571 |
| - # We first have to pad the shape of w to the right with ones |
572 |
| - # so that it can broadcast with the event_shape. |
573 |
| - |
574 |
| - w = at.shape_padright(w, len(event_shape)) |
575 |
| - |
576 |
| - # Second, we have to add the mixture_axis to the value tensor |
577 |
| - # To insert the mixture axis at the correct location, we use the |
578 |
| - # negative number index. This way, we can also handle situations |
579 |
| - # in which, value is an observed value with more batch dimensions |
580 |
| - # than the ones present in the comp_dists. |
581 |
| - comp_dists_ndim = len(comp_dists.shape) |
582 |
| - |
583 |
| - value = at.shape_padaxis(value, axis=mixture_axis - comp_dists_ndim) |
584 |
| - |
585 |
| - comp_logp = comp_dists.logp(value) |
586 |
| - return check_parameters( |
587 |
| - logsumexp(at.log(w) + comp_logp, axis=mixture_axis, keepdims=False), |
588 |
| - w >= 0, |
589 |
| - w <= 1, |
590 |
| - at.allclose(w.sum(axis=mixture_axis - comp_dists_ndim), 1), |
591 |
| - broadcast_conditions=False, |
592 |
| - ) |
593 |
| - |
594 |
| - def random(self, point=None, size=None): |
595 |
| - """ |
596 |
| - Draw random values from defined ``MixtureSameFamily`` distribution. |
597 |
| -
|
598 |
| - Parameters |
599 |
| - ---------- |
600 |
| - point : dict, optional |
601 |
| - Dict of variable values on which random values are to be |
602 |
| - conditioned (uses default point if not specified). |
603 |
| - size : int, optional |
604 |
| - Desired size of random sample (returns one sample if not |
605 |
| - specified). |
606 |
| -
|
607 |
| - Returns |
608 |
| - ------- |
609 |
| - array |
610 |
| - """ |
611 |
| - # sample_shape = to_tuple(size) |
612 |
| - # mixture_axis = self.mixture_axis |
613 |
| - # |
614 |
| - # # First we draw values for the mixture component weights |
615 |
| - # (w,) = draw_values([self.w], point=point, size=size) |
616 |
| - # |
617 |
| - # # We now draw random choices from those weights. |
618 |
| - # # However, we have to ensure that the number of choices has the |
619 |
| - # # sample_shape present. |
620 |
| - # w_shape = w.shape |
621 |
| - # batch_shape = self.comp_dists.shape[: mixture_axis + 1] |
622 |
| - # param_shape = np.broadcast(np.empty(w_shape), np.empty(batch_shape)).shape |
623 |
| - # event_shape = self.comp_dists.shape[mixture_axis + 1 :] |
624 |
| - # |
625 |
| - # if np.asarray(self.shape).size != 0: |
626 |
| - # comp_dists_ndim = len(self.comp_dists.shape) |
627 |
| - # |
628 |
| - # # If event_shape of both comp_dists and supplied shape matches, |
629 |
| - # # broadcast only batch_shape |
630 |
| - # # else broadcast the entire given shape with batch_shape. |
631 |
| - # if list(self.shape[mixture_axis - comp_dists_ndim + 1 :]) == list(event_shape): |
632 |
| - # dist_shape = np.broadcast( |
633 |
| - # np.empty(self.shape[:mixture_axis]), np.empty(param_shape[:mixture_axis]) |
634 |
| - # ).shape |
635 |
| - # else: |
636 |
| - # dist_shape = np.broadcast( |
637 |
| - # np.empty(self.shape), np.empty(param_shape[:mixture_axis]) |
638 |
| - # ).shape |
639 |
| - # else: |
640 |
| - # dist_shape = param_shape[:mixture_axis] |
641 |
| - # |
642 |
| - # # Try to determine the size that must be used to get the mixture |
643 |
| - # # components (i.e. get random choices using w). |
644 |
| - # # 1. There must be size independent choices based on w. |
645 |
| - # # 2. There must also be independent draws for each non singleton axis |
646 |
| - # # of w. |
647 |
| - # # 3. There must also be independent draws for each dimension added by |
648 |
| - # # self.shape with respect to the w.ndim. These usually correspond to |
649 |
| - # # observed variables with batch shapes |
650 |
| - # wsh = (1,) * (len(dist_shape) - len(w_shape) + 1) + w_shape[:mixture_axis] |
651 |
| - # psh = (1,) * (len(dist_shape) - len(param_shape) + 1) + param_shape[:mixture_axis] |
652 |
| - # w_sample_size = [] |
653 |
| - # # Loop through the dist_shape to get the conditions 2 and 3 first |
654 |
| - # for i in range(len(dist_shape)): |
655 |
| - # if dist_shape[i] != psh[i] and wsh[i] == 1: |
656 |
| - # # self.shape[i] is a non singleton dimension (usually caused by |
657 |
| - # # observed data) |
658 |
| - # sh = dist_shape[i] |
659 |
| - # else: |
660 |
| - # sh = wsh[i] |
661 |
| - # w_sample_size.append(sh) |
662 |
| - # |
663 |
| - # if sample_shape is not None and w_sample_size[: len(sample_shape)] != sample_shape: |
664 |
| - # w_sample_size = sample_shape + tuple(w_sample_size) |
665 |
| - # |
666 |
| - # choices = random_choice(p=w, size=w_sample_size) |
667 |
| - # |
668 |
| - # # We now draw samples from the mixture components random method |
669 |
| - # comp_samples = self.comp_dists.random(point=point, size=size) |
670 |
| - # if comp_samples.shape[: len(sample_shape)] != sample_shape: |
671 |
| - # comp_samples = np.broadcast_to( |
672 |
| - # comp_samples, |
673 |
| - # shape=sample_shape + comp_samples.shape, |
674 |
| - # ) |
675 |
| - # |
676 |
| - # # At this point the shapes of the arrays involved are: |
677 |
| - # # comp_samples.shape = (sample_shape, batch_shape, mixture_axis, event_shape) |
678 |
| - # # choices.shape = (sample_shape, batch_shape) |
679 |
| - # # |
680 |
| - # # To be able to take the choices along the mixture_axis of the |
681 |
| - # # comp_samples, we have to add in dimensions to the right of the |
682 |
| - # # choices array. |
683 |
| - # # We also need to make sure that the batch_shapes of both the comp_samples |
684 |
| - # # and choices broadcast with each other. |
685 |
| - # |
686 |
| - # choices = np.reshape(choices, choices.shape + (1,) * (1 + len(event_shape))) |
687 |
| - # |
688 |
| - # choices, comp_samples = get_broadcastable_dist_samples([choices, comp_samples], size=size) |
689 |
| - # |
690 |
| - # # We now take the choices of the mixture components along the mixture_axis |
691 |
| - # # but we use the negative index representation to be able to handle the |
692 |
| - # # sample_shape |
693 |
| - # samples = np.take_along_axis( |
694 |
| - # comp_samples, choices, axis=mixture_axis - len(self.comp_dists.shape) |
695 |
| - # ) |
696 |
| - # |
697 |
| - # # The `samples` array still has the `mixture_axis`, so we must remove it: |
698 |
| - # output = samples[(..., 0) + (slice(None),) * len(event_shape)] |
699 |
| - # return output |
700 |
| - |
701 |
| - def _distr_parameters_for_repr(self): |
702 |
| - return [] |
0 commit comments