-
Notifications
You must be signed in to change notification settings - Fork 19
Rlssm class make model dist #915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
cpaniaguam
wants to merge
107
commits into
cp-main-sb
Choose a base branch
from
rlssm-class-make-model-dist
base: cp-main-sb
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
107 commits
Select commit
Hold shift + click to select a range
20ddc4c
Add ssm_logp_func to RLSSMConfig and update validation tests
cpaniaguam d97dcee
Add RLSSM model and utilities for reinforcement learning integration
cpaniaguam a6a0238
Refactor RLSSM parameter handling and add custom prefix resolution fo…
cpaniaguam d880977
Add tests for RLSSM class covering initialization, validation, and mo…
cpaniaguam bef8d6c
Refactor loglik handling in RLSSM to improve type safety with casting
cpaniaguam 3981ef6
Add NaN value check for participant column in validate_balanced_panel…
cpaniaguam d84a800
Add validation for ssm_logp_func in RLSSMConfig to ensure it is calla…
cpaniaguam 15ad6e2
Add exclude rules for ruff and mypy hooks to skip tests directory
cpaniaguam 262ec07
Add validation tests for ssm_logp_func in RLSSMConfig to ensure it is…
cpaniaguam 381275a
Add tests for NaN participant_id and unannotated ssm_logp_func in RLSSM
cpaniaguam 0e9ba42
Reject missing data and deadline handling in RLSSM initialization to …
cpaniaguam 4f28c68
Add tests to validate error handling for missing data and deadline in…
cpaniaguam 5e9f566
Refactor path handling for loading RLDM fixture dataset in tests
cpaniaguam 67ac2ce
Add fixture to set floatX to float32 for module tests
cpaniaguam e1c05df
Ensure params_is_trialwise aligns with list_params in RLSSM initializ…
cpaniaguam 564232b
Clarify comments on default_priors in ModelConfig and remove unnecess…
cpaniaguam bafc037
Update RLSSM to use to_numpy(copy=True) for extra_fields and add test…
cpaniaguam ba358a4
Refactor parameter name resolution in RLSSM to handle underscores cor…
cpaniaguam 0bfa755
Add test for _get_prefix method in RLSSM to ensure token-based matching
cpaniaguam 5b8a16a
Refactor RLSSMConfig.from_rlssm_dict to remove model_name parameter a…
cpaniaguam f69f2b6
Fix comment in test_rlssm.py to clarify output shape of log-likelihoo…
cpaniaguam bad943d
Update RLSSMConfig documentation to mark description as required
cpaniaguam 241aad2
Add ssm_logp_func to RLSSM_REQUIRED_FIELDS and update RLSSMConfig ini…
cpaniaguam ca3816d
Add dummy ssm_logp_func to tests and validate its presence in RLSSMCo…
cpaniaguam 827025c
Remove unused logging import from rlssm.py
cpaniaguam 292d6f0
Remove redundant exclude rule for ruff-format in pre-commit configura…
cpaniaguam 3b1aaf4
Add to_model_config method to RLSSMConfig for ModelConfig conversion
cpaniaguam 0678c45
Refactor RLSSM to delegate ModelConfig construction to RLSSMConfig an…
cpaniaguam 26a9336
Integrate Config and RLSSMConfig into HSSM and RLSSM classes for impr…
cpaniaguam cd660da
Update choices type from list to tuple for consistency in BaseModelCo…
cpaniaguam db308c6
Update choices type from list to tuple in test_constructor for consis…
cpaniaguam 9e5e8be
Add deprecation warnings for model_config attributes in HSSMBase
cpaniaguam 74acc67
Refactor HSSMBase to support BaseModelConfig and improve model_config…
cpaniaguam 2acc19d
Add model configuration building methods to BaseModelConfig and Confi…
cpaniaguam 97bf90f
Refactor model configuration handling in HSSMBase and HSSM classes to…
cpaniaguam f75f5e4
Add properties to BaseModelConfig for parameter and extra field counts
cpaniaguam b9b4839
Refactor RLSSM attributes to use public naming convention for configu…
cpaniaguam 821392f
Refactor test_rlssm_panel_attrs to use public attributes for particip…
cpaniaguam 07169c9
Refactor HSSMBase to streamline model configuration handling and upda…
cpaniaguam 0395ec2
Refactor BaseModelConfig and RLSSMConfig by removing unused abstract …
cpaniaguam 626301d
Refactor HSSM class to remove Config inheritance and add initializati…
cpaniaguam 6c13443
Refactor RLSSM class to remove RLSSMConfig inheritance and streamline…
cpaniaguam 66296f0
Refactor Config and RLSSMConfig classes to use concrete types in meth…
cpaniaguam 801f235
Update Config class parameter types for choices to improve type safety
cpaniaguam 607874c
Update choices method to accept a tuple for model_config.choices
cpaniaguam 9e25a32
Add tests for model configuration handling and choices logic in Config
cpaniaguam 2b2d66b
Enhance HSSMBase initialization with safe default for constructor arg…
cpaniaguam d5b9d80
Update model_config validation to check for non-null choices
cpaniaguam 9bf18ea
Refactor HSSM distribution method to use typed model_config attribute…
cpaniaguam 9af3e95
Update test cases to use tuples for choices in model configuration
cpaniaguam c2e09d9
Refactor RLSSM to utilize model_config for list_params and loglik, en…
cpaniaguam 1aa19f2
Fix typo in comment regarding model_config choices validation
cpaniaguam 0ea0998
Refactor RLSSM tests to access model configuration attributes directl…
cpaniaguam 8f526f4
Update attribute comparison in compare_hssm_class_attributes to use m…
cpaniaguam 5dd68a5
Update test assertions to access model configuration attributes directly
cpaniaguam 7054ccd
Refactor model configuration normalization to streamline choices hand…
cpaniaguam 5e816bc
Refactor choices handling in Config class to improve clarity and logging
cpaniaguam 9f6a7ef
Refactor _normalize_model_config_with_choices to improve input handli…
cpaniaguam 49415ab
Refactor likelihood callable construction to simplify logic and enhan…
cpaniaguam 4452f36
Refactor _make_model_distribution to utilize model_config for loglik …
cpaniaguam c34e562
Fix formatting in HSSM class for consistency in likelihood callable p…
cpaniaguam 3e86974
Fix formatting in HSSM class for consistency in likelihood callable p…
cpaniaguam 4a5aefc
Refactor HSSM class to use typed model_config attributes directly and…
cpaniaguam cdc7763
Restore make_model_dist in HSSM
cpaniaguam e3cbcb7
Remove deprecated properties and methods from HSSMBase class
cpaniaguam 7e481e0
Enhance HSSMBase class to prevent overwriting _init_args if already s…
cpaniaguam 3432bca
Clarify model_config parameter documentation in HSSMBase class to spe…
cpaniaguam 31bd6f1
Enhance HSSMBase class documentation to clarify filtering of internal…
cpaniaguam 296810b
Update model_config parameter documentation in HSSM class to support …
cpaniaguam 95779bc
Add test to validate external model config fallback in _build_model_c…
cpaniaguam 37ea9be
Update sampling parameters in test_rlssm_sample_smoke for speed
cpaniaguam 9a37dd8
Add RLSSM quickstart notebook for model instantiation and sampling de…
cpaniaguam 43ec652
Add RLSSM Quickstart tutorial to navigation and plugins
cpaniaguam 7f1e6ff
Remove redundant next steps and streamline summary in RLSSM quickstar…
cpaniaguam aec1531
Refactor RLSSMConfig methods to simplify parameter handling and remov…
cpaniaguam a8cd51d
Fix handling of list_params in HSSMBase to ensure proper conversion f…
cpaniaguam 9c22e26
Refactor RLSSM to inject model configuration directly, removing unnec…
cpaniaguam 5658834
Update TestRLSSMConfigDefaults to reflect None for default parameters…
cpaniaguam 7a294af
Refactor RLSSM to inject loglik and backend directly into a new RLSSM…
cpaniaguam fd99efb
Add validation for missing bounds in RLSSMConfig parameters
cpaniaguam bc0f7ca
Fix RLSSM to use model_config for ssm_logp_func and update test cases…
cpaniaguam b075e4f
Enhance RLSSM tests to align params_is_trialwise with list_params and…
cpaniaguam 27d505e
Add test to ensure RLSSMConfig.from_defaults raises NotImplementedError
cpaniaguam ce8e187
Clarify RLSSMConfig.from_defaults behavior and raise NotImplementedEr…
cpaniaguam 7c7fd32
Inject JAX backend into RLSSMConfig during initialization
cpaniaguam e604406
Refactor RLSSM class to use model_config instead of rlssm_config for …
cpaniaguam 582a6fe
Merge branch '930-pass-configs-via-dependency-injection-into-model-cl…
cpaniaguam a3898d7
Fix merge conflicts with base branch
cpaniaguam 4d99410
Remove commented out lines
cpaniaguam f04f47e
Remove RLSSMConfig import from __init__.py
cpaniaguam 11115af
Reorganize import statements by moving RLSSMConfig import to the corr…
cpaniaguam 6a9384f
Move RLSSMConfig import to the correct module in test files
cpaniaguam 0285f04
Update docstring in __init__.py and exports
cpaniaguam 5807a71
Remove RLSSMConfig class and its associated methods from config.py
cpaniaguam 4bf67ea
Move RLSSMConfig class hssm.rl module
cpaniaguam 5d74bfe
Refactor config.py to remove RLSSM-specific defaults and unify observ…
cpaniaguam ef000cd
Fix formatting of error messages in TestRLSSMConfigValidation for con…
cpaniaguam 91b1098
Enhance validation in RLSSMConfig for ssm_logp_func attributes
cpaniaguam c3a4f52
Add validation test for non-callable values in ssm_logp_func.computed
cpaniaguam 692dc5d
Rename 'learning_process_loglik_kind' to 'learning_process_kind' in R…
cpaniaguam 7cf8bca
Simplify response and list_params assignment in HSSMBase by removing …
cpaniaguam c46f923
Revert "Simplify response and list_params assignment in HSSMBase by r…
cpaniaguam fac838e
Refactor RLSSMConfig to dynamically retrieve required fields for vali…
cpaniaguam 3084df4
Update RLSSMConfig to handle field exceptions in from_rlssm_dict method
cpaniaguam babee92
Merge pull request #936 from lnccbrown/inject-RLSSMConfig-directly-in…
cpaniaguam b0e2179
Merge pull request #931 from lnccbrown/930-pass-configs-via-dependenc…
cpaniaguam 1233cd7
Fix import path for RLSSM and RLSSMConfig; correct learning_process_l…
cpaniaguam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,351 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "1b9b429d", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# RLSSM Quickstart: Instantiation, Model Building, and Sampling\n", | ||
| "\n", | ||
| "This notebook provides a minimal end-to-end demonstration of the `RLSSM` class:\n", | ||
| "\n", | ||
| "1. **Load** a balanced-panel two-armed bandit dataset\n", | ||
| "2. **Define** an annotated learning function and the angle SSM log-likelihood\n", | ||
| "3. **Configure** and **instantiate** an `RLSSM` model\n", | ||
| "4. **Inspect** the built Bambi / PyMC model\n", | ||
| "5. **Run** a minimal 2-draw sampling smoke test\n", | ||
| "\n", | ||
| "For a full treatment — simulating data, hierarchical formulas, meaningful sampling, and posterior visualization — see:\n", | ||
| "- [rlssm_tutorial.ipynb](rlssm_tutorial.ipynb)\n", | ||
| "- [add_custom_rlssm_model.ipynb](add_custom_rlssm_model.ipynb)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "bf38d7f7", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 1. Imports and Setup" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "6d764731", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from pathlib import Path\n", | ||
| "\n", | ||
| "import jax.numpy as jnp\n", | ||
| "import numpy as np\n", | ||
| "import pandas as pd\n", | ||
| "\n", | ||
| "import hssm\n", | ||
| "from hssm.rl import RLSSM, RLSSMConfig\n", | ||
| "from hssm.distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx\n", | ||
| "from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise\n", | ||
| "from hssm.utils import annotate_function\n", | ||
| "\n", | ||
| "# RLSSM requires float32 throughout (JAX default).\n", | ||
| "hssm.set_floatX(\"float32\", update_jax=True)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "df12303f", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 2. Load the Dataset\n", | ||
| "\n", | ||
| "We use a small synthetic two-armed bandit dataset from the HSSM test fixtures. \n", | ||
| "It is a **balanced panel**: every participant has the same number of trials. \n", | ||
| "Columns: `participant_id`, `trial_id`, `rt`, `response`, `feedback`.\n", | ||
| "\n", | ||
| "> **Note:** You can also generate data with\n", | ||
| "> [`ssm-simulators`](https://github.com/AlexanderFengler/ssm-simulators).\n", | ||
| "> See `rlssm_tutorial.ipynb` for an example." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "c2ef5f6e", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Path relative to docs/tutorials/ when running inside the HSSM repo.\n", | ||
| "_fixture_path = Path(\"../../tests/fixtures/rldm_data.npy\")\n", | ||
| "raw = np.load(_fixture_path, allow_pickle=True).item()\n", | ||
| "data = pd.DataFrame(raw[\"data\"])\n", | ||
| "\n", | ||
| "n_participants = data[\"participant_id\"].nunique()\n", | ||
| "n_trials = len(data) // n_participants\n", | ||
| "\n", | ||
| "print(data.head())\n", | ||
| "print(f\"\\nParticipants: {n_participants} | Trials per participant: {n_trials}\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "8c310290", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 3. Define the Learning Process\n", | ||
| "\n", | ||
| "The RL learning process is a JAX function that, given a subject's trial sequence, computes\n", | ||
| "the trial-wise drift rate `v` via a Q-learning update rule. \n", | ||
| "\n", | ||
| "`annotate_function` attaches `.inputs`, `.outputs`, and (optionally) `.computed` metadata\n", | ||
| "that the RLSSM likelihood builder uses to automatically construct the input matrix for the\n", | ||
| "decision process.\n", | ||
| "\n", | ||
| "- **inputs** — columns that the function reads (free parameters + data columns)\n", | ||
| "- **outputs** — what the function produces (here: `v`, the drift rate)\n", | ||
| "\n", | ||
| "Here we annotate the built-in `compute_v_subject_wise` function, which implements a simple\n", | ||
| "Rescorla-Wagner Q-learning update for a two-armed bandit task." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "bbcea122", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "compute_v_annotated = annotate_function(\n", | ||
| " inputs=[\"rl_alpha\", \"scaler\", \"response\", \"feedback\"],\n", | ||
| " outputs=[\"v\"],\n", | ||
| ")(compute_v_subject_wise)\n", | ||
| "\n", | ||
| "print(\"Learning function inputs :\", compute_v_annotated.inputs)\n", | ||
| "print(\"Learning function outputs:\", compute_v_annotated.outputs)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "7a03305a", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 4. Define the Decision (SSM) Log-Likelihood\n", | ||
| "\n", | ||
| "The decision process uses the **angle model** likelihood, loaded from an ONNX file.\n", | ||
| "`make_jax_matrix_logp_funcs_from_onnx` returns a JAX callable that accepts a\n", | ||
| "2-D matrix whose columns are `[v, a, z, t, theta, rt, response]` and returns\n", | ||
| "per-trial log-probabilities.\n", | ||
| "\n", | ||
| "We then annotate that callable so the builder knows:\n", | ||
| "- which columns the matrix contains (`inputs`)\n", | ||
| "- that `v` itself is *computed* by the learning function (not a free parameter)\n", | ||
| "\n", | ||
| "The ONNX file is loaded from the local test fixture when running inside the HSSM\n", | ||
| "repository; otherwise it is downloaded from the HuggingFace Hub (`franklab/HSSM`)." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "60bbc036", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Use the local fixture when available; fall back to HuggingFace download.\n", | ||
| "_local_onnx = Path(\"../../tests/fixtures/angle.onnx\").resolve()\n", | ||
| "_onnx_model = str(_local_onnx) if _local_onnx.exists() else \"angle.onnx\"\n", | ||
| "\n", | ||
| "_angle_logp_jax = make_jax_matrix_logp_funcs_from_onnx(model=_onnx_model)\n", | ||
| "\n", | ||
| "angle_logp_func = annotate_function(\n", | ||
| " inputs=[\"v\", \"a\", \"z\", \"t\", \"theta\", \"rt\", \"response\"],\n", | ||
| " outputs=[\"logp\"],\n", | ||
| " computed={\"v\": compute_v_annotated},\n", | ||
| ")(_angle_logp_jax)\n", | ||
| "\n", | ||
| "print(\"SSM logp inputs :\", angle_logp_func.inputs)\n", | ||
| "print(\"SSM logp outputs:\", angle_logp_func.outputs)\n", | ||
| "print(\"Computed deps :\", list(angle_logp_func.computed.keys()))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "cf8f5b63", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 5. Configure the Model with `RLSSMConfig`\n", | ||
| "\n", | ||
| "`RLSSMConfig` collects all the information the RLSSM class needs:\n", | ||
| "\n", | ||
| "| Field | Purpose |\n", | ||
| "|-------|---------|\n", | ||
| "| `model_name` | Identifier string for the configuration |\n", | ||
| "| `decision_process` | Name of the SSM (e.g. `\"angle\"`) |\n", | ||
| "| `list_params` | Ordered list of *free* parameters to sample |\n", | ||
| "| `params_default` | Starting / default values for each parameter |\n", | ||
| "| `bounds` | Prior bounds for each parameter |\n", | ||
| "| `learning_process` | Dict mapping computed param name → annotated learning function |\n", | ||
| "| `extra_fields` | Extra data columns required by the learning function |\n", | ||
| "| `ssm_logp_func` | Annotated JAX callable for the decision-process likelihood |" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "4beba1bc", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "rlssm_config = RLSSMConfig(\n", | ||
| " model_name=\"rlssm_angle_quickstart\",\n", | ||
| " loglik_kind=\"approx_differentiable\",\n", | ||
| " decision_process=\"angle\",\n", | ||
| " decision_process_loglik_kind=\"approx_differentiable\",\n", | ||
| " learning_process_kind=\"blackbox\",\n", | ||
| " list_params=[\"rl_alpha\", \"scaler\", \"a\", \"theta\", \"t\", \"z\"],\n", | ||
| " params_default=[0.1, 1.0, 1.0, 0.0, 0.3, 0.5],\n", | ||
| " bounds={\n", | ||
| " \"rl_alpha\": (0.0, 1.0),\n", | ||
| " \"scaler\": (0.0, 10.0),\n", | ||
| " \"a\": (0.1, 3.0),\n", | ||
| " \"theta\": (-0.1, 0.1),\n", | ||
| " \"t\": (0.001, 1.0),\n", | ||
| " \"z\": (0.1, 0.9),\n", | ||
| " },\n", | ||
| " learning_process={\"v\": compute_v_annotated},\n", | ||
| " response=[\"rt\", \"response\"],\n", | ||
| " choices=[0, 1],\n", | ||
| " extra_fields=[\"feedback\"],\n", | ||
| " ssm_logp_func=angle_logp_func,\n", | ||
| ")\n", | ||
| "\n", | ||
| "print(\"Model name :\", rlssm_config.model_name)\n", | ||
| "print(\"Free params :\", rlssm_config.list_params)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "924ee4c7", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 6. Instantiate the `RLSSM` Model\n", | ||
| "\n", | ||
| "Passing `data` and `rlssm_config` to `RLSSM`:\n", | ||
| "\n", | ||
| "- validates the balanced-panel requirement\n", | ||
| "- builds a differentiable PyTensor Op that chains the RL learning step and the\n", | ||
| " angle log-likelihood\n", | ||
| "- constructs the Bambi / PyMC model internally\n", | ||
| "\n", | ||
| "Note that `v` (the drift rate) is *not* a free parameter — it is computed inside\n", | ||
| "the Op by the Q-learning update and therefore does not appear in `model.params`." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "1f8da79a", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "model = RLSSM(data=data, model_config=rlssm_config)\n", | ||
| "\n", | ||
| "assert isinstance(model, RLSSM)\n", | ||
| "print(\"Model type :\", type(model).__name__)\n", | ||
| "print(\"Participants :\", model.n_participants)\n", | ||
| "print(\"Trials/subj :\", model.n_trials)\n", | ||
| "print(\"Free parameters :\", list(model.params.keys()))\n", | ||
| "assert \"rl_alpha\" in model.params, \"rl_alpha must be a free parameter\"\n", | ||
| "assert \"v\" not in model.params, \"v is computed, not a free parameter\"\n", | ||
| "model" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "f7f39940", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 7. Inspect the Built Model\n", | ||
| "\n", | ||
| "After construction, `model.model` exposes the underlying **Bambi model** and\n", | ||
| "`model.pymc_model` exposes the **PyMC model** context — useful for debugging\n", | ||
| "or customizing priors." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "b0558ad4", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "print(\"=== Bambi model ===\")\n", | ||
| "print(model.model)\n", | ||
| "\n", | ||
| "print(\"\\n=== PyMC model ===\")\n", | ||
| "print(model.pymc_model)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "f4e50110", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 8. Sampling\n", | ||
| "\n", | ||
| "A minimal sampling run — 2 draws, 2 tuning steps, 1 chain — confirms that the full\n", | ||
| "computational graph (Q-learning scan → angle logp → NUTS gradient) is wired correctly." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "96ce3238", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "trace = model.sample(draws=2, tune=2, chains=1, cores=1, sampler=\"numpyro\", target_accept=0.9)\n", | ||
| "\n", | ||
| "assert trace is not None\n", | ||
| "print(trace)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "a784a468", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Summary\n", | ||
| "\n", | ||
| "This notebook showed how to:\n", | ||
| "\n", | ||
| "1. Load a balanced-panel dataset (`rldm_data.npy`)\n", | ||
| "2. Annotate a Q-learning function with `annotate_function`\n", | ||
| "3. Load the angle ONNX likelihood and annotate it so the builder can assemble the input matrix\n", | ||
| "4. Define an `RLSSMConfig` and pass it to `RLSSM`\n", | ||
| "5. Confirm model structure (free params, Bambi / PyMC objects)\n", | ||
| "6. Run a 2-draw sampling smoke test that returns an `arviz.InferenceData` object" | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "hssm", | ||
| "language": "python", | ||
| "name": "python3" | ||
| }, | ||
| "language_info": { | ||
| "codemirror_mode": { | ||
| "name": "ipython", | ||
| "version": 3 | ||
| }, | ||
| "file_extension": ".py", | ||
| "mimetype": "text/x-python", | ||
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.13.1" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 5 | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.