From d3ee9c33e4e7db77550b2118d9baa0d49276a170 Mon Sep 17 00:00:00 2001 From: Igor Morozov Date: Tue, 31 Mar 2026 11:00:28 +0300 Subject: [PATCH] fix: resolve NameError when MultiHeadAttention is called with w_init=None `w_init_scale` was referenced but never defined, causing a NameError whenever MultiHeadAttention is instantiated with the default w_init=None. Fix replaces the undefined variable with the literal 1.0, which matches the upstream haiku VarianceScaling default. Adds two focused regression tests: - test_w_init_none_does_not_raise: exercises the formerly-broken code path - test_w_init_explicit_still_works: confirms explicit w_init is unaffected Fixes deepmodeling/CrystalFormer#11 --- crystalformer/src/attention.py | 2 +- tests/test_attention.py | 55 ++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 tests/test_attention.py diff --git a/crystalformer/src/attention.py b/crystalformer/src/attention.py index bc07b30..9c18d8a 100644 --- a/crystalformer/src/attention.py +++ b/crystalformer/src/attention.py @@ -84,7 +84,7 @@ def __init__( self.dropout_rate = dropout_rate if w_init is None: - w_init = hk.initializers.VarianceScaling(w_init_scale) + w_init = hk.initializers.VarianceScaling(1.0) self.w_init = w_init self.with_bias = with_bias self.b_init = b_init diff --git a/tests/test_attention.py b/tests/test_attention.py new file mode 100644 index 0000000..64f8856 --- /dev/null +++ b/tests/test_attention.py @@ -0,0 +1,55 @@ +"""Tests for MultiHeadAttention -- focuses on the w_init=None fix. + +Before the fix, calling MultiHeadAttention with w_init=None (the default) +raised NameError: name 'w_init_scale' is not defined. +""" +from config import * + +from crystalformer.src.attention import MultiHeadAttention + + +def test_w_init_none_does_not_raise(): + """w_init=None (the default) must not raise NameError. + + Regression test for: NameError: name 'w_init_scale' is not defined. + The fix replaces the undefined variable with the literal value 1.0, + matching the upstream haiku default for VarianceScaling. + """ + def fn(q, k, v): + mha = MultiHeadAttention( + num_heads=2, + key_size=8, + model_size=16, + w_init=None, # default -- was broken before fix + ) + return mha(q, k, v) + + f = hk.without_apply_rng(hk.transform(fn)) + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (4, 16)) + params = f.init(key, x, x, x) + out = f.apply(params, x, x, x) + + assert out.shape == (4, 16) + assert jnp.isfinite(out).all(), "w_init=None path produces NaN/Inf" + + +def test_w_init_explicit_still_works(): + """Explicit w_init continues to work after the fix.""" + def fn(q, k, v): + mha = MultiHeadAttention( + num_heads=2, + key_size=8, + model_size=16, + w_init=hk.initializers.VarianceScaling(1.0), + ) + return mha(q, k, v) + + f = hk.without_apply_rng(hk.transform(fn)) + key = jax.random.PRNGKey(1) + x = jax.random.normal(key, (4, 16)) + params = f.init(key, x, x, x) + out = f.apply(params, x, x, x) + + assert out.shape == (4, 16) + assert jnp.isfinite(out).all()