Conversation
Adds multi-head attention pooling as an alternative to mean pooling for combining context embeddings in the deep set path. Controlled via attention_pooling, num_heads, and dropout_rate config options. Includes dropout support with proper RNG plumbing through the trainer.
Tests forward pass output shape, equivalence of output shape between mean and attention pooling modes, and dropout deterministic vs stochastic behavior.
jannisborn
left a comment
There was a problem hiding this comment.
Nice job @DhruvaRajwade, the attention pooling implementation looks correct!
Before merging, it would be great to expand the unit tests to:
- also test the case of overlapping context bonds as this is the setting we perform in the paper ablation study
- also test a case with 3 modalities
it should be easy to use pytest.parametrize theCONTEXT_BONDSrather than hard-coding it as class attribute.
Also the black formatter complains, you should be able to fix it via poetry run black .
Realizing that poetry is a bit outdated, I will refactor to uv in a subsequent PR if I find some time
Adds tests for overlapping context bonds, and a case with three modalities. context_bonds and dim_cond are now pytest.parametrize arguments instead of hardcoded class attributes.
|
Now isort is failing as well @DhruvaRajwade. Just run the CI commands locally before pushing, see here |
|
Ah, my bad, I realized I had a Vim autocmd that reformatted the buffers on save using |
jannisborn
left a comment
There was a problem hiding this comment.
Nice job, before merging let's just increase the package version in the pyproject.toml to 0.1.3
Summary
combining context embeddings in the deep set (
embed_cond_equal=True) path ofConditionalPerturbationNetworkattention_pooling(bool),num_heads(int),dropout_rate(float)attention_poolingdefaults toFalse, soexisting configs and checkpoints are unaffected
Usage
In your model config under
mlp:Files changed
deterministic param, dropout RNG in create_train_state
plumbing through step_fn and loss_fn
shape, output shape equivalence between pooling modes, and dropout
deterministic vs stochastic behavior
Tests
Three new tests in
cmonge/tests/models/test_attention_pooling.py:test_attention_pooling_forward_pass— InstantiatesConditionalPerturbationNetworkwithattention_pooling=True, runs a forwardpass, and verifies the output shape is
(batch_size, dim_data)and non-zero(confirming the residual connection works)
test_both_pooling_modes_same_output_shape— Instantiates two models,one with mean pooling (
attention_pooling=False) and one with attentionpooling (
attention_pooling=True), and verifies that both produce the same outputshape, ensuring they are drop-in replacements for each other
test_dropout_deterministic_vs_stochastic— Uses a high dropout rate(0.5) to verify that:
deterministic=True(eval mode): two calls produce identical outputsdeterministic=False(train mode): two calls with different dropout keysproduce different outputs