Skip to content

Rlssm class make model dist#915

Open
cpaniaguam wants to merge 107 commits intocp-main-sbfrom
rlssm-class-make-model-dist
Open

Rlssm class make model dist#915
cpaniaguam wants to merge 107 commits intocp-main-sbfrom
rlssm-class-make-model-dist

Conversation

@cpaniaguam
Copy link
Copy Markdown
Collaborator

@cpaniaguam cpaniaguam commented Mar 2, 2026

This pull request introduces reinforcement learning sequential sampling model (RLSSM) support to the HSSM package. It adds a new RLSSM class, supporting configuration, likelihood construction, and data validation for RL+SSM models, and refines the configuration workflow to require a fully annotated log-likelihood function. The changes also improve pre-commit configuration and update the package's public API.

Major features and changes:

1. RLSSM Model Integration

  • Added a new RLSSM class in src/hssm/rl/rlssm.py to support models that combine reinforcement learning processes with sequential sampling models. This class builds a differentiable pytensor Op from an annotated JAX log-likelihood function and enforces strict data requirements for balanced panels.
  • Introduced a utility function validate_balanced_panel in src/hssm/rl/utils.py to ensure input data forms a balanced panel, which is required for RLSSM models.

2. Configuration Enhancements

  • Extended RLSSMConfig in src/hssm/config.py to require an ssm_logp_func (an annotated JAX SSM log-likelihood function), replacing the previous loglik/loglik_kind workflow. Added runtime validation to ensure this function is callable and properly annotated. [1] [2] [3]
  • Updated from_rlssm_dict to accept a config dictionary and extract ssm_logp_func and model_name directly from it, simplifying model instantiation.

3. Public API and Package Structure

  • Registered RLSSM and RLSSMConfig in the package's public API via src/hssm/__init__.py and created a new src/hssm/rl/__init__.py for RL-related exports. [1] [2] [3]

4. Developer Experience

  • Updated .pre-commit-config.yaml to exclude the tests/ directory from ruff and mypy checks, streamlining development workflows.

@cpaniaguam cpaniaguam changed the base branch from main to cp-main-sb March 2, 2026 18:41
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds first-class RL + SSM (RLSSM) support to HSSM by introducing a new RLSSM model that builds a differentiable PyTensor Op from an annotated JAX SSM log-likelihood and plugs it into the existing distribution-building pipeline.

Changes:

  • Introduces RLSSM model class plus RL utility validate_balanced_panel.
  • Extends configuration via RLSSMConfig.ssm_logp_func and exposes RLSSM in the public API.
  • Adds test coverage for RLSSM initialization/model build and updates RLSSMConfig validation tests.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
src/hssm/rl/rlssm.py New RLSSM model implementation integrating RL likelihood Op into HSSMBase.
src/hssm/rl/utils.py Adds balanced-panel validation helper for RLSSM datasets.
src/hssm/rl/__init__.py RL subpackage exports for RLSSM and utilities.
src/hssm/config.py Adds ssm_logp_func to RLSSMConfig and validates presence.
src/hssm/__init__.py Exposes RLSSM / RLSSMConfig at top-level.
tests/test_rlssm.py New end-to-end-ish RLSSM tests (init, model build, balanced panel, smoke sampling).
tests/test_rlssm_config.py Updates RLSSMConfig tests to include the new required field.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

f"same number of trials. Observed trial counts: {dict(counts)}"
)

return int(len(counts)), int(counts.iloc[0])
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_balanced_panel only checks equal trial counts, but the RL likelihood builder reshapes the row order into (n_participants, n_trials, ...) (see make_rl_logp_func), which assumes each participant’s trials are in one contiguous block (and usually in-trial order). With interleaved participants, the panel can be “balanced” yet produce a silently incorrect likelihood. Consider validating contiguity (each participant appears in exactly one run of length n_trials) and/or sorting by participant_col (+ an optional trial_col if present) before returning (n_participants, n_trials).

Suggested change
return int(len(counts)), int(counts.iloc[0])
# Ensure that each participant's trials form a single contiguous block
# of rows of length n_trials. This is required because downstream code
# reshapes the data into (n_participants, n_trials, ...) based on row
# order, assuming no interleaving across participants.
n_trials = int(counts.iloc[0])
# Identify contiguous "blocks" of identical participant IDs.
blocks = data[participant_col].ne(data[participant_col].shift()).cumsum()
block_counts = data.groupby([participant_col, blocks]).size()
# Each participant must appear in exactly one block, and that block
# must have length n_trials.
blocks_per_participant = block_counts.groupby(level=0).size()
invalid_multi_blocks = blocks_per_participant[blocks_per_participant != 1]
invalid_block_sizes = block_counts[block_counts != n_trials]
if not invalid_multi_blocks.empty or not invalid_block_sizes.empty:
raise ValueError(
"Data must be ordered so that each participant's trials appear in "
"a single contiguous block of rows of length n_trials. "
"Participants with non-contiguous or incorrectly sized blocks "
f"were found. Consider sorting your data by '{participant_col}' "
"and, if available, by a trial index column before building the "
"RL likelihood."
)
return int(len(counts)), n_trials

Copilot uses AI. Check for mistakes.
"Please provide the correct participant column name via "
"`participant_col`."
)

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

groupby(participant_col) drops NaN participant IDs by default, which can make n_participants/n_trials incorrect without an explicit error. Consider adding a check like data[participant_col].isna().any() and raising a clear ValueError if participant IDs are missing.

Suggested change
# Ensure there are no missing participant IDs, since groupby will drop NaNs
# silently, which would make n_participants / n_trials incorrect.
if data[participant_col].isna().any():
raise ValueError(
f"Column '{participant_col}' contains missing values. "
"Please fill or remove rows with missing participant IDs before "
"calling validate_balanced_panel."
)

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +242 to +246
)

# Rearrange data so missing rows come first (no-op when missing_data=False).
self.data = _rearrange_data(self.data)

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_rearrange_data(self.data) changes row order, but the RL logp Op reshapes trials purely by row order into (n_participants, n_trials, ...). If any rows are moved (e.g., when missing_data=True and rt == -999), this will break per-participant trial sequences and invalidate the RL learning dynamics. Since missing-data networks are not supported for RLSSM, consider raising an explicit error when missing_data/deadline handling is requested (or implement a participant-wise rearrangement that preserves within-subject order).

Copilot uses AI. Check for mistakes.
Comment on lines +49 to +56
counts = data.groupby(participant_col).size()
if counts.nunique() != 1:
raise ValueError(
"Data must form balanced panels: all participants must have the "
f"same number of trials. Observed trial counts: {dict(counts)}"
)

return int(len(counts)), int(counts.iloc[0])
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_balanced_panel() only checks equal trial counts via groupby().size(), but it does not validate that rows are ordered/grouped by participant. The RL likelihood builder (make_rl_logp_func) reshapes arrays with .reshape(n_participants, n_trials, -1) based purely on row order, so interleaved participant rows will silently mix subjects/trials and produce an incorrect likelihood. Consider either (a) enforcing contiguous blocks per participant (and optionally stable-sorting by participant_col + a trial index column if available) or (b) returning a sorted copy of the data and using that downstream.

Copilot uses AI. Check for mistakes.
@cpaniaguam cpaniaguam requested a review from Copilot March 2, 2026 20:17
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +249 to +275
# All RLSSM parameters are treated as trialwise: the Op expects arrays of
# length n_total_trials for every parameter, and make_distribution.logp
# broadcasts scalar / (1,)-shaped tensors up to (n_obs,) accordingly.
params_is_trialwise = [
True for param_name in self.params if param_name != "p_outlier"
]

extra_fields_data = (
None
if not self.extra_fields
else [deepcopy(self.data[field].values) for field in self.extra_fields]
)

assert self.list_params is not None, "list_params should be set"
# self.loglik was set to the pytensor Op built in __init__; cast to
# narrow the inherited union type so make_distribution's type-checker
# accepts it without a runtime penalty.
loglik_op = cast("Callable[..., Any] | Op", self.loglik)
return make_distribution(
rv=self.model_name,
loglik=loglik_op,
list_params=self.list_params,
bounds=self.bounds,
lapse=self.lapse,
extra_fields=extra_fields_data,
params_is_trialwise=params_is_trialwise,
)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

params_is_trialwise is derived from self.params (excluding p_outlier), but it is passed alongside list_params=self.list_params. If self.list_params includes p_outlier (common in HSSMBase), this makes params_is_trialwise shorter and potentially misaligned with list_params, which can cause incorrect broadcasting or length-check failures in make_distribution. Build params_is_trialwise from self.list_params in the same order, marking p_outlier as non-trialwise.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Member

@AlexanderFengler AlexanderFengler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small comments.

data: pd.DataFrame,
rlssm_config: RLSSMConfig,
participant_col: str = "participant_id",
include: list[dict[str, Any] | Any] | None = None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you call it instead here?
We would want to make that change globally not just for this class I guess.

Either way, would do that as a separate PR.

)
if deadline is not False:
raise ValueError(
"RLSSM does not support `deadline` handling. "
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krishnbera do we actually have a solution for this?

src/hssm/base.py Outdated
"""
# Start with defaults
config = cls.config_class.from_defaults(model, loglik_kind)
# get_config_class is provided by Config/RLSSMConfig mixin through MRO
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does RLSSMConfig show up here in this file?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this will be cleaned up after #936 and #931 get merged into their respective base branches.

@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@cpaniaguam cpaniaguam marked this pull request as ready for review March 31, 2026 15:56
…oglik_kind key in RLSSMConfig; update model instantiation parameter name
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants