Skip to content

fix: resolve NameError when MultiHeadAttention is called with w_init=None#12

Open
exopoiesis wants to merge 1 commit intodeepmodeling:mainfrom
exopoiesis:fix/deps-modernization
Open

fix: resolve NameError when MultiHeadAttention is called with w_init=None#12
exopoiesis wants to merge 1 commit intodeepmodeling:mainfrom
exopoiesis:fix/deps-modernization

Conversation

@exopoiesis
Copy link
Copy Markdown

@exopoiesis exopoiesis commented Mar 30, 2026

Problem

MultiHeadAttention.__init__ references w_init_scale in the fallback branch:

if w_init is None:
    w_init = hk.initializers.VarianceScaling(w_init_scale)  # NameError

w_init_scale is never defined in this scope, so any call with the default
w_init=None raises NameError: name 'w_init_scale' is not defined.

Fix

Replace the undefined variable with the literal 1.0, which matches the
upstream haiku VarianceScaling default:

if w_init is None:
    w_init = hk.initializers.VarianceScaling(1.0)

Tests

Two focused regression tests in tests/test_attention.py:

  • test_w_init_none_does_not_raise — exercises the formerly-broken default path
  • test_w_init_explicit_still_works — confirms explicit w_init is unaffected

Scope

This PR is intentionally minimal: only crystalformer/src/attention.py and
tests/test_attention.py are changed. No tree API or extension module changes.

Fixes #11

@exopoiesis exopoiesis changed the title fix: modernize deprecated JAX APIs and relax pinned dependencies fix: deprecated JAX APIs and undefined w_init_scale Mar 30, 2026
@zdcao121
Copy link
Copy Markdown
Collaborator

Hi @exopoiesis, thanks for the PR and the detailed tests.

Two notes from our side:

  1. jax.tree_util.tree_map is an alias of jax.tree.map() for backward compatibility with older JAX versions, so this replacement is not strictly required for us right now.
  2. The extension module is deprecated, so changes there are lower priority.

We still appreciate the w_init=None fix and the overall maintenance effort.

…None

`w_init_scale` was referenced but never defined, causing a NameError
whenever MultiHeadAttention is instantiated with the default w_init=None.
Fix replaces the undefined variable with the literal 1.0, which matches
the upstream haiku VarianceScaling default.

Adds two focused regression tests:
- test_w_init_none_does_not_raise: exercises the formerly-broken code path
- test_w_init_explicit_still_works: confirms explicit w_init is unaffected

Fixes deepmodeling#11
@exopoiesis exopoiesis force-pushed the fix/deps-modernization branch from d5a8866 to d3ee9c3 Compare March 31, 2026 08:00
@exopoiesis exopoiesis changed the title fix: deprecated JAX APIs and undefined w_init_scale fix: resolve NameError when MultiHeadAttention is called with w_init=None Mar 31, 2026
@exopoiesis
Copy link
Copy Markdown
Author

Slimmed down to just the w_init fix + regression test as suggested. Tree_map and extension-module changes dropped.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Dependency modernization: deprecated JAX APIs + future Flax NNX migration

2 participants