-
Notifications
You must be signed in to change notification settings - Fork 101
Description
I don't think the current hmm_fit_sgd
function is using the length of the time series as we hoped. At least with the default loss function, the length is just scaling the loss. Really, we need to change marginal_log_prob
to only compute the log probability of observation up to the specified length.
Following up on our slack conversation, I see two ways of doing that:
- Pad time series with nan's and modify
_conditional_logliks
to put zeros wherever the emission is nan. That way thehmm_filter
will still compute the marginal log prob of just the observed data. I think trick should also leave thehmm_smoother
computations unchanged. A cool added benefit of this is it would allow us to interpolate over chunks of missing data.
It would look like this:
# Perform a nested vmap over timeteps and states
f = lambda emission: \
vmap(lambda state: \
self.emission_distribution(state).log_prob(emission))(
jnp.arange(self.num_states)
)
lls = vmap(f)(emissions)
return jnp.where(jnp.isnan(lls), 0, lls)
I tested this out and the only problem is that we can't take gradients back through this function wrt model parameters. They nan out because one of the paths through the where
is nan. See jax-ml/jax#1052.
There's a somewhat clunky fix, which is to find the nan's first, replace them with a default value of the emissions, compute the log likelihoods, and then zero out the entries that were originally nan. That would look something like this:
bad = jnp.any(jnp.isnan(emissions), axis=1)
tmp = jnp.where(jnp.broadcast_to(bad[:, None], emissions.shape), 0.0, emissions)
lls = vmap(f)(tmp)
return jnp.where(jnp.broadcast_to(bad[:, None], lls.shape), 0.0, lls)
It's not the prettiest, but it works.
- Alternatively, we could pass the length of the time series to the underlying inference functions like
hmm_filter
. Then those functions would need to use a while loop to dynamically stop the message passing once the length has been reached. (I tried implementing this by calling filter on a dynamic slice of the data, but JAX barfed on that...) This approach is totally doable, but it would lead to lots of extra logic in the inference code.
I'm working on a demo of approach 1 right now. Will keep you posted!