Skip to content

Commit 43c7778

Browse files
awaelchlilexierule
authored andcommitted
Fix ReduceOp type hint in ColossalAI strategy (#15535)
(cherry picked from commit 62d040c)
1 parent e9e461f commit 43c7778

File tree

4 files changed

+35
-4
lines changed

4 files changed

+35
-4
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4848

4949
- Fixed `TensorBoardLogger` not validating the input array type when logging the model graph ([#15323](https://github.com/Lightning-AI/lightning/pull/15323))
5050

51-
-
52-
53-
-
51+
- Fixed an attribute error in `ColossalAIStrategy` at import time when `torch.distributed` is not available ([#15535](https://github.com/Lightning-AI/lightning/pull/15535))
5452

5553
- Fixed an issue with the `BaseFinetuning` callback not setting the `track_running_stats` attribute for batch normaliztion layers ([#15063](https://github.com/Lightning-AI/lightning/pull/15063))
5654

src/pytorch_lightning/strategies/colossalai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
432432
strategy_registry.register("colossalai", cls, description="Default ColossalAI Strategy")
433433

434434
def reduce(
435-
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = ReduceOp.SUM
435+
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "sum"
436436
) -> Tensor:
437437
with _patch_cuda_is_available():
438438
from colossalai.communication.collective import reduce

tests/tests_lite/utilities/test_imports.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import subprocess
15+
import sys
16+
from textwrap import dedent
17+
1418
from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE
1519
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
1620

@@ -29,3 +33,16 @@ def test_imports():
2933
assert not _FAIRSCALE_AVAILABLE
3034
else:
3135
assert _FAIRSCALE_AVAILABLE
36+
37+
38+
def test_import_lightning_lite_with_torch_dist_unavailable():
39+
"""Test that the package can be imported regardless of whether torch.distributed is available."""
40+
code = dedent(
41+
"""
42+
import torch
43+
torch.distributed.is_available = lambda: False # pretend torch.distributed not available
44+
import lightning_lite
45+
"""
46+
)
47+
# run in complete isolation
48+
assert subprocess.call([sys.executable, "-c", code]) == 0

tests/tests_pytorch/utilities/test_imports.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
import importlib
1616
import operator
17+
import subprocess
18+
import sys
19+
from textwrap import dedent
1720
from unittest import mock
1821

1922
import pytest
@@ -141,3 +144,16 @@ def test_import_with_unavailable_dependencies(patch_name, new_fn, to_import, cle
141144
"""
142145
with mock.patch(patch_name, new=new_fn):
143146
importlib.import_module(to_import)
147+
148+
149+
def test_import_pytorch_lightning_with_torch_dist_unavailable():
150+
"""Test that the package can be imported regardless of whether torch.distributed is available."""
151+
code = dedent(
152+
"""
153+
import torch
154+
torch.distributed.is_available = lambda: False # pretend torch.distributed not available
155+
import pytorch_lightning
156+
"""
157+
)
158+
# run in complete isolation
159+
assert subprocess.call([sys.executable, "-c", code]) == 0

0 commit comments

Comments
 (0)