diff --git a/docs/cluster-configuration.md b/docs/cluster-configuration.md
index 7684db2ca..b83600fe7 100644
--- a/docs/cluster-configuration.md
+++ b/docs/cluster-configuration.md
@@ -3,7 +3,7 @@
To create Ray Clusters using the CodeFlare SDK a cluster configuration needs to be created first.
This is what a typical cluster configuration would look like; Note: The values for CPU and Memory are at the minimum requirements for creating the Ray Cluster.
-```
+```python
from codeflare_sdk import Cluster, ClusterConfiguration
cluster = Cluster(ClusterConfiguration(
@@ -20,8 +20,8 @@ cluster = Cluster(ClusterConfiguration(
num_gpus=0, # Default 0
mcad=True, # Default True
image="quay.io/project-codeflare/ray:latest-py39-cu118", # Mandatory Field
- instascale=False, # Default False
machine_types=["m5.xlarge", "g4dn.xlarge"],
+ labels={"exampleLabel": "example", "secondLabel": "example"},
))
```
@@ -30,3 +30,5 @@ From there a user can call `cluster.up()` and `cluster.down()` to create and rem
In cases where `mcad=False` a yaml file will be created with the individual Ray Cluster, Route/Ingress and Secret included.
The Ray Cluster and service will be created by KubeRay directly and the other components will be individually created.
+
+The `labels={"exampleLabel": "example"}` parameter can be used to apply additional labels to the RayCluster resource.
diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py
index 76f64287d..12a90f4af 100644
--- a/src/codeflare_sdk/cluster/cluster.py
+++ b/src/codeflare_sdk/cluster/cluster.py
@@ -187,6 +187,7 @@ def create_app_wrapper(self):
write_to_file = self.config.write_to_file
verify_tls = self.config.verify_tls
local_queue = self.config.local_queue
+ labels = self.config.labels
return generate_appwrapper(
name=name,
namespace=namespace,
@@ -211,6 +212,7 @@ def create_app_wrapper(self):
write_to_file=write_to_file,
verify_tls=verify_tls,
local_queue=local_queue,
+ labels=labels,
)
# creates a new cluster with the provided or default spec
diff --git a/src/codeflare_sdk/cluster/config.py b/src/codeflare_sdk/cluster/config.py
index 064b51cd4..f8010ea92 100644
--- a/src/codeflare_sdk/cluster/config.py
+++ b/src/codeflare_sdk/cluster/config.py
@@ -54,6 +54,7 @@ class ClusterConfiguration:
dispatch_priority: str = None
write_to_file: bool = False
verify_tls: bool = True
+ labels: dict = field(default_factory=dict)
def __post_init__(self):
if not self.verify_tls:
diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py
index 97dda5ba7..f5de1fbae 100755
--- a/src/codeflare_sdk/utils/generate_yaml.py
+++ b/src/codeflare_sdk/utils/generate_yaml.py
@@ -309,7 +309,11 @@ def get_default_kueue_name(namespace: str):
def write_components(
- user_yaml: dict, output_file_name: str, namespace: str, local_queue: Optional[str]
+ user_yaml: dict,
+ output_file_name: str,
+ namespace: str,
+ local_queue: Optional[str],
+ labels: dict,
):
# Create the directory if it doesn't exist
directory_path = os.path.dirname(output_file_name)
@@ -319,6 +323,7 @@ def write_components(
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
open(output_file_name, "w").close()
lq_name = local_queue or get_default_kueue_name(namespace)
+ cluster_labels = labels
with open(output_file_name, "a") as outfile:
for component in components:
if "generictemplate" in component:
@@ -331,6 +336,7 @@ def write_components(
]
labels = component["generictemplate"]["metadata"]["labels"]
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
+ labels.update(cluster_labels)
outfile.write("---\n")
yaml.dump(
component["generictemplate"], outfile, default_flow_style=False
@@ -339,11 +345,16 @@ def write_components(
def load_components(
- user_yaml: dict, name: str, namespace: str, local_queue: Optional[str]
+ user_yaml: dict,
+ name: str,
+ namespace: str,
+ local_queue: Optional[str],
+ labels: dict,
):
component_list = []
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
lq_name = local_queue or get_default_kueue_name(namespace)
+ cluster_labels = labels
for component in components:
if "generictemplate" in component:
if (
@@ -355,6 +366,7 @@ def load_components(
]
labels = component["generictemplate"]["metadata"]["labels"]
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
+ labels.update(cluster_labels)
component_list.append(component["generictemplate"])
resources = "---\n" + "---\n".join(
@@ -395,6 +407,7 @@ def generate_appwrapper(
write_to_file: bool,
verify_tls: bool,
local_queue: Optional[str],
+ labels,
):
user_yaml = read_template(template)
appwrapper_name, cluster_name = gen_names(name)
@@ -446,11 +459,11 @@ def generate_appwrapper(
if mcad:
write_user_appwrapper(user_yaml, outfile)
else:
- write_components(user_yaml, outfile, namespace, local_queue)
+ write_components(user_yaml, outfile, namespace, local_queue, labels)
return outfile
else:
if mcad:
user_yaml = load_appwrapper(user_yaml, name)
else:
- user_yaml = load_components(user_yaml, name, namespace, local_queue)
+ user_yaml = load_components(user_yaml, name, namespace, local_queue, labels)
return user_yaml
diff --git a/tests/test-case-no-mcad.yamls b/tests/test-case-no-mcad.yamls
index aaf9324e6..7fcf1fdc4 100644
--- a/tests/test-case-no-mcad.yamls
+++ b/tests/test-case-no-mcad.yamls
@@ -5,6 +5,8 @@ metadata:
labels:
controller-tools.k8s.io: '1.0'
kueue.x-k8s.io/queue-name: local-queue-default
+ testlabel: test
+ testlabel2: test
name: unit-test-cluster-ray
namespace: ns
spec:
diff --git a/tests/unit_test.py b/tests/unit_test.py
index 6f2ccee1c..53c888889 100644
--- a/tests/unit_test.py
+++ b/tests/unit_test.py
@@ -324,6 +324,7 @@ def test_cluster_creation_no_mcad(mocker):
config.name = "unit-test-cluster-ray"
config.write_to_file = True
config.mcad = False
+ config.labels = {"testlabel": "test", "testlabel2": "test"}
cluster = Cluster(config)
assert cluster.app_wrapper_yaml == f"{aw_dir}unit-test-cluster-ray.yaml"
@@ -348,6 +349,7 @@ def test_cluster_creation_no_mcad_local_queue(mocker):
config.mcad = False
config.write_to_file = True
config.local_queue = "local-queue-default"
+ config.labels = {"testlabel": "test", "testlabel2": "test"}
cluster = Cluster(config)
assert cluster.app_wrapper_yaml == f"{aw_dir}unit-test-cluster-ray.yaml"
assert cluster.app_wrapper_name == "unit-test-cluster-ray"
@@ -373,6 +375,7 @@ def test_cluster_creation_no_mcad_local_queue(mocker):
write_to_file=True,
mcad=False,
local_queue="local-queue-default",
+ labels={"testlabel": "test", "testlabel2": "test"},
)
cluster = Cluster(config)
assert cluster.app_wrapper_yaml == f"{aw_dir}unit-test-cluster-ray.yaml"