Skip to content

Commit 00749de

Browse files
Strengthen Case 2 unit test: verify full probability distribution
Replace test_top_prediction_on_tail with test_tail_high_probability. Now checks tail (atoms 0-9) high probability AND non-tail low probability. Before bugfix (d858ff2): tail_sum=0.3755, FAILS (< 0.5 threshold) After bugfix (5918489): tail_sum=0.7433, PASSES
1 parent 5918489 commit 00749de

1 file changed

Lines changed: 20 additions & 8 deletions

File tree

tests/test_real_cases.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,29 @@ def test_generate_probabilities(self):
136136
self.assertAlmostEqual(np.sum(scores), 1.0, places=5)
137137
self.assertTrue(np.all(scores >= 0))
138138

139-
def test_top_prediction_on_tail(self):
139+
def test_tail_high_probability(self):
140140
"""The modification is on the alkyl tail (atoms 0-9).
141-
The top predicted atom should be in that region."""
141+
Tail atoms should have high probability and non-tail atoms should have low probability."""
142142
scores = self.mf.generate_probabilities()
143-
predicted_site = int(np.argmax(scores))
144143
tail_atoms = list(range(0, 10))
145-
self.assertIn(
146-
predicted_site,
147-
tail_atoms,
148-
f"Predicted site {predicted_site} is not on the alkyl tail (atoms 0-9)",
149-
)
144+
non_tail_atoms = list(range(10, len(scores)))
145+
146+
tail_sum = sum(scores[i] for i in tail_atoms)
147+
tail_mean = np.mean([scores[i] for i in tail_atoms])
148+
non_tail_mean = np.mean([scores[i] for i in non_tail_atoms])
149+
150+
# Tail should hold the majority of probability mass
151+
self.assertGreater(tail_sum, 0.5,
152+
f"Tail probability sum {tail_sum:.4f} should be > 0.5")
153+
154+
# Tail mean should be significantly higher than non-tail mean
155+
self.assertGreater(tail_mean, non_tail_mean * 3,
156+
f"Tail mean {tail_mean:.4f} should be much higher than non-tail mean {non_tail_mean:.4f}")
157+
158+
# Every non-tail atom should have low probability
159+
for i in non_tail_atoms:
160+
self.assertLess(scores[i], 0.03,
161+
f"Non-tail atom {i} has unexpectedly high probability {scores[i]:.4f}")
150162

151163
def test_get_edge_detail_no_side_effect(self):
152164
"""Previously, calling get_edge_detail should not mutate the original edge matches.

0 commit comments

Comments
 (0)