Skip to content

Commit a49141a

Browse files
tchatonBorda
authored andcommitted
[App] Introduce Multi Node Component (#15524)
(cherry picked from commit ecc8ac0)
1 parent 92b5341 commit a49141a

File tree

8 files changed

+155
-6
lines changed

8 files changed

+155
-6
lines changed

examples/app_multi_node/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from lightning import LightningApp
1+
import lightning as L
22
from lightning.app.components import LightningTrainingComponent
33
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
44

5-
app = LightningApp(
5+
app = L.LightningApp(
66
LightningTrainingComponent(
77
"train.py",
88
num_nodes=2,

examples/app_multi_node/app_work.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import lightning.app as L
2+
from lightning.app.components import MultiNode
3+
4+
5+
class AnyDistributedComponent(L.LightningWork):
6+
def run(
7+
self,
8+
main_address: str,
9+
main_port: int,
10+
node_rank: int,
11+
):
12+
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {node_rank}")
13+
14+
15+
compute = L.CloudCompute("gpu")
16+
app = L.LightningApp(
17+
MultiNode(
18+
AnyDistributedComponent,
19+
nodes=2,
20+
cloud_compute=compute,
21+
)
22+
)

examples/app_multi_node/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from lightning.pytorch import Trainer
2-
from lightning.pytorch.demos.boring_classes import BoringModel
1+
from pytorch_lightning import Trainer
2+
from pytorch_lightning.demos.boring_classes import BoringModel
33

44
if __name__ == "__main__":
55
model = BoringModel()

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ exclude = [
3636
"src/lightning_app/cli/react-ui-template",
3737
"src/lightning_app/cli/app-template",
3838
"src/lightning_app/components/database",
39+
"src/lightning_app/components/multi_node",
3940
"src/lightning_app/frontend/just_py/just_py",
4041
]
4142
install_types = "True"
@@ -58,6 +59,7 @@ warn_no_return = "False"
5859
# the list can be generated with:
5960
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
6061
module = [
62+
"lightning_app.components.multi_node",
6163
"lightning_app.api.http_methods",
6264
"lightning_app.api.request_types",
6365
"lightning_app.cli.commands.app_commands",

src/lightning_app/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212

1313
- Added the `start` method to the work ([#15523](https://github.com/Lightning-AI/lightning/pull/15523))
1414

15-
-
15+
- Added a `MultiNode` Component to run with distributed computation with any frameworks ([#15524](https://github.com/Lightning-AI/lightning/pull/15524))
1616

1717
-
1818

src/lightning_app/components/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from lightning_app.components.database.client import DatabaseClient
22
from lightning_app.components.database.server import Database
3+
from lightning_app.components.multi_node import MultiNode
34
from lightning_app.components.python.popen import PopenPythonScript
45
from lightning_app.components.python.tracer import Code, TracerPythonScript
56
from lightning_app.components.serve.gradio import ServeGradio
@@ -16,6 +17,7 @@
1617
"ServeGradio",
1718
"ServeStreamlit",
1819
"ModelInferenceAPI",
20+
"MultiNode",
1921
"LightningTrainingComponent",
2022
"PyTorchLightningScriptRunner",
2123
]
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from typing import Any, Type
2+
3+
from lightning_app import structures
4+
from lightning_app.core.flow import LightningFlow
5+
from lightning_app.core.work import LightningWork
6+
from lightning_app.utilities.enum import WorkStageStatus
7+
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
8+
9+
10+
class MultiNode(LightningFlow):
11+
def __init__(
12+
self,
13+
work_cls: Type["LightningWork"],
14+
nodes: int,
15+
cloud_compute: "CloudCompute",
16+
*work_args: Any,
17+
**work_kwargs: Any,
18+
) -> None:
19+
"""This component enables performing distributed multi-node multi-device training.
20+
21+
Example::
22+
23+
import torch
24+
25+
import lightning as L
26+
from lightning.components import MultiNode
27+
28+
class AnyDistributedComponent(L.LightningWork):
29+
def run(
30+
self,
31+
main_address: str,
32+
main_port: int,
33+
node_rank: int,
34+
):
35+
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {node_rank}")
36+
37+
38+
compute = L.CloudCompute("gpu")
39+
app = L.LightningApp(
40+
MultiNode(
41+
AnyDistributedComponent,
42+
nodes=8,
43+
cloud_compute=compute,
44+
)
45+
)
46+
47+
Arguments:
48+
work_cls: The work to be executed
49+
nodes: Number of nodes.
50+
cloud_compute: The cloud compute object used in the cloud.
51+
work_args: Arguments to be provided to the work on instantiation.
52+
work_kwargs: Keywords arguments to be provided to the work on instantiation.
53+
"""
54+
super().__init__()
55+
self.ws = structures.List()
56+
self._work_cls = work_cls
57+
self.nodes = nodes
58+
self._cloud_compute = cloud_compute
59+
self._work_args = work_args
60+
self._work_kwargs = work_kwargs
61+
self.has_initialized = False
62+
63+
def run(self) -> None:
64+
# 1. Create & start the works
65+
if not self.has_initialized:
66+
for node_rank in range(self.nodes):
67+
self.ws.append(
68+
self._work_cls(
69+
*self._work_args,
70+
cloud_compute=self._cloud_compute,
71+
**self._work_kwargs,
72+
parallel=True,
73+
)
74+
)
75+
# Starting node `node_rank`` ...
76+
self.ws[-1].start()
77+
self.has_initialized = True
78+
79+
# 2. Wait for all machines to be started !
80+
if all(w.status.stage == WorkStageStatus.STARTED for w in self.ws):
81+
return
82+
83+
# Loop over all node machines
84+
for node_rank in range(self.nodes):
85+
86+
# 3. Run the user code in a distributed way !
87+
self.ws[node_rank].run(
88+
main_address=self.ws[0].internal_ip,
89+
main_port=self.ws[0].port,
90+
node_rank=node_rank,
91+
)
92+
93+
# 4. Stop the machine when finished.
94+
if self.ws[node_rank].has_succeeded:
95+
self.ws[node_rank].stop()

tests/tests_app_examples/test_multi_node.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22

3+
import pytest
34
from tests_app import _PROJECT_ROOT
45

56
from lightning_app.testing.testing import application_testing, LightningTestApp
@@ -8,11 +9,13 @@
89
class LightningTestMultiNodeApp(LightningTestApp):
910
def on_before_run_once(self):
1011
res = super().on_before_run_once()
11-
if all(w.has_finished for w in self.works):
12+
if self.works and all(w.has_stopped for w in self.works):
13+
assert len([w for w in self.works]) == 2
1214
return True
1315
return res
1416

1517

18+
@pytest.mark.skipif(True, reason="flaky")
1619
def test_multi_node_example():
1720
cwd = os.getcwd()
1821
new_cwd = os.path.join(_PROJECT_ROOT, "examples/app_multi_node")
@@ -27,3 +30,28 @@ def test_multi_node_example():
2730
result = application_testing(LightningTestMultiNodeApp, command_line)
2831
assert result.exit_code == 0
2932
os.chdir(cwd)
33+
34+
35+
class LightningTestMultiNodeWorksApp(LightningTestApp):
36+
def on_before_run_once(self):
37+
res = super().on_before_run_once()
38+
if self.works and all(w.has_stopped for w in self.works):
39+
assert len([w for w in self.works]) == 2
40+
return True
41+
return res
42+
43+
44+
def test_multi_node_example_2():
45+
cwd = os.getcwd()
46+
new_cwd = os.path.join(_PROJECT_ROOT, "examples/app_multi_node")
47+
os.chdir(new_cwd)
48+
command_line = [
49+
"app_work.py",
50+
"--blocking",
51+
"False",
52+
"--open-ui",
53+
"False",
54+
]
55+
result = application_testing(LightningTestMultiNodeWorksApp, command_line)
56+
assert result.exit_code == 0
57+
os.chdir(cwd)

0 commit comments

Comments
 (0)