Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR introduces a new generate_with_obs_dist method to both HMM implementations and adds tests to verify its output shapes.
- Implements
generate_with_obs_distinGeneralizedHiddenMarkovModelfor batch-wise sequence generation with returned observation probabilities. - Adds shape-based tests for
generate_with_obs_distin bothtest_hidden_markov_model.pyandtest_generalized_hidden_markov_model.py. - Extends the existing HMM tests to ensure both state and observation outputs match expected dimensions.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| tests/generative_processes/test_hidden_markov_model.py | Added test_generate_with_obs_dist and related shape assertions |
| tests/generative_processes/test_generalized_hidden_markov_model.py | Parametrized test_generate_with_obs_dist for both model fixtures |
| simplexity/generative_processes/generalized_hidden_markov_model.py | Introduced generate_with_obs_dist method with vmapped vectorization |
Comments suppressed due to low confidence (1)
tests/generative_processes/test_hidden_markov_model.py:127
- Consider adding an assertion to verify that
intermediate_obs_probssums to 1 along the vocabulary axis (e.g.,jnp.allclose(intermediate_obs_probs.sum(-1), 1.0)), ensuring they form valid probability distributions.
assert intermediate_obs_probs.shape == (batch_size, sequence_len, z1r.vocab_size)
| """Generate a batch of sequences of observations from the generative process. | ||
|
|
||
| Inputs: | ||
| state: (batch_size, num_states) | ||
| key: (batch_size, 2) | ||
| Returns: tuple of (belief states, observations, observation probabilities) where: |
There was a problem hiding this comment.
[nitpick] The docstring describes batch input shapes, but this method is vmapped over a per-sample state and key. Clarify that inputs are single-example (no batch dim) and outputs are vectorized across the batch.
| """Generate a batch of sequences of observations from the generative process. | |
| Inputs: | |
| state: (batch_size, num_states) | |
| key: (batch_size, 2) | |
| Returns: tuple of (belief states, observations, observation probabilities) where: | |
| """Generate sequences of observations from the generative process. | |
| Inputs (per-sample, no batch dimension): | |
| state: (num_states,) | |
| key: (2,) | |
| Returns (vectorized across the batch): |
hrbigelow
left a comment
There was a problem hiding this comment.
I did just a basic wall-clock profiling comparison between generate_with_obs_dist and my original implementation without using the simplexity API functions. I get a roughly 2.4x speed up. 18 seconds for my implementation, 45 seconds for this one.
I believe this is mostly due to the redundant calculations in your transition_states and observation_probability_distribution functions.
As we discussed, I'm okay if you want to keep things as a reference implementation for purposes of code clarity, but I will probably stick with my side implementation in that case.

No description provided.