Skip to content

Commit 11196b1

Browse files
authored
Fix default CloudCompute for flows (#15371)
* Fix default CloudCompute for flows * Unit test added
1 parent 3fb98ad commit 11196b1

File tree

3 files changed

+42
-18
lines changed

3 files changed

+42
-18
lines changed

src/lightning_app/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3838
- Fixed a bug where the upload files endpoint would raise an error when running locally ([#14924](https://github.com/Lightning-AI/lightning/pull/14924))
3939
- Fixed BYOC cluster region selector -> hiding it from help since only us-east-1 has been tested and is recommended ([#15277]https://github.com/Lightning-AI/lightning/pull/15277)
4040
- Fixed a bug when launching an app on multiple clusters ([#15226](https://github.com/Lightning-AI/lightning/pull/15226))
41+
- Fixed a bug with a default CloudCompute for Lightning flows ([#15371](https://github.com/Lightning-AI/lightning/pull/15371))
4142

4243
## [0.6.2] - 2022-09-21
4344

src/lightning_app/core/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(
103103

104104
_validate_root_flow(root)
105105
self._root = root
106-
self.flow_cloud_compute = flow_cloud_compute or lightning_app.CloudCompute()
106+
self.flow_cloud_compute = flow_cloud_compute or lightning_app.CloudCompute(name="flow-lite")
107107

108108
# queues definition.
109109
self.delta_queue: Optional[BaseQueue] = None

tests/tests_app/runners/test_cloud.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,28 @@ def run(self):
6868
pass
6969

7070

71+
def get_cloud_runtime_request_body(**kwargs) -> "Body8":
72+
default_request_body = dict(
73+
app_entrypoint_file=mock.ANY,
74+
enable_app_server=True,
75+
flow_servers=[],
76+
image_spec=None,
77+
works=[],
78+
local_source=True,
79+
dependency_cache_key=mock.ANY,
80+
user_requested_flow_compute_config=V1UserRequestedFlowComputeConfig(
81+
name="flow-lite",
82+
preemptible=False,
83+
shm_size=0,
84+
),
85+
)
86+
87+
if kwargs.get("user_requested_flow_compute_config") is not None:
88+
default_request_body["user_requested_flow_compute_config"] = kwargs["user_requested_flow_compute_config"]
89+
90+
return Body8(**default_request_body)
91+
92+
7193
class TestAppCreationClient:
7294
"""Testing the calls made using GridRestClient to create the app."""
7395

@@ -138,8 +160,9 @@ def test_new_instance_on_different_cluster(self, monkeypatch):
138160
body=V1ProjectClusterBinding(cluster_id=new_cluster, project_id="default-project-id"),
139161
)
140162

163+
@pytest.mark.parametrize("flow_cloud_compute", [None, CloudCompute(name="t2.medium")])
141164
@mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock())
142-
def test_run_with_custom_flow_compute_config(self, monkeypatch):
165+
def test_run_with_default_flow_compute_config(self, monkeypatch, flow_cloud_compute):
143166
mock_client = mock.MagicMock()
144167
mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
145168
memberships=[V1Membership(name="test-project", project_id="test-project-id")]
@@ -155,30 +178,30 @@ def test_run_with_custom_flow_compute_config(self, monkeypatch):
155178
cloud_backend.client = mock_client
156179
monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
157180
monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
158-
app = mock.MagicMock()
159-
app.flows = []
160-
app.frontend = {}
161-
app.flow_cloud_compute = CloudCompute(name="t2.medium")
181+
182+
dummy_flow = mock.MagicMock()
183+
monkeypatch.setattr(dummy_flow, "run", lambda *args, **kwargs: None)
184+
if flow_cloud_compute is None:
185+
app = LightningApp(dummy_flow)
186+
else:
187+
app = LightningApp(dummy_flow, flow_cloud_compute=flow_cloud_compute)
188+
162189
cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file="entrypoint.py")
163190
cloud_runtime._check_uploaded_folder = mock.MagicMock()
164191

165192
monkeypatch.setattr(Path, "is_file", lambda *args, **kwargs: False)
166193
monkeypatch.setattr(cloud, "Path", Path)
167194
cloud_runtime.dispatch()
168-
body = Body8(
169-
app_entrypoint_file=mock.ANY,
170-
enable_app_server=True,
171-
flow_servers=[],
172-
image_spec=None,
173-
works=[],
174-
local_source=True,
175-
dependency_cache_key=mock.ANY,
176-
user_requested_flow_compute_config=V1UserRequestedFlowComputeConfig(
177-
name="t2.medium",
195+
196+
user_requested_flow_compute_config = None
197+
if flow_cloud_compute is not None:
198+
user_requested_flow_compute_config = V1UserRequestedFlowComputeConfig(
199+
name=flow_cloud_compute.name,
178200
preemptible=False,
179201
shm_size=0,
180-
),
181-
)
202+
)
203+
204+
body = get_cloud_runtime_request_body(user_requested_flow_compute_config=user_requested_flow_compute_config)
182205
cloud_runtime.backend.client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
183206
project_id="test-project-id", app_id=mock.ANY, body=body
184207
)

0 commit comments

Comments
 (0)