Sanitize non-finite history values before forecasting#192
Sanitize non-finite history values before forecasting#192ShinDongWoon wants to merge 4 commits intomainfrom
Conversation
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting
| enc_x_value = x[:, -self.input_len :, :].clone() | ||
| if torch.any(~torch.isfinite(enc_x_value)): | ||
| if not enc_x_value.is_floating_point(): | ||
| raise RuntimeError( | ||
| "TimesNet input contains non-finite values in non-floating tensor" | ||
| ) | ||
| invalid_mask = ~torch.isfinite(enc_x_value) | ||
| enc_x_value = enc_x_value.masked_fill(invalid_mask, 0.0) |
There was a problem hiding this comment.
[P1] Guard torch.isfinite for non‑floating inputs
The new sanitization in TimesNet.forward invokes torch.isfinite before checking the tensor’s dtype. When x is an integral tensor (or any non‑floating type), torch.isfinite raises RuntimeError: isfinite not implemented for type … before the custom error message in the subsequent if not enc_x_value.is_floating_point() block can run. This means that calling the model with integer history values now always crashes with a low‑level PyTorch error even though integers cannot contain non‑finite values and the code used to defer to later dtype handling. The check should first verify enc_x_value.is_floating_point() (or cast) before applying torch.isfinite so non‑floating inputs either bypass sanitization or raise the intended explanatory error.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Summary
- Guarded TimesNet history sanitization so
torch.isfiniteonly runs on floating tensors, while still surfacing the explicit error when non-floating inputs contain invalid values after casting. src/timesnet_forecast/models/timesnet.pyL1150-L1174
Testing
- ✅
pytest
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting
| softmax_dtype = amp.dtype | ||
| if softmax_dtype in (torch.float16, torch.bfloat16): | ||
| amp_for_softmax = amp.to(dtype=torch.float32) | ||
| weights_float = F.softmax(amp_for_softmax, dim=1) | ||
| else: | ||
| weights_float = F.softmax(amp, dim=1) | ||
| eps = torch.finfo(weights_float.dtype).eps | ||
| weight_sum = weights_float.sum(dim=1, keepdim=True) | ||
| zero_mask = weight_sum <= eps | ||
| if zero_mask.any(): | ||
| uniform = torch.full_like(weights_float, 1.0 / weights_float.size(1)) | ||
| weights_float = torch.where(zero_mask, uniform, weights_float) | ||
| weight_sum = torch.where(zero_mask, torch.ones_like(weight_sum), weight_sum) | ||
| weights_float = weights_float / weight_sum.clamp_min(eps) | ||
| weights_flat = weights_float.to(dtype=amp.dtype) |
There was a problem hiding this comment.
[P1] Prevent NaN residual weights from bypassing validation
The new residual weighting path normalizes weights_float but no longer verifies that the softmax output is finite. When amplitudes contains nan or ±inf (which can happen with mixed‑precision overflow), F.softmax returns nan, weight_sum becomes nan, zero_mask remains false, and the division leaves weights_flat filled with nan. The previous version raised an error when weights did not sum to 1, but now the block quietly propagates NaNs into the forecast. Consider checking torch.isfinite(weights_float) before normalization and raising if invalid values are detected so corrupted inputs cannot silently poison the output.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Summary
- Raised a runtime error when residual softmax weights contain non-finite values so corrupted amplitudes cannot silently propagate into forecasts. src/timesnet_forecast/models/timesnet.pyL396-L418
Testing
- ✅
pytest
…net.py Sanitize non-finite history and residual weights before forecasting
Summary
Testing
https://chatgpt.com/codex/tasks/task_e_68d4eba5571c83288b40cc76e7d8b8ce