Skip to content

Commit 9ad6000

Browse files
committed
fix: validate fp16.loss_scale before coercion
Signed-off-by: nathon-lee <leejianwoo@gmail.com>
1 parent 225ab4e commit 9ad6000

2 files changed

Lines changed: 21 additions & 2 deletions

File tree

deepspeed/runtime/precision_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,11 @@ class DeepSpeedFP16Config(DeepSpeedConfigModel):
114114
"""
115115
Loss scaling value. Default value of 0 means dynamic loss scaling instead of static loss scale.
116116
"""
117-
118-
@field_validator("loss_scale")
117+
@field_validator("loss_scale", mode="before")
119118
@classmethod
120119
def _validate_loss_scale(cls, v):
121120
# Prevent True/False from being treated as 1/0
121+
# (must run before Pydantic coerces bool -> float)
122122
if isinstance(v, bool):
123123
raise ValueError("fp16.loss_scale must be a number, not bool")
124124

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import math
2+
3+
import pytest
4+
from pydantic import ValidationError
5+
6+
from deepspeed.runtime.precision_config import DeepSpeedFP16Config
7+
8+
9+
@pytest.mark.parametrize("loss_scale", [-1, float("inf"), float("nan"), True])
10+
def test_fp16_loss_scale_rejects_invalid_values(loss_scale):
11+
with pytest.raises(ValidationError):
12+
DeepSpeedFP16Config(loss_scale=loss_scale)
13+
14+
15+
@pytest.mark.parametrize("loss_scale", [0, 1, 2.0, "3"])
16+
def test_fp16_loss_scale_accepts_valid_values(loss_scale):
17+
cfg = DeepSpeedFP16Config(loss_scale=loss_scale)
18+
assert math.isfinite(cfg.loss_scale)
19+
assert cfg.loss_scale >= 0

0 commit comments

Comments
 (0)