From 4b383d3c452cbd1ffd064cff43682cb9e17420c9 Mon Sep 17 00:00:00 2001 From: ay340 <68298508+ay340@users.noreply.github.com> Date: Mon, 15 Sep 2025 22:40:36 +0200 Subject: [PATCH] Add general initial state and transition input format checks; added CatHMM and ARHMM emissions checks --- dynamax/hidden_markov_model/models/arhmm.py | 8 ++++++++ .../hidden_markov_model/models/categorical_hmm.py | 12 +++++++----- dynamax/hidden_markov_model/models/initial.py | 6 ++++++ dynamax/hidden_markov_model/models/transitions.py | 8 +++++++- 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/dynamax/hidden_markov_model/models/arhmm.py b/dynamax/hidden_markov_model/models/arhmm.py index 222f4f31..6a961ffc 100644 --- a/dynamax/hidden_markov_model/models/arhmm.py +++ b/dynamax/hidden_markov_model/models/arhmm.py @@ -40,6 +40,11 @@ def __init__(self, input_dim = num_lags * emission_dim super().__init__(num_states, input_dim, emission_dim) + def _check_emissions_format(self, emission_weights, emission_biases, emission_covariances): + assert emission_weights.shape == (self.num_states, self.emission_dim, self.input_dim), f"'emission_weights' must have shape (num_states, emission_dim, input_dim)={(self.num_states, self.emission_dim, self.input_dim)} but {emission_weights.shape} provided." + assert emission_biases.shape == (self.num_states, self.emission_dim), f"'emission_biases' must have shape (num_states, emission_dim)={(self.num_states, self.emission_dim)} but {emission_biases.shape} provided." + assert emission_covariances.shape == (self.num_states, self.emission_dim, self.emission_dim), f"'emission_covariances' must have shape (num_states, emission_dim, emission_dim)={(self.num_states, self.emission_dim, self.emission_dim)} but {emission_covariances.shape} provided." + def initialize(self, key: Array=jr.PRNGKey(0), method: str="prior", @@ -93,6 +98,9 @@ def initialize(self, weights=ParameterProperties(), biases=ParameterProperties(), covs=ParameterProperties(constrainer=RealToPSDBijector())) + + self._check_emissions_format(emission_weights=params.weights, emission_biases=params.biases, emission_covariances=params.covs) + return params, props diff --git a/dynamax/hidden_markov_model/models/categorical_hmm.py b/dynamax/hidden_markov_model/models/categorical_hmm.py index 7d3b2ff2..45e4b10b 100644 --- a/dynamax/hidden_markov_model/models/categorical_hmm.py +++ b/dynamax/hidden_markov_model/models/categorical_hmm.py @@ -58,6 +58,11 @@ def log_prior(self, params: ParamsCategoricalHMMEmissions) -> Scalar: """Return the log prior probability of the emission parameters.""" return tfd.Dirichlet(self.emission_prior_concentration).log_prob(params.probs).sum() + def _check_emissions_format(self, emission_probs): + assert emission_probs.shape == (self.num_states, self.emission_dim, self.num_classes), f"'emission_probs' must have shape (num_states, emission_dim, num_classes)={(self.num_states, self.emission_dim, self.num_classes)} but {emission_probs.shape} provided." + assert jnp.all(emission_probs >= 0), "All entries in 'emission_probs' must be non-negative." + assert jnp.allclose(emission_probs.sum(axis=2), 1.0), "Each row of 'emission_probs' must sum to 1." + def initialize(self, key:Optional[Array]=jr.PRNGKey(0), method="prior", @@ -91,11 +96,8 @@ def initialize(self, raise NotImplementedError("kmeans initialization is not yet implemented!") else: raise Exception("invalid initialization method: {}".format(method)) - else: - assert emission_probs.shape == (self.num_states, self.emission_dim, self.num_classes) - assert jnp.all(emission_probs >= 0) - assert jnp.allclose(emission_probs.sum(axis=2), 1.0) - + + self._check_emissions_format(emission_probs=emission_probs) # Add parameters to the dictionary params = ParamsCategoricalHMMEmissions(probs=emission_probs) props = ParamsCategoricalHMMEmissions(probs=ParameterProperties(constrainer=tfb.SoftmaxCentered())) diff --git a/dynamax/hidden_markov_model/models/initial.py b/dynamax/hidden_markov_model/models/initial.py index 1e733196..b5e76f3b 100644 --- a/dynamax/hidden_markov_model/models/initial.py +++ b/dynamax/hidden_markov_model/models/initial.py @@ -35,6 +35,11 @@ def distribution(self, params: ParamsStandardHMMInitialState, inputs=None) -> tf """Return the distribution object of the initial distribution.""" return tfd.Categorical(probs=params.probs) + def _check_initialization_format(self, initial_probs: Float[Array, " num_states"]) -> None: + assert initial_probs.shape == (self.num_states,), f"'initial_probs' must have shape (num_states,)={(self.num_states,)} but {initial_probs.shape} provided." + assert jnp.all(initial_probs >= 0.0), f"All entries in 'initial_probs' must be non-negative." + assert jnp.isclose(initial_probs.sum(), 1.0), ValueError(f"'initial_probs' must sum to 1.0.") + def initialize( self, key: Optional[Array]=None, @@ -59,6 +64,7 @@ def initialize( this_key, key = jr.split(key) initial_probs = tfd.Dirichlet(self.initial_probs_concentration).sample(seed=this_key) + self._check_initialization_format(initial_probs=initial_probs) # Package the results into dictionaries params = ParamsStandardHMMInitialState(probs=initial_probs) props = ParamsStandardHMMInitialState(probs=ParameterProperties(constrainer=tfb.SoftmaxCentered())) diff --git a/dynamax/hidden_markov_model/models/transitions.py b/dynamax/hidden_markov_model/models/transitions.py index c01bfa63..d9029ae6 100644 --- a/dynamax/hidden_markov_model/models/transitions.py +++ b/dynamax/hidden_markov_model/models/transitions.py @@ -51,6 +51,11 @@ def distribution(self, params: ParamsStandardHMMTransitions, state: IntScalar, i """Return the distribution over the next state given the current state.""" return tfd.Categorical(probs=params.transition_matrix[state]) + def _check_transitions_format(self, transition_matrix: Float[Array, "num_states num_states"]): + assert transition_matrix.shape == (self.num_states, self.num_states), f"'transition_matrix' must have shape (num_states, num_states)={(self.num_states, self.num_states)} but {transition_matrix.shape} provided." + assert jnp.all(transition_matrix >= 0.0), f"All entries in 'transition_matrix' must be non-negative." + assert jnp.isclose(transition_matrix.sum(axis=1), 1.0).all(), f"Each row of 'transition_matrix' must sum to 1.0." + def initialize( self, key: Optional[Array]=None, @@ -73,7 +78,8 @@ def initialize( else: transition_matrix_sample = tfd.Dirichlet(self.concentration).sample(seed=key) transition_matrix = cast(Float[Array, "num_states num_states"], transition_matrix_sample) - + + self._check_transitions_format(transition_matrix=transition_matrix) # Package the results into dictionaries params = ParamsStandardHMMTransitions(transition_matrix=transition_matrix) props = ParamsStandardHMMTransitions(transition_matrix=ParameterProperties(constrainer=tfb.SoftmaxCentered()))