Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
20ddc4c
Add ssm_logp_func to RLSSMConfig and update validation tests
cpaniaguam Mar 2, 2026
d97dcee
Add RLSSM model and utilities for reinforcement learning integration
cpaniaguam Mar 2, 2026
a6a0238
Refactor RLSSM parameter handling and add custom prefix resolution fo…
cpaniaguam Mar 2, 2026
d880977
Add tests for RLSSM class covering initialization, validation, and mo…
cpaniaguam Mar 2, 2026
bef8d6c
Refactor loglik handling in RLSSM to improve type safety with casting
cpaniaguam Mar 2, 2026
3981ef6
Add NaN value check for participant column in validate_balanced_panel…
cpaniaguam Mar 2, 2026
d84a800
Add validation for ssm_logp_func in RLSSMConfig to ensure it is calla…
cpaniaguam Mar 2, 2026
15ad6e2
Add exclude rules for ruff and mypy hooks to skip tests directory
cpaniaguam Mar 2, 2026
262ec07
Add validation tests for ssm_logp_func in RLSSMConfig to ensure it is…
cpaniaguam Mar 2, 2026
381275a
Add tests for NaN participant_id and unannotated ssm_logp_func in RLSSM
cpaniaguam Mar 2, 2026
0e9ba42
Reject missing data and deadline handling in RLSSM initialization to …
cpaniaguam Mar 2, 2026
4f28c68
Add tests to validate error handling for missing data and deadline in…
cpaniaguam Mar 2, 2026
5e9f566
Refactor path handling for loading RLDM fixture dataset in tests
cpaniaguam Mar 2, 2026
67ac2ce
Add fixture to set floatX to float32 for module tests
cpaniaguam Mar 2, 2026
e1c05df
Ensure params_is_trialwise aligns with list_params in RLSSM initializ…
cpaniaguam Mar 2, 2026
564232b
Clarify comments on default_priors in ModelConfig and remove unnecess…
cpaniaguam Mar 2, 2026
bafc037
Update RLSSM to use to_numpy(copy=True) for extra_fields and add test…
cpaniaguam Mar 2, 2026
ba358a4
Refactor parameter name resolution in RLSSM to handle underscores cor…
cpaniaguam Mar 2, 2026
0bfa755
Add test for _get_prefix method in RLSSM to ensure token-based matching
cpaniaguam Mar 2, 2026
5b8a16a
Refactor RLSSMConfig.from_rlssm_dict to remove model_name parameter a…
cpaniaguam Mar 2, 2026
f69f2b6
Fix comment in test_rlssm.py to clarify output shape of log-likelihoo…
cpaniaguam Mar 3, 2026
bad943d
Update RLSSMConfig documentation to mark description as required
cpaniaguam Mar 3, 2026
241aad2
Add ssm_logp_func to RLSSM_REQUIRED_FIELDS and update RLSSMConfig ini…
cpaniaguam Mar 3, 2026
ca3816d
Add dummy ssm_logp_func to tests and validate its presence in RLSSMCo…
cpaniaguam Mar 3, 2026
827025c
Remove unused logging import from rlssm.py
cpaniaguam Mar 3, 2026
292d6f0
Remove redundant exclude rule for ruff-format in pre-commit configura…
cpaniaguam Mar 3, 2026
3b1aaf4
Add to_model_config method to RLSSMConfig for ModelConfig conversion
cpaniaguam Mar 4, 2026
0678c45
Refactor RLSSM to delegate ModelConfig construction to RLSSMConfig an…
cpaniaguam Mar 4, 2026
26a9336
Integrate Config and RLSSMConfig into HSSM and RLSSM classes for impr…
cpaniaguam Mar 9, 2026
cd660da
Update choices type from list to tuple for consistency in BaseModelCo…
cpaniaguam Mar 9, 2026
db308c6
Update choices type from list to tuple in test_constructor for consis…
cpaniaguam Mar 9, 2026
9e5e8be
Add deprecation warnings for model_config attributes in HSSMBase
cpaniaguam Mar 11, 2026
74acc67
Refactor HSSMBase to support BaseModelConfig and improve model_config…
cpaniaguam Mar 11, 2026
2acc19d
Add model configuration building methods to BaseModelConfig and Confi…
cpaniaguam Mar 11, 2026
97bf90f
Refactor model configuration handling in HSSMBase and HSSM classes to…
cpaniaguam Mar 11, 2026
f75f5e4
Add properties to BaseModelConfig for parameter and extra field counts
cpaniaguam Mar 11, 2026
b9b4839
Refactor RLSSM attributes to use public naming convention for configu…
cpaniaguam Mar 11, 2026
821392f
Refactor test_rlssm_panel_attrs to use public attributes for particip…
cpaniaguam Mar 11, 2026
07169c9
Refactor HSSMBase to streamline model configuration handling and upda…
cpaniaguam Mar 11, 2026
0395ec2
Refactor BaseModelConfig and RLSSMConfig by removing unused abstract …
cpaniaguam Mar 11, 2026
626301d
Refactor HSSM class to remove Config inheritance and add initializati…
cpaniaguam Mar 11, 2026
6c13443
Refactor RLSSM class to remove RLSSMConfig inheritance and streamline…
cpaniaguam Mar 11, 2026
66296f0
Refactor Config and RLSSMConfig classes to use concrete types in meth…
cpaniaguam Mar 11, 2026
801f235
Update Config class parameter types for choices to improve type safety
cpaniaguam Mar 11, 2026
607874c
Update choices method to accept a tuple for model_config.choices
cpaniaguam Mar 11, 2026
9e25a32
Add tests for model configuration handling and choices logic in Config
cpaniaguam Mar 11, 2026
2b2d66b
Enhance HSSMBase initialization with safe default for constructor arg…
cpaniaguam Mar 11, 2026
d5b9d80
Update model_config validation to check for non-null choices
cpaniaguam Mar 11, 2026
9bf18ea
Refactor HSSM distribution method to use typed model_config attribute…
cpaniaguam Mar 11, 2026
9af3e95
Update test cases to use tuples for choices in model configuration
cpaniaguam Mar 12, 2026
c2e09d9
Refactor RLSSM to utilize model_config for list_params and loglik, en…
cpaniaguam Mar 12, 2026
1aa19f2
Fix typo in comment regarding model_config choices validation
cpaniaguam Mar 12, 2026
0ea0998
Refactor RLSSM tests to access model configuration attributes directl…
cpaniaguam Mar 12, 2026
8f526f4
Update attribute comparison in compare_hssm_class_attributes to use m…
cpaniaguam Mar 12, 2026
5dd68a5
Update test assertions to access model configuration attributes directly
cpaniaguam Mar 12, 2026
7054ccd
Refactor model configuration normalization to streamline choices hand…
cpaniaguam Mar 12, 2026
5e816bc
Refactor choices handling in Config class to improve clarity and logging
cpaniaguam Mar 12, 2026
9f6a7ef
Refactor _normalize_model_config_with_choices to improve input handli…
cpaniaguam Mar 12, 2026
49415ab
Refactor likelihood callable construction to simplify logic and enhan…
cpaniaguam Mar 12, 2026
4452f36
Refactor _make_model_distribution to utilize model_config for loglik …
cpaniaguam Mar 12, 2026
c34e562
Fix formatting in HSSM class for consistency in likelihood callable p…
cpaniaguam Mar 12, 2026
3e86974
Fix formatting in HSSM class for consistency in likelihood callable p…
cpaniaguam Mar 12, 2026
4a5aefc
Refactor HSSM class to use typed model_config attributes directly and…
cpaniaguam Mar 12, 2026
cdc7763
Restore make_model_dist in HSSM
cpaniaguam Mar 13, 2026
e3cbcb7
Remove deprecated properties and methods from HSSMBase class
cpaniaguam Mar 13, 2026
7e481e0
Enhance HSSMBase class to prevent overwriting _init_args if already s…
cpaniaguam Mar 13, 2026
3432bca
Clarify model_config parameter documentation in HSSMBase class to spe…
cpaniaguam Mar 13, 2026
31bd6f1
Enhance HSSMBase class documentation to clarify filtering of internal…
cpaniaguam Mar 13, 2026
296810b
Update model_config parameter documentation in HSSM class to support …
cpaniaguam Mar 13, 2026
95779bc
Add test to validate external model config fallback in _build_model_c…
cpaniaguam Mar 13, 2026
37ea9be
Update sampling parameters in test_rlssm_sample_smoke for speed
cpaniaguam Mar 17, 2026
9a37dd8
Add RLSSM quickstart notebook for model instantiation and sampling de…
cpaniaguam Mar 17, 2026
43ec652
Add RLSSM Quickstart tutorial to navigation and plugins
cpaniaguam Mar 17, 2026
7f1e6ff
Remove redundant next steps and streamline summary in RLSSM quickstar…
cpaniaguam Mar 17, 2026
aec1531
Refactor RLSSMConfig methods to simplify parameter handling and remov…
cpaniaguam Mar 18, 2026
a8cd51d
Fix handling of list_params in HSSMBase to ensure proper conversion f…
cpaniaguam Mar 18, 2026
9c22e26
Refactor RLSSM to inject model configuration directly, removing unnec…
cpaniaguam Mar 18, 2026
5658834
Update TestRLSSMConfigDefaults to reflect None for default parameters…
cpaniaguam Mar 18, 2026
7a294af
Refactor RLSSM to inject loglik and backend directly into a new RLSSM…
cpaniaguam Mar 18, 2026
fd99efb
Add validation for missing bounds in RLSSMConfig parameters
cpaniaguam Mar 18, 2026
bc0f7ca
Fix RLSSM to use model_config for ssm_logp_func and update test cases…
cpaniaguam Mar 18, 2026
b075e4f
Enhance RLSSM tests to align params_is_trialwise with list_params and…
cpaniaguam Mar 18, 2026
27d505e
Add test to ensure RLSSMConfig.from_defaults raises NotImplementedError
cpaniaguam Mar 18, 2026
ce8e187
Clarify RLSSMConfig.from_defaults behavior and raise NotImplementedEr…
cpaniaguam Mar 18, 2026
7c7fd32
Inject JAX backend into RLSSMConfig during initialization
cpaniaguam Mar 18, 2026
e604406
Refactor RLSSM class to use model_config instead of rlssm_config for …
cpaniaguam Mar 18, 2026
582a6fe
Merge branch '930-pass-configs-via-dependency-injection-into-model-cl…
cpaniaguam Mar 19, 2026
a3898d7
Fix merge conflicts with base branch
cpaniaguam Mar 19, 2026
4d99410
Remove commented out lines
cpaniaguam Mar 19, 2026
f04f47e
Remove RLSSMConfig import from __init__.py
cpaniaguam Mar 25, 2026
11115af
Reorganize import statements by moving RLSSMConfig import to the corr…
cpaniaguam Mar 25, 2026
6a9384f
Move RLSSMConfig import to the correct module in test files
cpaniaguam Mar 25, 2026
0285f04
Update docstring in __init__.py and exports
cpaniaguam Mar 25, 2026
5807a71
Remove RLSSMConfig class and its associated methods from config.py
cpaniaguam Mar 25, 2026
4bf67ea
Move RLSSMConfig class hssm.rl module
cpaniaguam Mar 25, 2026
5d74bfe
Refactor config.py to remove RLSSM-specific defaults and unify observ…
cpaniaguam Mar 27, 2026
ef000cd
Fix formatting of error messages in TestRLSSMConfigValidation for con…
cpaniaguam Mar 27, 2026
91b1098
Enhance validation in RLSSMConfig for ssm_logp_func attributes
cpaniaguam Mar 27, 2026
c3a4f52
Add validation test for non-callable values in ssm_logp_func.computed
cpaniaguam Mar 27, 2026
692dc5d
Rename 'learning_process_loglik_kind' to 'learning_process_kind' in R…
cpaniaguam Mar 30, 2026
7cf8bca
Simplify response and list_params assignment in HSSMBase by removing …
cpaniaguam Mar 30, 2026
c46f923
Revert "Simplify response and list_params assignment in HSSMBase by r…
cpaniaguam Mar 30, 2026
fac838e
Refactor RLSSMConfig to dynamically retrieve required fields for vali…
cpaniaguam Mar 31, 2026
3084df4
Update RLSSMConfig to handle field exceptions in from_rlssm_dict method
cpaniaguam Mar 31, 2026
babee92
Merge pull request #936 from lnccbrown/inject-RLSSMConfig-directly-in…
cpaniaguam Mar 31, 2026
b0e2179
Merge pull request #931 from lnccbrown/930-pass-configs-via-dependenc…
cpaniaguam Mar 31, 2026
1233cd7
Fix import path for RLSSM and RLSSMConfig; correct learning_process_l…
cpaniaguam Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ repos:
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
exclude: ^tests/
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.19.1 # Use the sha / tag you want to point at
hooks:
- id: mypy
args: [--no-strict-optional, --ignore-missing-imports]
exclude: ^tests/
Comment thread
cpaniaguam marked this conversation as resolved.
351 changes: 351 additions & 0 deletions docs/tutorials/rlssm_quickstart.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "1b9b429d",
"metadata": {},
"source": [
"# RLSSM Quickstart: Instantiation, Model Building, and Sampling\n",
"\n",
"This notebook provides a minimal end-to-end demonstration of the `RLSSM` class:\n",
"\n",
"1. **Load** a balanced-panel two-armed bandit dataset\n",
"2. **Define** an annotated learning function and the angle SSM log-likelihood\n",
"3. **Configure** and **instantiate** an `RLSSM` model\n",
"4. **Inspect** the built Bambi / PyMC model\n",
"5. **Run** a minimal 2-draw sampling smoke test\n",
"\n",
"For a full treatment — simulating data, hierarchical formulas, meaningful sampling, and posterior visualization — see:\n",
"- [rlssm_tutorial.ipynb](rlssm_tutorial.ipynb)\n",
"- [add_custom_rlssm_model.ipynb](add_custom_rlssm_model.ipynb)"
]
},
{
"cell_type": "markdown",
"id": "bf38d7f7",
"metadata": {},
"source": [
"## 1. Imports and Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d764731",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import hssm\n",
"from hssm.rl import RLSSM, RLSSMConfig\n",
"from hssm.distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx\n",
"from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise\n",
"from hssm.utils import annotate_function\n",
"\n",
"# RLSSM requires float32 throughout (JAX default).\n",
"hssm.set_floatX(\"float32\", update_jax=True)"
]
},
{
"cell_type": "markdown",
"id": "df12303f",
"metadata": {},
"source": [
"## 2. Load the Dataset\n",
"\n",
"We use a small synthetic two-armed bandit dataset from the HSSM test fixtures. \n",
"It is a **balanced panel**: every participant has the same number of trials. \n",
"Columns: `participant_id`, `trial_id`, `rt`, `response`, `feedback`.\n",
"\n",
"> **Note:** You can also generate data with\n",
"> [`ssm-simulators`](https://github.com/AlexanderFengler/ssm-simulators).\n",
"> See `rlssm_tutorial.ipynb` for an example."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2ef5f6e",
"metadata": {},
"outputs": [],
"source": [
"# Path relative to docs/tutorials/ when running inside the HSSM repo.\n",
"_fixture_path = Path(\"../../tests/fixtures/rldm_data.npy\")\n",
"raw = np.load(_fixture_path, allow_pickle=True).item()\n",
"data = pd.DataFrame(raw[\"data\"])\n",
"\n",
"n_participants = data[\"participant_id\"].nunique()\n",
"n_trials = len(data) // n_participants\n",
"\n",
"print(data.head())\n",
"print(f\"\\nParticipants: {n_participants} | Trials per participant: {n_trials}\")"
]
},
{
"cell_type": "markdown",
"id": "8c310290",
"metadata": {},
"source": [
"## 3. Define the Learning Process\n",
"\n",
"The RL learning process is a JAX function that, given a subject's trial sequence, computes\n",
"the trial-wise drift rate `v` via a Q-learning update rule. \n",
"\n",
"`annotate_function` attaches `.inputs`, `.outputs`, and (optionally) `.computed` metadata\n",
"that the RLSSM likelihood builder uses to automatically construct the input matrix for the\n",
"decision process.\n",
"\n",
"- **inputs** — columns that the function reads (free parameters + data columns)\n",
"- **outputs** — what the function produces (here: `v`, the drift rate)\n",
"\n",
"Here we annotate the built-in `compute_v_subject_wise` function, which implements a simple\n",
"Rescorla-Wagner Q-learning update for a two-armed bandit task."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbcea122",
"metadata": {},
"outputs": [],
"source": [
"compute_v_annotated = annotate_function(\n",
" inputs=[\"rl_alpha\", \"scaler\", \"response\", \"feedback\"],\n",
" outputs=[\"v\"],\n",
")(compute_v_subject_wise)\n",
"\n",
"print(\"Learning function inputs :\", compute_v_annotated.inputs)\n",
"print(\"Learning function outputs:\", compute_v_annotated.outputs)"
]
},
{
"cell_type": "markdown",
"id": "7a03305a",
"metadata": {},
"source": [
"## 4. Define the Decision (SSM) Log-Likelihood\n",
"\n",
"The decision process uses the **angle model** likelihood, loaded from an ONNX file.\n",
"`make_jax_matrix_logp_funcs_from_onnx` returns a JAX callable that accepts a\n",
"2-D matrix whose columns are `[v, a, z, t, theta, rt, response]` and returns\n",
"per-trial log-probabilities.\n",
"\n",
"We then annotate that callable so the builder knows:\n",
"- which columns the matrix contains (`inputs`)\n",
"- that `v` itself is *computed* by the learning function (not a free parameter)\n",
"\n",
"The ONNX file is loaded from the local test fixture when running inside the HSSM\n",
"repository; otherwise it is downloaded from the HuggingFace Hub (`franklab/HSSM`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60bbc036",
"metadata": {},
"outputs": [],
"source": [
"# Use the local fixture when available; fall back to HuggingFace download.\n",
"_local_onnx = Path(\"../../tests/fixtures/angle.onnx\").resolve()\n",
"_onnx_model = str(_local_onnx) if _local_onnx.exists() else \"angle.onnx\"\n",
"\n",
"_angle_logp_jax = make_jax_matrix_logp_funcs_from_onnx(model=_onnx_model)\n",
"\n",
"angle_logp_func = annotate_function(\n",
" inputs=[\"v\", \"a\", \"z\", \"t\", \"theta\", \"rt\", \"response\"],\n",
" outputs=[\"logp\"],\n",
" computed={\"v\": compute_v_annotated},\n",
")(_angle_logp_jax)\n",
"\n",
"print(\"SSM logp inputs :\", angle_logp_func.inputs)\n",
"print(\"SSM logp outputs:\", angle_logp_func.outputs)\n",
"print(\"Computed deps :\", list(angle_logp_func.computed.keys()))"
]
},
{
"cell_type": "markdown",
"id": "cf8f5b63",
"metadata": {},
"source": [
"## 5. Configure the Model with `RLSSMConfig`\n",
"\n",
"`RLSSMConfig` collects all the information the RLSSM class needs:\n",
"\n",
"| Field | Purpose |\n",
"|-------|---------|\n",
"| `model_name` | Identifier string for the configuration |\n",
"| `decision_process` | Name of the SSM (e.g. `\"angle\"`) |\n",
"| `list_params` | Ordered list of *free* parameters to sample |\n",
"| `params_default` | Starting / default values for each parameter |\n",
"| `bounds` | Prior bounds for each parameter |\n",
"| `learning_process` | Dict mapping computed param name → annotated learning function |\n",
"| `extra_fields` | Extra data columns required by the learning function |\n",
"| `ssm_logp_func` | Annotated JAX callable for the decision-process likelihood |"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4beba1bc",
"metadata": {},
"outputs": [],
"source": [
"rlssm_config = RLSSMConfig(\n",
" model_name=\"rlssm_angle_quickstart\",\n",
" loglik_kind=\"approx_differentiable\",\n",
" decision_process=\"angle\",\n",
" decision_process_loglik_kind=\"approx_differentiable\",\n",
" learning_process_kind=\"blackbox\",\n",
" list_params=[\"rl_alpha\", \"scaler\", \"a\", \"theta\", \"t\", \"z\"],\n",
" params_default=[0.1, 1.0, 1.0, 0.0, 0.3, 0.5],\n",
" bounds={\n",
" \"rl_alpha\": (0.0, 1.0),\n",
" \"scaler\": (0.0, 10.0),\n",
" \"a\": (0.1, 3.0),\n",
" \"theta\": (-0.1, 0.1),\n",
" \"t\": (0.001, 1.0),\n",
" \"z\": (0.1, 0.9),\n",
" },\n",
" learning_process={\"v\": compute_v_annotated},\n",
" response=[\"rt\", \"response\"],\n",
" choices=[0, 1],\n",
" extra_fields=[\"feedback\"],\n",
" ssm_logp_func=angle_logp_func,\n",
")\n",
"\n",
"print(\"Model name :\", rlssm_config.model_name)\n",
"print(\"Free params :\", rlssm_config.list_params)"
]
},
{
"cell_type": "markdown",
"id": "924ee4c7",
"metadata": {},
"source": [
"## 6. Instantiate the `RLSSM` Model\n",
"\n",
"Passing `data` and `rlssm_config` to `RLSSM`:\n",
"\n",
"- validates the balanced-panel requirement\n",
"- builds a differentiable PyTensor Op that chains the RL learning step and the\n",
" angle log-likelihood\n",
"- constructs the Bambi / PyMC model internally\n",
"\n",
"Note that `v` (the drift rate) is *not* a free parameter — it is computed inside\n",
"the Op by the Q-learning update and therefore does not appear in `model.params`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f8da79a",
"metadata": {},
"outputs": [],
"source": [
"model = RLSSM(data=data, model_config=rlssm_config)\n",
"\n",
"assert isinstance(model, RLSSM)\n",
"print(\"Model type :\", type(model).__name__)\n",
"print(\"Participants :\", model.n_participants)\n",
"print(\"Trials/subj :\", model.n_trials)\n",
"print(\"Free parameters :\", list(model.params.keys()))\n",
"assert \"rl_alpha\" in model.params, \"rl_alpha must be a free parameter\"\n",
"assert \"v\" not in model.params, \"v is computed, not a free parameter\"\n",
"model"
]
},
{
"cell_type": "markdown",
"id": "f7f39940",
"metadata": {},
"source": [
"## 7. Inspect the Built Model\n",
"\n",
"After construction, `model.model` exposes the underlying **Bambi model** and\n",
"`model.pymc_model` exposes the **PyMC model** context — useful for debugging\n",
"or customizing priors."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0558ad4",
"metadata": {},
"outputs": [],
"source": [
"print(\"=== Bambi model ===\")\n",
"print(model.model)\n",
"\n",
"print(\"\\n=== PyMC model ===\")\n",
"print(model.pymc_model)"
]
},
{
"cell_type": "markdown",
"id": "f4e50110",
"metadata": {},
"source": [
"## 8. Sampling\n",
"\n",
"A minimal sampling run — 2 draws, 2 tuning steps, 1 chain — confirms that the full\n",
"computational graph (Q-learning scan → angle logp → NUTS gradient) is wired correctly."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96ce3238",
"metadata": {},
"outputs": [],
"source": [
"trace = model.sample(draws=2, tune=2, chains=1, cores=1, sampler=\"numpyro\", target_accept=0.9)\n",
"\n",
"assert trace is not None\n",
"print(trace)"
]
},
{
"cell_type": "markdown",
"id": "a784a468",
"metadata": {},
"source": [
"## Summary\n",
"\n",
"This notebook showed how to:\n",
"\n",
"1. Load a balanced-panel dataset (`rldm_data.npy`)\n",
"2. Annotate a Q-learning function with `annotate_function`\n",
"3. Load the angle ONNX likelihood and annotate it so the builder can assemble the input matrix\n",
"4. Define an `RLSSMConfig` and pass it to `RLSSM`\n",
"5. Confirm model structure (free params, Bambi / PyMC objects)\n",
"6. Run a 2-draw sampling smoke test that returns an `arviz.InferenceData` object"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "hssm",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ nav:
- Hierarchical Variational Inference: tutorials/variational_inference_hierarchical.ipynb
- Using HSSM low-level API directly with PyMC: tutorials/pymc.ipynb
- Reinforcement Learning - Sequential Sampling Models (RLSSM): tutorials/rlssm_tutorial.ipynb
- RLSSM Quickstart: tutorials/rlssm_quickstart.ipynb
- Add custom RLSSM models: tutorials/add_custom_rlssm_model.ipynb
- Custom models: tutorials/jax_callable_contribution_onnx_example.ipynb
- Custom models from onnx files: tutorials/blackbox_contribution_onnx_example.ipynb
Expand Down Expand Up @@ -91,6 +92,7 @@ plugins:
- tutorials/hssm_tutorial_workshop_2.ipynb
- tutorials/add_custom_rlssm_model.ipynb
- tutorials/rlssm_tutorial.ipynb
- tutorials/rlssm_quickstart.ipynb
- tutorials/lapse_prob_and_dist.ipynb
- tutorials/plotting.ipynb
- tutorials/scientific_workflow_hssm.ipynb
Expand Down
2 changes: 2 additions & 0 deletions src/hssm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .param import UserParam as Param
from .prior import Prior
from .register import register_model
from .rl import RLSSM
from .simulator import simulate_data
from .utils import check_data_for_rl, set_floatX

Expand All @@ -31,6 +32,7 @@

__all__ = [
"HSSM",
"RLSSM",
"Link",
"load_data",
"ModelConfig",
Expand Down
Loading