Skip to content

Sanitize non-finite history values before forecasting#192

Open
ShinDongWoon wants to merge 4 commits intomainfrom
codex/fix-residual-weights-sum-error-nip2qs
Open

Sanitize non-finite history values before forecasting#192
ShinDongWoon wants to merge 4 commits intomainfrom
codex/fix-residual-weights-sum-error-nip2qs

Conversation

@ShinDongWoon
Copy link
Owner

Summary

  • clone the encoder slice and replace non-finite inputs with zeros before embedding
  • raise a clear error if non-floating tensors contain non-finite values

Testing

  • pytest

https://chatgpt.com/codex/tasks/task_e_68d4eba5571c83288b40cc76e7d8b8ce

@ShinDongWoon
Copy link
Owner Author

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting

Comment on lines 1161 to 1168
enc_x_value = x[:, -self.input_len :, :].clone()
if torch.any(~torch.isfinite(enc_x_value)):
if not enc_x_value.is_floating_point():
raise RuntimeError(
"TimesNet input contains non-finite values in non-floating tensor"
)
invalid_mask = ~torch.isfinite(enc_x_value)
enc_x_value = enc_x_value.masked_fill(invalid_mask, 0.0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] Guard torch.isfinite for non‑floating inputs

The new sanitization in TimesNet.forward invokes torch.isfinite before checking the tensor’s dtype. When x is an integral tensor (or any non‑floating type), torch.isfinite raises RuntimeError: isfinite not implemented for type … before the custom error message in the subsequent if not enc_x_value.is_floating_point() block can run. This means that calling the model with integer history values now always crashes with a low‑level PyTorch error even though integers cannot contain non‑finite values and the code used to defer to later dtype handling. The check should first verify enc_x_value.is_floating_point() (or cast) before applying torch.isfinite so non‑floating inputs either bypass sanitization or raise the intended explanatory error.

Useful? React with 👍 / 👎.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@codex fix comments

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary

Testing

  • pytest

View task →

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

codex review

@ShinDongWoon
Copy link
Owner Author

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting

Comment on lines 399 to 413
softmax_dtype = amp.dtype
if softmax_dtype in (torch.float16, torch.bfloat16):
amp_for_softmax = amp.to(dtype=torch.float32)
weights_float = F.softmax(amp_for_softmax, dim=1)
else:
weights_float = F.softmax(amp, dim=1)
eps = torch.finfo(weights_float.dtype).eps
weight_sum = weights_float.sum(dim=1, keepdim=True)
zero_mask = weight_sum <= eps
if zero_mask.any():
uniform = torch.full_like(weights_float, 1.0 / weights_float.size(1))
weights_float = torch.where(zero_mask, uniform, weights_float)
weight_sum = torch.where(zero_mask, torch.ones_like(weight_sum), weight_sum)
weights_float = weights_float / weight_sum.clamp_min(eps)
weights_flat = weights_float.to(dtype=amp.dtype)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] Prevent NaN residual weights from bypassing validation

The new residual weighting path normalizes weights_float but no longer verifies that the softmax output is finite. When amplitudes contains nan or ±inf (which can happen with mixed‑precision overflow), F.softmax returns nan, weight_sum becomes nan, zero_mask remains false, and the division leaves weights_flat filled with nan. The previous version raised an error when weights did not sum to 1, but now the block quietly propagates NaNs into the forecast. Consider checking torch.isfinite(weights_float) before normalization and raising if invalid values are detected so corrupted inputs cannot silently poison the output.

Useful? React with 👍 / 👎.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@codex fix comments

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary

Testing

  • pytest

View task →

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant