@@ -1164,6 +1164,163 @@ def test_gpu_gradient_sparsity():
11641164 "Total gradient should equal number of output elements"
11651165
11661166
1167+ @pytest .mark .skipif (not GPU_AVAILABLE , reason = "CUDA not available" )
1168+ def test_gpu_numerical_gradient_maxplus ():
1169+ """Verify GPU MaxPlus gradients using finite differences."""
1170+ torch .manual_seed (42 )
1171+
1172+ m , k , n = 3 , 4 , 3
1173+ a = (torch .randn (m , k , device = "cuda" ) * 3 ).requires_grad_ (True )
1174+ b = (torch .randn (k , n , device = "cuda" ) * 3 ).requires_grad_ (True )
1175+
1176+ c = tropical_maxplus_matmul (a , b )
1177+ loss = c .sum ()
1178+ loss .backward ()
1179+
1180+ analytical_grad_a = a .grad .clone ()
1181+
1182+ eps = 1e-4
1183+ numerical_grad_a = torch .zeros_like (a )
1184+
1185+ for i in range (m ):
1186+ for j in range (k ):
1187+ a_plus = a .detach ().clone ()
1188+ a_plus [i , j ] += eps
1189+ a_minus = a .detach ().clone ()
1190+ a_minus [i , j ] -= eps
1191+
1192+ c_plus = tropical_maxplus_matmul (a_plus , b .detach ()).sum ()
1193+ c_minus = tropical_maxplus_matmul (a_minus , b .detach ()).sum ()
1194+
1195+ numerical_grad_a [i , j ] = (c_plus - c_minus ) / (2 * eps )
1196+
1197+ assert torch .allclose (analytical_grad_a , numerical_grad_a , atol = 0.1 ), \
1198+ "GPU MaxPlus numerical gradient mismatch"
1199+
1200+
1201+ @pytest .mark .skipif (not GPU_AVAILABLE , reason = "CUDA not available" )
1202+ def test_gpu_numerical_gradient_minplus ():
1203+ """Verify GPU MinPlus gradients using finite differences."""
1204+ torch .manual_seed (42 )
1205+
1206+ m , k , n = 3 , 4 , 3
1207+ a = (torch .randn (m , k , device = "cuda" ) * 3 ).requires_grad_ (True )
1208+ b = (torch .randn (k , n , device = "cuda" ) * 3 ).requires_grad_ (True )
1209+
1210+ c = tropical_minplus_matmul (a , b )
1211+ loss = c .sum ()
1212+ loss .backward ()
1213+
1214+ analytical_grad_a = a .grad .clone ()
1215+
1216+ eps = 1e-4
1217+ numerical_grad_a = torch .zeros_like (a )
1218+
1219+ for i in range (m ):
1220+ for j in range (k ):
1221+ a_plus = a .detach ().clone ()
1222+ a_plus [i , j ] += eps
1223+ a_minus = a .detach ().clone ()
1224+ a_minus [i , j ] -= eps
1225+
1226+ c_plus = tropical_minplus_matmul (a_plus , b .detach ()).sum ()
1227+ c_minus = tropical_minplus_matmul (a_minus , b .detach ()).sum ()
1228+
1229+ numerical_grad_a [i , j ] = (c_plus - c_minus ) / (2 * eps )
1230+
1231+ assert torch .allclose (analytical_grad_a , numerical_grad_a , atol = 0.1 ), \
1232+ "GPU MinPlus numerical gradient mismatch"
1233+
1234+
1235+ @pytest .mark .skipif (not GPU_AVAILABLE , reason = "CUDA not available" )
1236+ def test_gpu_numerical_gradient_maxmul ():
1237+ """Verify GPU MaxMul gradients using finite differences."""
1238+ torch .manual_seed (42 )
1239+
1240+ m , k , n = 3 , 4 , 3
1241+ # Use positive values for MaxMul
1242+ a = (torch .randn (m , k , device = "cuda" ).abs () * 2 + 0.5 ).requires_grad_ (True )
1243+ b = (torch .randn (k , n , device = "cuda" ).abs () * 2 + 0.5 ).requires_grad_ (True )
1244+
1245+ c = tropical_maxmul_matmul (a , b )
1246+ loss = c .sum ()
1247+ loss .backward ()
1248+
1249+ analytical_grad_a = a .grad .clone ()
1250+
1251+ eps = 1e-4
1252+ numerical_grad_a = torch .zeros_like (a )
1253+
1254+ for i in range (m ):
1255+ for j in range (k ):
1256+ a_plus = a .detach ().clone ()
1257+ a_plus [i , j ] += eps
1258+ a_minus = a .detach ().clone ()
1259+ a_minus [i , j ] -= eps
1260+
1261+ c_plus = tropical_maxmul_matmul (a_plus , b .detach ()).sum ()
1262+ c_minus = tropical_maxmul_matmul (a_minus , b .detach ()).sum ()
1263+
1264+ numerical_grad_a [i , j ] = (c_plus - c_minus ) / (2 * eps )
1265+
1266+ assert torch .allclose (analytical_grad_a , numerical_grad_a , atol = 0.1 ), \
1267+ "GPU MaxMul numerical gradient mismatch"
1268+
1269+
1270+ @pytest .mark .skipif (not GPU_AVAILABLE , reason = "CUDA not available" )
1271+ def test_gpu_minplus_optimization ():
1272+ """Test that GPU MinPlus gradients enable optimization to converge."""
1273+ torch .manual_seed (42 )
1274+
1275+ a = torch .randn (8 , 6 , device = "cuda" , requires_grad = True )
1276+ b = torch .randn (6 , 10 , device = "cuda" )
1277+ target = torch .randn (8 , 10 , device = "cuda" )
1278+
1279+ optimizer = torch .optim .SGD ([a ], lr = 0.1 )
1280+
1281+ initial_loss = None
1282+ for step in range (10 ):
1283+ optimizer .zero_grad ()
1284+ c = tropical_minplus_matmul (a , b )
1285+ loss = torch .nn .functional .mse_loss (c , target )
1286+ if initial_loss is None :
1287+ initial_loss = loss .item ()
1288+ loss .backward ()
1289+ optimizer .step ()
1290+
1291+ assert loss .item () < initial_loss , \
1292+ f"GPU MinPlus loss should decrease: { initial_loss } -> { loss .item ()} "
1293+
1294+
1295+ @pytest .mark .skipif (not GPU_AVAILABLE , reason = "CUDA not available" )
1296+ def test_gpu_maxmul_optimization ():
1297+ """Test that GPU MaxMul gradients enable optimization to converge."""
1298+ torch .manual_seed (42 )
1299+
1300+ # Use positive values for MaxMul
1301+ a = (torch .randn (8 , 6 , device = "cuda" ).abs () + 0.1 ).requires_grad_ (True )
1302+ b = (torch .randn (6 , 10 , device = "cuda" ).abs () + 0.1 )
1303+ target = torch .randn (8 , 10 , device = "cuda" ).abs () + 1.0
1304+
1305+ optimizer = torch .optim .SGD ([a ], lr = 0.01 )
1306+
1307+ initial_loss = None
1308+ for step in range (10 ):
1309+ optimizer .zero_grad ()
1310+ c = tropical_maxmul_matmul (a , b )
1311+ loss = torch .nn .functional .mse_loss (c , target )
1312+ if initial_loss is None :
1313+ initial_loss = loss .item ()
1314+ loss .backward ()
1315+ optimizer .step ()
1316+ # Keep values positive for MaxMul
1317+ with torch .no_grad ():
1318+ a .clamp_ (min = 0.01 )
1319+
1320+ assert loss .item () < initial_loss , \
1321+ f"GPU MaxMul loss should decrease: { initial_loss } -> { loss .item ()} "
1322+
1323+
11671324# ============================================================================
11681325# Batched GPU Gradient Tests
11691326# ============================================================================
0 commit comments