Skip to content

Commit fe75256

Browse files
committed
wrap api client to add defaults
Signed-off-by: Kevin <[email protected]>
1 parent fb59ba6 commit fe75256

File tree

5 files changed

+60
-41
lines changed

5 files changed

+60
-41
lines changed

src/codeflare_sdk/cluster/auth.py

+32-13
Original file line numberDiff line numberDiff line change
@@ -119,25 +119,14 @@ def login(self) -> str:
119119
configuration.host = self.server
120120
configuration.api_key["authorization"] = self.token
121121

122+
api_client = client.ApiClient(configuration)
122123
if not self.skip_tls:
123-
if self.ca_cert_path is None:
124-
configuration.ssl_ca_cert = None
125-
elif os.path.isfile(self.ca_cert_path):
126-
print(
127-
f"Authenticated with certificate located at {self.ca_cert_path}"
128-
)
129-
configuration.ssl_ca_cert = self.ca_cert_path
130-
else:
131-
raise FileNotFoundError(
132-
f"Certificate file not found at {self.ca_cert_path}"
133-
)
134-
configuration.verify_ssl = True
124+
_client_with_cert(api_client, self.ca_cert_path)
135125
else:
136126
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
137127
print("Insecure request warnings have been disabled")
138128
configuration.verify_ssl = False
139129

140-
api_client = client.ApiClient(configuration)
141130
client.AuthenticationApi(api_client).get_api_group()
142131
config_path = None
143132
return "Logged into %s" % self.server
@@ -211,6 +200,36 @@ def config_check() -> str:
211200
return config_path
212201

213202

203+
def _client_with_cert(client: client.ApiClient, ca_cert_path: Optional[str] = None):
204+
cert_path = _gen_ca_cert_path(ca_cert_path)
205+
if os.path.isfile(cert_path):
206+
print(f"Authenticated with certificate located at {ca_cert_path}")
207+
client.configuration.ssl_ca_cert = ca_cert_path
208+
else:
209+
raise FileNotFoundError(f"Certificate file not found at {ca_cert_path}")
210+
211+
212+
def _gen_ca_cert_path(ca_cert_path: str):
213+
"""Gets the path to the default CA certificate file either through env config or default path"""
214+
if ca_cert_path is not None:
215+
return ca_cert_path
216+
elif "CF_SDK_CA_CERT_PATH" in os.environ:
217+
return os.environ.get("CF_SDK_CA_CERT_PATH")
218+
elif os.path.exists(WORKBENCH_CA_CERT_PATH):
219+
return WORKBENCH_CA_CERT_PATH
220+
else:
221+
return None
222+
223+
224+
def get_api_client() -> client.ApiClient:
225+
"This function should load the api client with defaults"
226+
if api_client != None:
227+
return api_client
228+
to_return = client.ApiClient()
229+
_client_with_cert(to_return)
230+
return to_return
231+
232+
214233
def api_config_handler() -> Optional[client.ApiClient]:
215234
"""
216235
This function is used to load the api client if the user has logged in

src/codeflare_sdk/cluster/awload.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from kubernetes import client, config
2626
from ..utils.kube_api_helpers import _kube_api_error_handling
27-
from .auth import config_check, api_config_handler
27+
from .auth import config_check, get_api_client
2828

2929

3030
class AWManager:
@@ -59,7 +59,7 @@ def submit(self) -> None:
5959
"""
6060
try:
6161
config_check()
62-
api_instance = client.CustomObjectsApi(api_config_handler())
62+
api_instance = client.CustomObjectsApi(get_api_client())
6363
api_instance.create_namespaced_custom_object(
6464
group="workload.codeflare.dev",
6565
version="v1beta2",
@@ -84,7 +84,7 @@ def remove(self) -> None:
8484

8585
try:
8686
config_check()
87-
api_instance = client.CustomObjectsApi(api_config_handler())
87+
api_instance = client.CustomObjectsApi(get_api_client())
8888
api_instance.delete_namespaced_custom_object(
8989
group="workload.codeflare.dev",
9090
version="v1beta2",

src/codeflare_sdk/cluster/cluster.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from kubernetes import config
2626
from ray.job_submission import JobSubmissionClient
2727

28-
from .auth import config_check, api_config_handler
28+
from .auth import config_check, get_api_client
2929
from ..utils import pretty_print
3030
from ..utils.generate_yaml import (
3131
generate_appwrapper,
@@ -74,7 +74,7 @@ def __init__(self, config: ClusterConfiguration):
7474

7575
@property
7676
def _client_headers(self):
77-
k8_client = api_config_handler() or client.ApiClient()
77+
k8_client = get_api_client()
7878
return {
7979
"Authorization": k8_client.configuration.get_api_key_with_prefix(
8080
"authorization"
@@ -89,7 +89,7 @@ def _client_verify_tls(self):
8989

9090
@property
9191
def job_client(self):
92-
k8client = api_config_handler() or client.ApiClient()
92+
k8client = get_api_client()
9393
if self._job_submission_client:
9494
return self._job_submission_client
9595
if is_openshift_cluster():
@@ -135,7 +135,7 @@ def up(self):
135135

136136
try:
137137
config_check()
138-
api_instance = client.CustomObjectsApi(api_config_handler())
138+
api_instance = client.CustomObjectsApi(get_api_client())
139139
if self.config.appwrapper:
140140
if self.config.write_to_file:
141141
with open(self.app_wrapper_yaml) as f:
@@ -162,7 +162,7 @@ def up(self):
162162
return _kube_api_error_handling(e)
163163

164164
def _throw_for_no_raycluster(self):
165-
api_instance = client.CustomObjectsApi(api_config_handler())
165+
api_instance = client.CustomObjectsApi(get_api_client())
166166
try:
167167
api_instance.list_namespaced_custom_object(
168168
group="ray.io",
@@ -189,7 +189,7 @@ def down(self):
189189
self._throw_for_no_raycluster()
190190
try:
191191
config_check()
192-
api_instance = client.CustomObjectsApi(api_config_handler())
192+
api_instance = client.CustomObjectsApi(get_api_client())
193193
if self.config.appwrapper:
194194
api_instance.delete_namespaced_custom_object(
195195
group="workload.codeflare.dev",
@@ -344,7 +344,7 @@ def cluster_dashboard_uri(self) -> str:
344344
config_check()
345345
if is_openshift_cluster():
346346
try:
347-
api_instance = client.CustomObjectsApi(api_config_handler())
347+
api_instance = client.CustomObjectsApi(get_api_client())
348348
routes = api_instance.list_namespaced_custom_object(
349349
group="route.openshift.io",
350350
version="v1",
@@ -366,7 +366,7 @@ def cluster_dashboard_uri(self) -> str:
366366
return f"{protocol}://{route['spec']['host']}"
367367
else:
368368
try:
369-
api_instance = client.NetworkingV1Api(api_config_handler())
369+
api_instance = client.NetworkingV1Api(get_api_client())
370370
ingresses = api_instance.list_namespaced_ingress(self.config.namespace)
371371
except Exception as e: # pragma no cover
372372
return _kube_api_error_handling(e)
@@ -546,7 +546,7 @@ def list_all_queued(
546546

547547

548548
def get_current_namespace(): # pragma: no cover
549-
if api_config_handler() != None:
549+
if get_api_client() != None:
550550
if os.path.isfile("/var/run/secrets/kubernetes.io/serviceaccount/namespace"):
551551
try:
552552
file = open(
@@ -591,7 +591,7 @@ def get_cluster(
591591
):
592592
try:
593593
config_check()
594-
api_instance = client.CustomObjectsApi(api_config_handler())
594+
api_instance = client.CustomObjectsApi(get_api_client())
595595
rcs = api_instance.list_namespaced_custom_object(
596596
group="ray.io",
597597
version="v1",
@@ -646,7 +646,7 @@ def _create_resources(yamls, namespace: str, api_instance: client.CustomObjectsA
646646
def _check_aw_exists(name: str, namespace: str) -> bool:
647647
try:
648648
config_check()
649-
api_instance = client.CustomObjectsApi(api_config_handler())
649+
api_instance = client.CustomObjectsApi(get_api_client())
650650
aws = api_instance.list_namespaced_custom_object(
651651
group="workload.codeflare.dev",
652652
version="v1beta2",
@@ -673,7 +673,7 @@ def _get_ingress_domain(self): # pragma: no cover
673673

674674
if is_openshift_cluster():
675675
try:
676-
api_instance = client.CustomObjectsApi(api_config_handler())
676+
api_instance = client.CustomObjectsApi(get_api_client())
677677

678678
routes = api_instance.list_namespaced_custom_object(
679679
group="route.openshift.io",
@@ -692,7 +692,7 @@ def _get_ingress_domain(self): # pragma: no cover
692692
domain = route["spec"]["host"]
693693
else:
694694
try:
695-
api_client = client.NetworkingV1Api(api_config_handler())
695+
api_client = client.NetworkingV1Api(get_api_client())
696696
ingresses = api_client.list_namespaced_ingress(namespace)
697697
except Exception as e: # pragma: no cover
698698
return _kube_api_error_handling(e)
@@ -706,7 +706,7 @@ def _get_ingress_domain(self): # pragma: no cover
706706
def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]:
707707
try:
708708
config_check()
709-
api_instance = client.CustomObjectsApi(api_config_handler())
709+
api_instance = client.CustomObjectsApi(get_api_client())
710710
aws = api_instance.list_namespaced_custom_object(
711711
group="workload.codeflare.dev",
712712
version="v1beta2",
@@ -725,7 +725,7 @@ def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]:
725725
def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]:
726726
try:
727727
config_check()
728-
api_instance = client.CustomObjectsApi(api_config_handler())
728+
api_instance = client.CustomObjectsApi(get_api_client())
729729
rcs = api_instance.list_namespaced_custom_object(
730730
group="ray.io",
731731
version="v1",
@@ -747,7 +747,7 @@ def _get_ray_clusters(
747747
list_of_clusters = []
748748
try:
749749
config_check()
750-
api_instance = client.CustomObjectsApi(api_config_handler())
750+
api_instance = client.CustomObjectsApi(get_api_client())
751751
rcs = api_instance.list_namespaced_custom_object(
752752
group="ray.io",
753753
version="v1",
@@ -776,7 +776,7 @@ def _get_app_wrappers(
776776

777777
try:
778778
config_check()
779-
api_instance = client.CustomObjectsApi(api_config_handler())
779+
api_instance = client.CustomObjectsApi(get_api_client())
780780
aws = api_instance.list_namespaced_custom_object(
781781
group="workload.codeflare.dev",
782782
version="v1beta2",
@@ -805,7 +805,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
805805
dashboard_url = None
806806
if is_openshift_cluster():
807807
try:
808-
api_instance = client.CustomObjectsApi(api_config_handler())
808+
api_instance = client.CustomObjectsApi(get_api_client())
809809
routes = api_instance.list_namespaced_custom_object(
810810
group="route.openshift.io",
811811
version="v1",
@@ -824,7 +824,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
824824
dashboard_url = f"{protocol}://{route['spec']['host']}"
825825
else:
826826
try:
827-
api_instance = client.NetworkingV1Api(api_config_handler())
827+
api_instance = client.NetworkingV1Api(get_api_client())
828828
ingresses = api_instance.list_namespaced_ingress(
829829
rc["metadata"]["namespace"]
830830
)

src/codeflare_sdk/utils/generate_cert.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from cryptography import x509
2020
from cryptography.x509.oid import NameOID
2121
import datetime
22-
from ..cluster.auth import config_check, api_config_handler
22+
from ..cluster.auth import config_check, get_api_client
2323
from kubernetes import client, config
2424
from .kube_api_helpers import _kube_api_error_handling
2525

@@ -103,7 +103,7 @@ def generate_tls_cert(cluster_name, namespace, days=30):
103103
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.key"}}'
104104
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.crt"}}'|base64 -d > ${TLSDIR}/ca.crt
105105
config_check()
106-
v1 = client.CoreV1Api(api_config_handler())
106+
v1 = client.CoreV1Api(get_api_client())
107107

108108
# Secrets have a suffix appended to the end so we must list them and gather the secret that includes cluster_name-ca-secret-
109109
secret_name = get_secret_name(cluster_name, namespace, v1)

src/codeflare_sdk/utils/generate_yaml.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import uuid
2828
from kubernetes import client, config
2929
from .kube_api_helpers import _kube_api_error_handling
30-
from ..cluster.auth import api_config_handler, config_check
30+
from ..cluster.auth import get_api_client, config_check
3131
from os import urandom
3232
from base64 import b64encode
3333
from urllib3.util import parse_url
@@ -57,7 +57,7 @@ def gen_names(name):
5757
def is_openshift_cluster():
5858
try:
5959
config_check()
60-
for api in client.ApisApi(api_config_handler()).get_api_versions().groups:
60+
for api in client.ApisApi(get_api_client()).get_api_versions().groups:
6161
for v in api.versions:
6262
if "route.openshift.io/v1" in v.group_version:
6363
return True
@@ -235,7 +235,7 @@ def get_default_kueue_name(namespace: str):
235235
# If the local queue is set, use it. Otherwise, try to use the default queue.
236236
try:
237237
config_check()
238-
api_instance = client.CustomObjectsApi(api_config_handler())
238+
api_instance = client.CustomObjectsApi(get_api_client())
239239
local_queues = api_instance.list_namespaced_custom_object(
240240
group="kueue.x-k8s.io",
241241
version="v1beta1",
@@ -261,7 +261,7 @@ def local_queue_exists(namespace: str, local_queue_name: str):
261261
# get all local queues in the namespace
262262
try:
263263
config_check()
264-
api_instance = client.CustomObjectsApi(api_config_handler())
264+
api_instance = client.CustomObjectsApi(get_api_client())
265265
local_queues = api_instance.list_namespaced_custom_object(
266266
group="kueue.x-k8s.io",
267267
version="v1beta1",

0 commit comments

Comments
 (0)