Skip to content

Commit 0c4e0fb

Browse files
GiggleLiuclaude
andauthored
test: add comprehensive non-batched GPU gradient tests (#33)
Add numerical gradient verification and optimization tests for GPU: - test_gpu_numerical_gradient_maxplus - test_gpu_numerical_gradient_minplus - test_gpu_numerical_gradient_maxmul - test_gpu_minplus_optimization - test_gpu_maxmul_optimization Total GPU tests: 22 (was 17) Total tests: 207 (was 202) Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent a270d5d commit 0c4e0fb

1 file changed

Lines changed: 157 additions & 0 deletions

File tree

crates/tropical-gemm-python/tests/test_pytorch_gradients.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,163 @@ def test_gpu_gradient_sparsity():
11641164
"Total gradient should equal number of output elements"
11651165

11661166

1167+
@pytest.mark.skipif(not GPU_AVAILABLE, reason="CUDA not available")
1168+
def test_gpu_numerical_gradient_maxplus():
1169+
"""Verify GPU MaxPlus gradients using finite differences."""
1170+
torch.manual_seed(42)
1171+
1172+
m, k, n = 3, 4, 3
1173+
a = (torch.randn(m, k, device="cuda") * 3).requires_grad_(True)
1174+
b = (torch.randn(k, n, device="cuda") * 3).requires_grad_(True)
1175+
1176+
c = tropical_maxplus_matmul(a, b)
1177+
loss = c.sum()
1178+
loss.backward()
1179+
1180+
analytical_grad_a = a.grad.clone()
1181+
1182+
eps = 1e-4
1183+
numerical_grad_a = torch.zeros_like(a)
1184+
1185+
for i in range(m):
1186+
for j in range(k):
1187+
a_plus = a.detach().clone()
1188+
a_plus[i, j] += eps
1189+
a_minus = a.detach().clone()
1190+
a_minus[i, j] -= eps
1191+
1192+
c_plus = tropical_maxplus_matmul(a_plus, b.detach()).sum()
1193+
c_minus = tropical_maxplus_matmul(a_minus, b.detach()).sum()
1194+
1195+
numerical_grad_a[i, j] = (c_plus - c_minus) / (2 * eps)
1196+
1197+
assert torch.allclose(analytical_grad_a, numerical_grad_a, atol=0.1), \
1198+
"GPU MaxPlus numerical gradient mismatch"
1199+
1200+
1201+
@pytest.mark.skipif(not GPU_AVAILABLE, reason="CUDA not available")
1202+
def test_gpu_numerical_gradient_minplus():
1203+
"""Verify GPU MinPlus gradients using finite differences."""
1204+
torch.manual_seed(42)
1205+
1206+
m, k, n = 3, 4, 3
1207+
a = (torch.randn(m, k, device="cuda") * 3).requires_grad_(True)
1208+
b = (torch.randn(k, n, device="cuda") * 3).requires_grad_(True)
1209+
1210+
c = tropical_minplus_matmul(a, b)
1211+
loss = c.sum()
1212+
loss.backward()
1213+
1214+
analytical_grad_a = a.grad.clone()
1215+
1216+
eps = 1e-4
1217+
numerical_grad_a = torch.zeros_like(a)
1218+
1219+
for i in range(m):
1220+
for j in range(k):
1221+
a_plus = a.detach().clone()
1222+
a_plus[i, j] += eps
1223+
a_minus = a.detach().clone()
1224+
a_minus[i, j] -= eps
1225+
1226+
c_plus = tropical_minplus_matmul(a_plus, b.detach()).sum()
1227+
c_minus = tropical_minplus_matmul(a_minus, b.detach()).sum()
1228+
1229+
numerical_grad_a[i, j] = (c_plus - c_minus) / (2 * eps)
1230+
1231+
assert torch.allclose(analytical_grad_a, numerical_grad_a, atol=0.1), \
1232+
"GPU MinPlus numerical gradient mismatch"
1233+
1234+
1235+
@pytest.mark.skipif(not GPU_AVAILABLE, reason="CUDA not available")
1236+
def test_gpu_numerical_gradient_maxmul():
1237+
"""Verify GPU MaxMul gradients using finite differences."""
1238+
torch.manual_seed(42)
1239+
1240+
m, k, n = 3, 4, 3
1241+
# Use positive values for MaxMul
1242+
a = (torch.randn(m, k, device="cuda").abs() * 2 + 0.5).requires_grad_(True)
1243+
b = (torch.randn(k, n, device="cuda").abs() * 2 + 0.5).requires_grad_(True)
1244+
1245+
c = tropical_maxmul_matmul(a, b)
1246+
loss = c.sum()
1247+
loss.backward()
1248+
1249+
analytical_grad_a = a.grad.clone()
1250+
1251+
eps = 1e-4
1252+
numerical_grad_a = torch.zeros_like(a)
1253+
1254+
for i in range(m):
1255+
for j in range(k):
1256+
a_plus = a.detach().clone()
1257+
a_plus[i, j] += eps
1258+
a_minus = a.detach().clone()
1259+
a_minus[i, j] -= eps
1260+
1261+
c_plus = tropical_maxmul_matmul(a_plus, b.detach()).sum()
1262+
c_minus = tropical_maxmul_matmul(a_minus, b.detach()).sum()
1263+
1264+
numerical_grad_a[i, j] = (c_plus - c_minus) / (2 * eps)
1265+
1266+
assert torch.allclose(analytical_grad_a, numerical_grad_a, atol=0.1), \
1267+
"GPU MaxMul numerical gradient mismatch"
1268+
1269+
1270+
@pytest.mark.skipif(not GPU_AVAILABLE, reason="CUDA not available")
1271+
def test_gpu_minplus_optimization():
1272+
"""Test that GPU MinPlus gradients enable optimization to converge."""
1273+
torch.manual_seed(42)
1274+
1275+
a = torch.randn(8, 6, device="cuda", requires_grad=True)
1276+
b = torch.randn(6, 10, device="cuda")
1277+
target = torch.randn(8, 10, device="cuda")
1278+
1279+
optimizer = torch.optim.SGD([a], lr=0.1)
1280+
1281+
initial_loss = None
1282+
for step in range(10):
1283+
optimizer.zero_grad()
1284+
c = tropical_minplus_matmul(a, b)
1285+
loss = torch.nn.functional.mse_loss(c, target)
1286+
if initial_loss is None:
1287+
initial_loss = loss.item()
1288+
loss.backward()
1289+
optimizer.step()
1290+
1291+
assert loss.item() < initial_loss, \
1292+
f"GPU MinPlus loss should decrease: {initial_loss} -> {loss.item()}"
1293+
1294+
1295+
@pytest.mark.skipif(not GPU_AVAILABLE, reason="CUDA not available")
1296+
def test_gpu_maxmul_optimization():
1297+
"""Test that GPU MaxMul gradients enable optimization to converge."""
1298+
torch.manual_seed(42)
1299+
1300+
# Use positive values for MaxMul
1301+
a = (torch.randn(8, 6, device="cuda").abs() + 0.1).requires_grad_(True)
1302+
b = (torch.randn(6, 10, device="cuda").abs() + 0.1)
1303+
target = torch.randn(8, 10, device="cuda").abs() + 1.0
1304+
1305+
optimizer = torch.optim.SGD([a], lr=0.01)
1306+
1307+
initial_loss = None
1308+
for step in range(10):
1309+
optimizer.zero_grad()
1310+
c = tropical_maxmul_matmul(a, b)
1311+
loss = torch.nn.functional.mse_loss(c, target)
1312+
if initial_loss is None:
1313+
initial_loss = loss.item()
1314+
loss.backward()
1315+
optimizer.step()
1316+
# Keep values positive for MaxMul
1317+
with torch.no_grad():
1318+
a.clamp_(min=0.01)
1319+
1320+
assert loss.item() < initial_loss, \
1321+
f"GPU MaxMul loss should decrease: {initial_loss} -> {loss.item()}"
1322+
1323+
11671324
# ============================================================================
11681325
# Batched GPU Gradient Tests
11691326
# ============================================================================

0 commit comments

Comments
 (0)