diff --git a/mlff/utils/training_utils.py b/mlff/utils/training_utils.py index c070224..8b4bf14 100644 --- a/mlff/utils/training_utils.py +++ b/mlff/utils/training_utils.py @@ -31,7 +31,7 @@ def print_metrics(epoch, eval_metrics): def graph_mse_loss(y, y_label, batch_segments, graph_mask, scale, use_robust_loss=False, robust_loss_alpha=1.99): - del batch_segments + # del batch_segments assert y.shape == y_label.shape @@ -42,6 +42,14 @@ def graph_mse_loss(y, y_label, batch_segments, graph_mask, scale, use_robust_los ) denominator = full_mask.sum().astype(y.dtype) + num_nodes_per_graph = jax.ops.segment_sum( + data=jnp.ones_like(batch_segments), + segment_ids=batch_segments, + num_segments=len(graph_mask) + ) + + num_nodes_per_graph = jnp.maximum(num_nodes_per_graph, 1) + # jax.debug.print("y_label_graph: {}", y_label) # jax.debug.print("y_graph: {}", y) @@ -57,19 +65,31 @@ def graph_mse_loss(y, y_label, batch_segments, graph_mask, scale, use_robust_los # loss = jnp.sum(2 * scale * ROBUST_LOSS_DIST.nllfun(diff, robust_loss_alpha, robust_scale)) / denominator # Compute adaptive scale, use scale of 1.0 - loss = jnp.sum(2 * scale * ROBUST_LOSS_DIST.nllfun(diff, robust_loss_alpha, 1.0)) / denominator + #loss = jnp.sum(2 * scale * ROBUST_LOSS_DIST.nllfun(diff, robust_loss_alpha, 1.0)) / denominator + per_graph_loss = 2 * scale * ROBUST_LOSS_DIST.nllfun(diff, robust_loss_alpha, 1.0) + else: # Regular L2 loss - loss = ( - jnp.sum( - 2 * scale * optax.l2_loss( - jnp.where(full_mask, y, 0).reshape(-1), - jnp.where(full_mask, y_label, 0).reshape(-1), - ) - ) - / denominator + # loss = ( + # jnp.sum( + # 2 * scale * optax.l2_loss( + # jnp.where(full_mask, y, 0).reshape(-1), + # jnp.where(full_mask, y_label, 0).reshape(-1), + # ) + # ) + # / denominator + # ) + per_graph_loss = 2 * scale * optax.l2_loss( + jnp.where(full_mask, y, 0.0), + jnp.where(full_mask, y_label, 0.0) ) + + per_graph_loss = per_graph_loss / num_nodes_per_graph + per_graph_loss = jnp.where(graph_mask, per_graph_loss, 0.0) + + loss = jnp.sum(per_graph_loss) / jnp.maximum(denominator, 1.) + return loss @@ -174,17 +194,39 @@ def graph_mae_loss(y, y_label, batch_segments, graph_mask, scale): graph_mask, [y_label.ndim - 1 - o for o in range(0, y_label.ndim - 1)] ) denominator = full_mask.sum().astype(y.dtype) - + + num_nodes_per_graph = jax.ops.segment_sum( + data=jnp.ones_like(batch_segments), + segment_ids=batch_segments, + num_segments=len(graph_mask) + ) + + num_nodes_per_graph = jnp.maximum(num_nodes_per_graph, 1) + # Calculate absolute error instead of squared error - loss = ( - jnp.sum( - jnp.abs( - jnp.where(full_mask, y, 0).reshape(-1) - - jnp.where(full_mask, y_label, 0).reshape(-1) - ) - ) - / denominator + # loss = ( + # jnp.sum( + # jnp.abs( + # jnp.where(full_mask, y, 0).reshape(-1) - + # jnp.where(full_mask, y_label, 0).reshape(-1) + # ) + # ) + # / denominator + # ) + + per_graph_mae = (jnp.abs( + jnp.where(full_mask, y, 0).reshape(-1) - + jnp.where(full_mask, y_label, 0).reshape(-1) + )) / num_nodes_per_graph + + per_graph_mae = jnp.where( + full_mask, + per_graph_mae, + jnp.asarray(0., dtype=per_graph_mae.dtype) ) + + loss = jnp.sum(per_graph_mae) / jnp.maximum(denominator, 1.) + return loss