Skip to content

Commit 6fcb590

Browse files
authored
Update deepspeed precision test (#12727)
1 parent 3f76c14 commit 6fcb590

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tests/strategies/test_deepspeed_strategy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,13 @@ def test_deepspeed_strategy_env(tmpdir, monkeypatch, deepspeed_config):
168168

169169

170170
@RunIf(deepspeed=True)
171+
@mock.patch("torch.cuda.device_count", return_value=1)
171172
@pytest.mark.parametrize("precision", [16, "mixed"])
172173
@pytest.mark.parametrize(
173174
"amp_backend",
174175
["native", pytest.param("apex", marks=RunIf(amp_apex=True))],
175176
)
176-
def test_deepspeed_precision_choice(amp_backend, precision, tmpdir):
177+
def test_deepspeed_precision_choice(_, amp_backend, precision, tmpdir):
177178
"""Test to ensure precision plugin is also correctly chosen.
178179
179180
DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin

0 commit comments

Comments
 (0)