Add option to normalize NS input in double precision#238
Conversation
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Greptile SummaryThis PR adds a
Confidence Score: 5/5Safe to merge; the double-precision normalization logic is correct and the in-place clamp guard in the distributed path works as intended. The normalization math is sound for both the distributed and non-distributed paths. The only notable issue is the unthrottled warning call, which does not affect correctness. The logging.warning call in muon_utils.py (non-distributed normalize_in_double branch) should be rate-limited before merging to avoid log spam during training. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[newton_schulz called] --> B{tp_group?}
B -- yes --> C[distributed_normalize_p2]
C --> D{normalize_in_double?}
D -- yes --> E["x_sq = x.double()\nsum in float64\nall_reduce\nnorm = sqrt.to(float32)\nx / norm"]
D -- no --> F["x_sq = x\nsum in float32\nall_reduce\nnorm = sqrt.to(float32)\nnorm.clamp_min_(eps)\nx / norm"]
B -- no --> G{normalize_in_double?}
G -- no --> H["F.normalize(x, p=2,\ndim=(-2,-1), eps=eps)"]
G -- yes --> I["logging.warning (every call)\nnorm = vector_norm(x,\ndtype=float64).to(float32)\nx / norm"]
E --> J[NS iterations]
F --> J
H --> J
I --> J
J --> K[return X]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[newton_schulz called] --> B{tp_group?}
B -- yes --> C[distributed_normalize_p2]
C --> D{normalize_in_double?}
D -- yes --> E["x_sq = x.double()\nsum in float64\nall_reduce\nnorm = sqrt.to(float32)\nx / norm"]
D -- no --> F["x_sq = x\nsum in float32\nall_reduce\nnorm = sqrt.to(float32)\nnorm.clamp_min_(eps)\nx / norm"]
B -- no --> G{normalize_in_double?}
G -- no --> H["F.normalize(x, p=2,\ndim=(-2,-1), eps=eps)"]
G -- yes --> I["logging.warning (every call)\nnorm = vector_norm(x,\ndtype=float64).to(float32)\nx / norm"]
E --> J[NS iterations]
F --> J
H --> J
I --> J
J --> K[return X]
Reviews (3): Last reviewed commit: "bug fix" | Re-trigger Greptile |
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
|
/ok to test 854dfd9 |
"if" clause on calculated norm would trigger device to host sync, trying to do it on device will take multiple path, in which case none of them are better than just use double. A custom kernel can do it but is out of scope.
No zero division guard for fp64 path as we don't imagine a case that LLM training is done in fp64. Square of fp32 value won't underflow fp64.