It seems to me that there might be a mistake in the way the noised state is computed in the current implementation. Specifically
sam/sam_jax/training_utils
in line 537, forward_and_loss which includes the l2 regularization is used to compute grad this is then used in line 546 as input to dual_vector(grad).
I think this is not exactly correct given the original SAM paper. The state shouldn't be noised for the l2 regularization as it is now, but only for the cross-entropy loss. A separate gradient for the clean state should be computed for the l2 regularization and summed with the SAM gradient.
Is there something that I'm missing?
It seems to me that there might be a mistake in the way the noised state is computed in the current implementation. Specifically
sam/sam_jax/training_utils
in line 537,
forward_and_losswhich includes the l2 regularization is used to computegradthis is then used in line 546 as input todual_vector(grad).I think this is not exactly correct given the original SAM paper. The state shouldn't be noised for the l2 regularization as it is now, but only for the cross-entropy loss. A separate gradient for the clean state should be computed for the l2 regularization and summed with the SAM gradient.
Is there something that I'm missing?