Skip to content

Commit 046a86f

Browse files
committed
Change Default namespace logic to use user's current namespace
Signed-off-by: Anish Asthana <[email protected]>
1 parent 45aae2a commit 046a86f

File tree

3 files changed

+46
-18
lines changed

3 files changed

+46
-18
lines changed

src/codeflare_sdk/cluster/cluster.py

+8
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ def create_app_wrapper(self):
6363
Called upon cluster object creation, creates an AppWrapper yaml based on
6464
the specifications of the ClusterConfiguration.
6565
"""
66+
67+
if self.config.namespace is None:
68+
self.config.namespace = get_current_namespace()
69+
if type(self.config.namespace) is not str:
70+
raise TypeError(
71+
f"Namespace {self.config.namespace} is of type {type(self.config.namespace)}. Check your Kubernetes Authentication."
72+
)
73+
6674
name = self.config.name
6775
namespace = self.config.namespace
6876
min_cpu = self.config.min_cpus

src/codeflare_sdk/cluster/config.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from dataclasses import dataclass, field
2222
from .auth import Authentication
2323
import pathlib
24+
import openshift
2425

2526
dir = pathlib.Path(__file__).parent.parent.resolve()
2627

@@ -33,7 +34,7 @@ class ClusterConfiguration:
3334
"""
3435

3536
name: str
36-
namespace: str = "default"
37+
namespace: str = None
3738
head_info: list = field(default_factory=list)
3839
machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
3940
min_cpus: int = 1

tests/unit_test.py

+36-17
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,23 @@ def test_cluster_creation():
240240
return cluster
241241

242242

243+
def test_default_cluster_creation(mocker):
244+
mocker.patch(
245+
"codeflare_sdk.cluster.cluster.get_current_namespace",
246+
return_value="opendatahub",
247+
)
248+
default_config = ClusterConfiguration(
249+
name="unit-test-default-cluster",
250+
)
251+
cluster = Cluster(default_config)
252+
253+
assert cluster.app_wrapper_yaml == "unit-test-default-cluster.yaml"
254+
assert cluster.app_wrapper_name == "unit-test-default-cluster"
255+
assert cluster.config.namespace == "opendatahub"
256+
257+
return cluster
258+
259+
243260
def arg_check_apply_effect(*args):
244261
assert args[0] == "apply"
245262
assert args[1] == ["-f", "unit-test-cluster.yaml"]
@@ -1593,22 +1610,6 @@ def test_wait_ready(mocker, capsys):
15931610
)
15941611

15951612

1596-
def test_cmd_line_generation():
1597-
os.system(
1598-
f"python3 {parent}/src/codeflare_sdk/utils/generate_yaml.py --name=unit-cmd-cluster --min-cpu=1 --max-cpu=1 --min-memory=2 --max-memory=2 --gpu=1 --workers=2 --template=src/codeflare_sdk/templates/new-template.yaml"
1599-
)
1600-
assert filecmp.cmp(
1601-
"unit-cmd-cluster.yaml", f"{parent}/tests/test-case-cmd.yaml", shallow=True
1602-
)
1603-
os.remove("unit-test-cluster.yaml")
1604-
os.remove("unit-cmd-cluster.yaml")
1605-
1606-
1607-
def test_cleanup():
1608-
os.remove("test.yaml")
1609-
os.remove("raytest2.yaml")
1610-
1611-
16121613
def test_jobdefinition_coverage():
16131614
abstract = JobDefinition()
16141615
cluster = Cluster(test_config_creation())
@@ -1673,7 +1674,6 @@ def test_DDPJobDefinition_dry_run():
16731674
assert type(ddp_job._scheduler) == type(str())
16741675

16751676
assert ddp_job.request.app_id.startswith("test")
1676-
assert ddp_job.request.working_dir.startswith("/tmp/torchx_workspace")
16771677
assert ddp_job.request.cluster_name == "unit-test-cluster"
16781678
assert ddp_job.request.requirements == "test"
16791679

@@ -1916,3 +1916,22 @@ def parse_j(cmd):
19161916
max_worker = args[1]
19171917
gpu = args[3]
19181918
return f"{max_worker}x{gpu}"
1919+
1920+
1921+
# Make sure to keep this function and the efollowing function at the end of the file
1922+
def test_cmd_line_generation():
1923+
os.system(
1924+
f"python3 {parent}/src/codeflare_sdk/utils/generate_yaml.py --name=unit-cmd-cluster --min-cpu=1 --max-cpu=1 --min-memory=2 --max-memory=2 --gpu=1 --workers=2 --template=src/codeflare_sdk/templates/new-template.yaml"
1925+
)
1926+
assert filecmp.cmp(
1927+
"unit-cmd-cluster.yaml", f"{parent}/tests/test-case-cmd.yaml", shallow=True
1928+
)
1929+
os.remove("unit-test-cluster.yaml")
1930+
os.remove("unit-test-default-cluster.yaml")
1931+
os.remove("unit-cmd-cluster.yaml")
1932+
1933+
1934+
# Make sure to always keep this function last
1935+
def test_cleanup():
1936+
os.remove("test.yaml")
1937+
os.remove("raytest2.yaml")

0 commit comments

Comments
 (0)