Skip to content

Add block_marginal_sqrt_cov and jit support#177

Merged
SamDuffield merged 3 commits intomainfrom
remove-assertions
Feb 7, 2026
Merged

Add block_marginal_sqrt_cov and jit support#177
SamDuffield merged 3 commits intomainfrom
remove-assertions

Conversation

@SamDuffield
Copy link
Contributor

@SamDuffield SamDuffield requested a review from Sahel13 February 5, 2026 16:31
@Sahel13
Copy link
Collaborator

Sahel13 commented Feb 6, 2026

This is not necessary, the shape of chol_cov is static, and we cannot jit start and end anyway because the output shape depends on it.

This test for the old code (with assertions) runs fine (we should probably add this)

import chex
import jax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized
from jax import random

from cuthbertlib.linalg.marginal_sqrt_cov import marginal_sqrt_cov


@pytest.fixture(scope="module", autouse=True)
def config():
    jax.config.update("jax_enable_x64", True)
    yield
    jax.config.update("jax_enable_x64", False)


class TestMarginalSqrtCov(chex.TestCase):
    @chex.variants(with_jit=True, without_jit=True)
    @parameterized.product(
        seed=[0, 42, 99],
        block=[
            (6, 0, 3),  # top-left block
            (6, 3, 6),  # bottom-right block
            (8, 2, 5),  # middle block
            (10, 1, 9),  # large block
        ],
    )
    def test_marginal_sqrt_cov(self, seed, block):
        n, start, end = block
        key = random.key(seed)

        # Random lower-triangular joint square root
        L = jnp.tril(random.normal(key, (n, n)))

        # Extract marginal square root
        B = self.variant(
            marginal_sqrt_cov,
            static_argnames=("start", "end"),
        )(L, start, end)

        # Expected marginal covariance block
        Sigma = L @ L.T
        Sigma_block = Sigma[start:end, start:end]

        # Check B is lower triangular
        assert jnp.allclose(B, jnp.tril(B))

        # Check B B^T reproduces marginal covariance
        assert jnp.allclose(B @ B.T, Sigma_block)

@SamDuffield
Copy link
Contributor Author

I've pulled changes made in #178 and addressed here the comments. Let's just use this PR for fixing marginal_sqrt_cov. My bad for overcomplicating 😅

@SamDuffield SamDuffield changed the title Remove assertions from marginal_sqrt_cov Add block_marginal_sqrt_cov and jit support Feb 6, 2026
@SamDuffield SamDuffield merged commit 60df149 into main Feb 7, 2026
2 checks passed
@SamDuffield SamDuffield deleted the remove-assertions branch February 7, 2026 13:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants