Skip to content

Commit 4d01301

Browse files
committed
skip unimplemented error; update workflow
1 parent 1e0aeaa commit 4d01301

14 files changed

Lines changed: 56 additions & 7 deletions

File tree

.github/workflows/cpu-inference.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,7 @@ jobs:
6363
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
6464
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
6565
cd tests
66-
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference' unit/inference/test_inference_config.py
66+
pytest -v -s unit/autotuning/ unit/checkpoint/ unit/comm/ unit/compression/ unit/elasticity/ unit/launcher/ unit/profiling/ unit/ops
67+
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' unit/
68+
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference_ops' unit/
69+
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -n 4 -m 'inference' unit/

accelerator/abstract_accelerator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ def is_bf16_supported(self):
156156
def is_fp16_supported(self):
157157
...
158158

159+
@abc.abstractmethod
160+
def supported_dtypes(self):
161+
...
162+
159163
# Misc
160164
@abc.abstractmethod
161165
def amp(self):

accelerator/cpu_accelerator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,10 @@ def is_bf16_supported(self):
183183
return True
184184

185185
def is_fp16_supported(self):
186-
return True
186+
return False
187+
188+
def supported_dtypes(self):
189+
return [torch.float, torch.bfloat16]
187190

188191
# Tensor operations
189192

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def op_enabled(op_name):
153153
for op_name, builder in ALL_OPS.items():
154154
op_compatible = builder.is_compatible()
155155
compatible_ops[op_name] = op_compatible
156+
compatible_ops["deepspeed_not_implemented"] = False
156157

157158
# If op is requested but not available, throw an error.
158159
if op_enabled(op_name) and not op_compatible:

tests/unit/checkpoint/test_latest_checkpoint.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@
55

66
import deepspeed
77

8+
import pytest
89
from unit.common import DistributedTest
910
from unit.simple_model import *
1011

1112
from unit.checkpoint.common import checkpoint_correctness_verification
13+
from deepspeed.ops.op_builder import FusedAdamBuilder
14+
15+
if not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]:
16+
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
1217

1318

1419
class TestLatestCheckpoint(DistributedTest):

tests/unit/inference/test_inference.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
from deepspeed.model_implementations import DeepSpeedTransformerInference
2121
from torch import nn
2222
from deepspeed.accelerator import get_accelerator
23+
from deepspeed.ops.op_builder import InferenceBuilder
24+
25+
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
26+
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
2327

2428
rocm_version = OpBuilder.installed_rocm_version()
2529
if rocm_version != (0, 0):

tests/unit/inference/test_model_profiling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from transformers import pipeline
1212
from unit.common import DistributedTest
1313
from deepspeed.accelerator import get_accelerator
14+
from deepspeed.ops.op_builder import InferenceBuilder
15+
16+
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
17+
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
1418

1519

1620
@pytest.fixture

tests/unit/ops/accelerators/test_accelerator_backward.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#pytest.skip(
2020
# "transformer kernels are temporarily disabled because of unexplained failures",
2121
# allow_module_level=True)
22+
if torch.half not in get_accelerator().supported_dtypes():
23+
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)
2224

2325

2426
def check_equal(first, second, atol=1e-2, verbose=False):

tests/unit/ops/accelerators/test_accelerator_forward.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from deepspeed.accelerator import get_accelerator
1616
from unit.common import DistributedTest
1717

18+
if torch.half not in get_accelerator().supported_dtypes():
19+
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)
20+
1821

1922
def check_equal(first, second, atol=1e-2, verbose=False):
2023
if verbose:

tests/unit/ops/adam/test_adamw.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
from deepspeed.ops.adam import DeepSpeedCPUAdam
1212
from unit.common import DistributedTest
1313
from unit.simple_model import SimpleModel
14+
from deepspeed.accelerator import get_accelerator
1415

16+
if torch.half not in get_accelerator().supported_dtypes():
17+
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)
1518
# yapf: disable
1619
#'optimizer, zero_offload, torch_adam, adam_w_mode, resulting_optimizer
1720
adam_configs = [["AdamW", False, False, False, (FusedAdam, True)],

0 commit comments

Comments
 (0)