Skip to content

Feature/attention pooling#33

Merged
jannisborn merged 5 commits intoAI4SCR:mainfrom
DhruvaRajwade:feature/attention-pooling
Mar 24, 2026
Merged

Feature/attention pooling#33
jannisborn merged 5 commits intoAI4SCR:mainfrom
DhruvaRajwade:feature/attention-pooling

Conversation

@DhruvaRajwade
Copy link
Contributor

Summary

  • Adds multi-head attention pooling as an alternative to mean pooling for
    combining context embeddings in the deep set (embed_cond_equal=True) path of
    ConditionalPerturbationNetwork
  • Controlled via three new config options: attention_pooling (bool),
    num_heads (int), dropout_rate (float)
  • Fully backwards compatible — attention_pooling defaults to False, so
    existing configs and checkpoints are unaffected
  • Includes dropout support with proper RNG plumbing through the trainer

Usage

In your model config under mlp:

mlp:                                                                          
  attention_pooling: true
  num_heads: 4                                                                
  dropout_rate: 0.1                                                           
  embed_cond_equal: true
  # ... other existing options 

Files changed

  • cmonge/models/nn.py — Attention pooling implementation, dropout layers,
    deterministic param, dropout RNG in create_train_state
  • cmonge/trainers/conditional_monge_trainer.py — Dropout key generation and
    plumbing through step_fn and loss_fn
  • cmonge/tests/models/test_attention_pooling.py — Tests for forward pass
    shape, output shape equivalence between pooling modes, and dropout
    deterministic vs stochastic behavior

Tests

Three new tests in cmonge/tests/models/test_attention_pooling.py:

  • test_attention_pooling_forward_pass — Instantiates
    ConditionalPerturbationNetwork with attention_pooling=True, runs a forward
    pass, and verifies the output shape is (batch_size, dim_data) and non-zero
    (confirming the residual connection works)

  • test_both_pooling_modes_same_output_shape — Instantiates two models,
    one with mean pooling (attention_pooling=False) and one with attention
    pooling (attention_pooling=True), and verifies that both produce the same output
    shape, ensuring they are drop-in replacements for each other

  • test_dropout_deterministic_vs_stochastic — Uses a high dropout rate
    (0.5) to verify that:

    • deterministic=True (eval mode): two calls produce identical outputs
    • deterministic=False (train mode): two calls with different dropout keys
      produce different outputs

Adds multi-head attention pooling as an alternative to mean pooling
for combining context embeddings in the deep set path. Controlled
via attention_pooling, num_heads, and dropout_rate config options.
Includes dropout support with proper RNG plumbing through the trainer.
Tests forward pass output shape, equivalence of output shape between
mean and attention pooling modes, and dropout deterministic vs
stochastic behavior.
Copy link
Collaborator

@jannisborn jannisborn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job @DhruvaRajwade, the attention pooling implementation looks correct!
Before merging, it would be great to expand the unit tests to:

  • also test the case of overlapping context bonds as this is the setting we perform in the paper ablation study
  • also test a case with 3 modalities
    it should be easy to use pytest.parametrize the CONTEXT_BONDS rather than hard-coding it as class attribute.

Also the black formatter complains, you should be able to fix it via poetry run black .
Realizing that poetry is a bit outdated, I will refactor to uv in a subsequent PR if I find some time

Adds tests for overlapping context bonds, and a case with three modalities. context_bonds and dim_cond are now pytest.parametrize arguments instead of hardcoded class attributes.
@jannisborn
Copy link
Collaborator

Now isort is failing as well @DhruvaRajwade. Just run the CI commands locally before pushing, see here

run: poetry run isort --profile=black --check-only .

@DhruvaRajwade
Copy link
Contributor Author

Ah, my bad, I realized I had a Vim autocmd that reformatted the buffers on save using prettier . All formatting/linting/import tests should pass now

Copy link
Collaborator

@jannisborn jannisborn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job, before merging let's just increase the package version in the pyproject.toml to 0.1.3

@jannisborn jannisborn merged commit 0b14293 into AI4SCR:main Mar 24, 2026
1 check passed
@DhruvaRajwade DhruvaRajwade deleted the feature/attention-pooling branch March 27, 2026 00:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants