Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 61 additions & 19 deletions mlff/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def print_metrics(epoch, eval_metrics):


def graph_mse_loss(y, y_label, batch_segments, graph_mask, scale, use_robust_loss=False, robust_loss_alpha=1.99):
del batch_segments
# del batch_segments

assert y.shape == y_label.shape

Expand All @@ -42,6 +42,14 @@ def graph_mse_loss(y, y_label, batch_segments, graph_mask, scale, use_robust_los
)
denominator = full_mask.sum().astype(y.dtype)

num_nodes_per_graph = jax.ops.segment_sum(
data=jnp.ones_like(batch_segments),
segment_ids=batch_segments,
num_segments=len(graph_mask)
)

num_nodes_per_graph = jnp.maximum(num_nodes_per_graph, 1)

# jax.debug.print("y_label_graph: {}", y_label)
# jax.debug.print("y_graph: {}", y)

Expand All @@ -57,19 +65,31 @@ def graph_mse_loss(y, y_label, batch_segments, graph_mask, scale, use_robust_los

# loss = jnp.sum(2 * scale * ROBUST_LOSS_DIST.nllfun(diff, robust_loss_alpha, robust_scale)) / denominator
# Compute adaptive scale, use scale of 1.0
loss = jnp.sum(2 * scale * ROBUST_LOSS_DIST.nllfun(diff, robust_loss_alpha, 1.0)) / denominator
#loss = jnp.sum(2 * scale * ROBUST_LOSS_DIST.nllfun(diff, robust_loss_alpha, 1.0)) / denominator
per_graph_loss = 2 * scale * ROBUST_LOSS_DIST.nllfun(diff, robust_loss_alpha, 1.0)


else:
# Regular L2 loss
loss = (
jnp.sum(
2 * scale * optax.l2_loss(
jnp.where(full_mask, y, 0).reshape(-1),
jnp.where(full_mask, y_label, 0).reshape(-1),
)
)
/ denominator
# loss = (
# jnp.sum(
# 2 * scale * optax.l2_loss(
# jnp.where(full_mask, y, 0).reshape(-1),
# jnp.where(full_mask, y_label, 0).reshape(-1),
# )
# )
# / denominator
# )
per_graph_loss = 2 * scale * optax.l2_loss(
jnp.where(full_mask, y, 0.0),
jnp.where(full_mask, y_label, 0.0)
)

per_graph_loss = per_graph_loss / num_nodes_per_graph
per_graph_loss = jnp.where(graph_mask, per_graph_loss, 0.0)

loss = jnp.sum(per_graph_loss) / jnp.maximum(denominator, 1.)

return loss


Expand Down Expand Up @@ -174,17 +194,39 @@ def graph_mae_loss(y, y_label, batch_segments, graph_mask, scale):
graph_mask, [y_label.ndim - 1 - o for o in range(0, y_label.ndim - 1)]
)
denominator = full_mask.sum().astype(y.dtype)


num_nodes_per_graph = jax.ops.segment_sum(
data=jnp.ones_like(batch_segments),
segment_ids=batch_segments,
num_segments=len(graph_mask)
)

num_nodes_per_graph = jnp.maximum(num_nodes_per_graph, 1)

# Calculate absolute error instead of squared error
loss = (
jnp.sum(
jnp.abs(
jnp.where(full_mask, y, 0).reshape(-1) -
jnp.where(full_mask, y_label, 0).reshape(-1)
)
)
/ denominator
# loss = (
# jnp.sum(
# jnp.abs(
# jnp.where(full_mask, y, 0).reshape(-1) -
# jnp.where(full_mask, y_label, 0).reshape(-1)
# )
# )
# / denominator
# )

per_graph_mae = (jnp.abs(
jnp.where(full_mask, y, 0).reshape(-1) -
jnp.where(full_mask, y_label, 0).reshape(-1)
)) / num_nodes_per_graph

per_graph_mae = jnp.where(
full_mask,
per_graph_mae,
jnp.asarray(0., dtype=per_graph_mae.dtype)
)

loss = jnp.sum(per_graph_mae) / jnp.maximum(denominator, 1.)

return loss


Expand Down