Skip to content

Commit e71205a

Browse files
committed
Updated tests and load_components
1 parent f004ef9 commit e71205a

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -678,11 +678,23 @@ def write_components(
678678
print(f"Written to: {output_file_name}")
679679

680680

681-
def load_components(user_yaml: dict, name: str):
681+
def load_components(
682+
user_yaml: dict, name: str, namespace: str, local_queue: Optional[str]
683+
):
682684
component_list = []
683685
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
686+
lq_name = local_queue or get_default_kueue_name(namespace)
684687
for component in components:
685688
if "generictemplate" in component:
689+
if (
690+
"workload.codeflare.dev/appwrapper"
691+
in component["generictemplate"]["metadata"]["labels"]
692+
):
693+
del component["generictemplate"]["metadata"]["labels"][
694+
"workload.codeflare.dev/appwrapper"
695+
]
696+
labels = component["generictemplate"]["metadata"]["labels"]
697+
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
686698
component_list.append(component["generictemplate"])
687699

688700
resources = "---\n" + "---\n".join(
@@ -790,11 +802,11 @@ def generate_appwrapper(
790802
if mcad:
791803
write_user_appwrapper(user_yaml, outfile)
792804
else:
793-
write_components(user_yaml, outfile, local_queue)
805+
write_components(user_yaml, outfile, namespace, local_queue)
794806
return outfile
795807
else:
796808
if mcad:
797809
user_yaml = load_appwrapper(user_yaml, name)
798810
else:
799-
user_yaml = load_components(user_yaml, name)
811+
user_yaml = load_components(user_yaml, name, namespace, local_queue)
800812
return user_yaml

tests/unit_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,6 @@ def get_local_queue(group, version, namespace, plural):
333333

334334

335335
def test_cluster_creation_no_mcad(mocker):
336-
# With written resources
337336
# Create Ray Cluster with no local queue specified
338337
mocker.patch("kubernetes.client.ApisApi.get_api_versions")
339338
mocker.patch(
@@ -359,6 +358,7 @@ def test_cluster_creation_no_mcad(mocker):
359358

360359

361360
def test_cluster_creation_no_mcad_local_queue(mocker):
361+
# With written resources
362362
# Create Ray Cluster with local queue specified
363363
mocker.patch("kubernetes.client.ApisApi.get_api_versions")
364364
mocker.patch(
@@ -395,6 +395,7 @@ def test_cluster_creation_no_mcad_local_queue(mocker):
395395
image="quay.io/project-codeflare/ray:latest-py39-cu118",
396396
write_to_file=False,
397397
mcad=False,
398+
local_queue="local-queue-default",
398399
)
399400
cluster = Cluster(config)
400401
test_resources = []

0 commit comments

Comments
 (0)