Skip to content

Commit 97436b4

Browse files
committed
wrap api client to add defaults
Signed-off-by: Kevin <[email protected]>
1 parent 1235fc8 commit 97436b4

File tree

6 files changed

+69
-60
lines changed

6 files changed

+69
-60
lines changed

src/codeflare_sdk/cluster/auth.py

+35-24
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,7 @@ def __init__(
9393
self.token = token
9494
self.server = server
9595
self.skip_tls = skip_tls
96-
self.ca_cert_path = self._gen_ca_cert_path(ca_cert_path)
97-
98-
def _gen_ca_cert_path(self, ca_cert_path: str):
99-
if ca_cert_path is not None:
100-
return ca_cert_path
101-
elif "CF_SDK_CA_CERT_PATH" in os.environ:
102-
return os.environ.get("CF_SDK_CA_CERT_PATH")
103-
elif os.path.exists(WORKBENCH_CA_CERT_PATH):
104-
return WORKBENCH_CA_CERT_PATH
105-
else:
106-
return None
96+
self.ca_cert_path = _gen_ca_cert_path(ca_cert_path)
10797

10898
def login(self) -> str:
10999
"""
@@ -119,25 +109,14 @@ def login(self) -> str:
119109
configuration.host = self.server
120110
configuration.api_key["authorization"] = self.token
121111

112+
api_client = client.ApiClient(configuration)
122113
if not self.skip_tls:
123-
if self.ca_cert_path is None:
124-
configuration.ssl_ca_cert = None
125-
elif os.path.isfile(self.ca_cert_path):
126-
print(
127-
f"Authenticated with certificate located at {self.ca_cert_path}"
128-
)
129-
configuration.ssl_ca_cert = self.ca_cert_path
130-
else:
131-
raise FileNotFoundError(
132-
f"Certificate file not found at {self.ca_cert_path}"
133-
)
134-
configuration.verify_ssl = True
114+
_client_with_cert(api_client, self.ca_cert_path)
135115
else:
136116
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
137117
print("Insecure request warnings have been disabled")
138118
configuration.verify_ssl = False
139119

140-
api_client = client.ApiClient(configuration)
141120
client.AuthenticationApi(api_client).get_api_group()
142121
config_path = None
143122
return "Logged into %s" % self.server
@@ -211,6 +190,38 @@ def config_check() -> str:
211190
return config_path
212191

213192

193+
def _client_with_cert(client: client.ApiClient, ca_cert_path: Optional[str] = None):
194+
client.configuration.verify_ssl = True
195+
cert_path = _gen_ca_cert_path(ca_cert_path)
196+
if cert_path is None:
197+
client.configuration.ssl_ca_cert = None
198+
elif os.path.isfile(cert_path):
199+
client.configuration.ssl_ca_cert = ca_cert_path
200+
else:
201+
raise FileNotFoundError(f"Certificate file not found at {cert_path}")
202+
203+
204+
def _gen_ca_cert_path(ca_cert_path: str):
205+
"""Gets the path to the default CA certificate file either through env config or default path"""
206+
if ca_cert_path is not None:
207+
return ca_cert_path
208+
elif "CF_SDK_CA_CERT_PATH" in os.environ:
209+
return os.environ.get("CF_SDK_CA_CERT_PATH")
210+
elif os.path.exists(WORKBENCH_CA_CERT_PATH):
211+
return WORKBENCH_CA_CERT_PATH
212+
else:
213+
return None
214+
215+
216+
def get_api_client() -> client.ApiClient:
217+
"This function should load the api client with defaults"
218+
if api_client != None:
219+
return api_client
220+
to_return = client.ApiClient()
221+
_client_with_cert(to_return)
222+
return to_return
223+
224+
214225
def api_config_handler() -> Optional[client.ApiClient]:
215226
"""
216227
This function is used to load the api client if the user has logged in

src/codeflare_sdk/cluster/awload.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from kubernetes import client, config
2626
from ..utils.kube_api_helpers import _kube_api_error_handling
27-
from .auth import config_check, api_config_handler
27+
from .auth import config_check, get_api_client
2828

2929

3030
class AWManager:
@@ -59,7 +59,7 @@ def submit(self) -> None:
5959
"""
6060
try:
6161
config_check()
62-
api_instance = client.CustomObjectsApi(api_config_handler())
62+
api_instance = client.CustomObjectsApi(get_api_client())
6363
api_instance.create_namespaced_custom_object(
6464
group="workload.codeflare.dev",
6565
version="v1beta2",
@@ -84,7 +84,7 @@ def remove(self) -> None:
8484

8585
try:
8686
config_check()
87-
api_instance = client.CustomObjectsApi(api_config_handler())
87+
api_instance = client.CustomObjectsApi(get_api_client())
8888
api_instance.delete_namespaced_custom_object(
8989
group="workload.codeflare.dev",
9090
version="v1beta2",

src/codeflare_sdk/cluster/cluster.py

+18-21
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from kubernetes import config
2626
from ray.job_submission import JobSubmissionClient
2727

28-
from .auth import config_check, api_config_handler
28+
from .auth import config_check, get_api_client
2929
from ..utils import pretty_print
3030
from ..utils.generate_yaml import (
3131
generate_appwrapper,
@@ -80,7 +80,7 @@ def __init__(self, config: ClusterConfiguration):
8080

8181
@property
8282
def _client_headers(self):
83-
k8_client = api_config_handler() or client.ApiClient()
83+
k8_client = get_api_client()
8484
return {
8585
"Authorization": k8_client.configuration.get_api_key_with_prefix(
8686
"authorization"
@@ -95,7 +95,7 @@ def _client_verify_tls(self):
9595

9696
@property
9797
def job_client(self):
98-
k8client = api_config_handler() or client.ApiClient()
98+
k8client = get_api_client()
9999
if self._job_submission_client:
100100
return self._job_submission_client
101101
if is_openshift_cluster():
@@ -141,7 +141,7 @@ def up(self):
141141

142142
try:
143143
config_check()
144-
api_instance = client.CustomObjectsApi(api_config_handler())
144+
api_instance = client.CustomObjectsApi(get_api_client())
145145
if self.config.appwrapper:
146146
if self.config.write_to_file:
147147
with open(self.app_wrapper_yaml) as f:
@@ -172,7 +172,7 @@ def up(self):
172172
return _kube_api_error_handling(e)
173173

174174
def _throw_for_no_raycluster(self):
175-
api_instance = client.CustomObjectsApi(api_config_handler())
175+
api_instance = client.CustomObjectsApi(get_api_client())
176176
try:
177177
api_instance.list_namespaced_custom_object(
178178
group="ray.io",
@@ -199,7 +199,7 @@ def down(self):
199199
self._throw_for_no_raycluster()
200200
try:
201201
config_check()
202-
api_instance = client.CustomObjectsApi(api_config_handler())
202+
api_instance = client.CustomObjectsApi(get_api_client())
203203
if self.config.appwrapper:
204204
api_instance.delete_namespaced_custom_object(
205205
group="workload.codeflare.dev",
@@ -358,7 +358,7 @@ def cluster_dashboard_uri(self) -> str:
358358
config_check()
359359
if is_openshift_cluster():
360360
try:
361-
api_instance = client.CustomObjectsApi(api_config_handler())
361+
api_instance = client.CustomObjectsApi(get_api_client())
362362
routes = api_instance.list_namespaced_custom_object(
363363
group="route.openshift.io",
364364
version="v1",
@@ -380,7 +380,7 @@ def cluster_dashboard_uri(self) -> str:
380380
return f"{protocol}://{route['spec']['host']}"
381381
else:
382382
try:
383-
api_instance = client.NetworkingV1Api(api_config_handler())
383+
api_instance = client.NetworkingV1Api(get_api_client())
384384
ingresses = api_instance.list_namespaced_ingress(self.config.namespace)
385385
except Exception as e: # pragma no cover
386386
return _kube_api_error_handling(e)
@@ -579,9 +579,6 @@ def get_current_namespace(): # pragma: no cover
579579
return active_context
580580
except Exception as e:
581581
print("Unable to find current namespace")
582-
583-
if api_config_handler() != None:
584-
return None
585582
print("trying to gather from current context")
586583
try:
587584
_, active_context = config.list_kube_config_contexts(config_check())
@@ -601,7 +598,7 @@ def get_cluster(
601598
):
602599
try:
603600
config_check()
604-
api_instance = client.CustomObjectsApi(api_config_handler())
601+
api_instance = client.CustomObjectsApi(get_api_client())
605602
rcs = api_instance.list_namespaced_custom_object(
606603
group="ray.io",
607604
version="v1",
@@ -656,7 +653,7 @@ def _create_resources(yamls, namespace: str, api_instance: client.CustomObjectsA
656653
def _check_aw_exists(name: str, namespace: str) -> bool:
657654
try:
658655
config_check()
659-
api_instance = client.CustomObjectsApi(api_config_handler())
656+
api_instance = client.CustomObjectsApi(get_api_client())
660657
aws = api_instance.list_namespaced_custom_object(
661658
group="workload.codeflare.dev",
662659
version="v1beta2",
@@ -683,7 +680,7 @@ def _get_ingress_domain(self): # pragma: no cover
683680

684681
if is_openshift_cluster():
685682
try:
686-
api_instance = client.CustomObjectsApi(api_config_handler())
683+
api_instance = client.CustomObjectsApi(get_api_client())
687684

688685
routes = api_instance.list_namespaced_custom_object(
689686
group="route.openshift.io",
@@ -702,7 +699,7 @@ def _get_ingress_domain(self): # pragma: no cover
702699
domain = route["spec"]["host"]
703700
else:
704701
try:
705-
api_client = client.NetworkingV1Api(api_config_handler())
702+
api_client = client.NetworkingV1Api(get_api_client())
706703
ingresses = api_client.list_namespaced_ingress(namespace)
707704
except Exception as e: # pragma: no cover
708705
return _kube_api_error_handling(e)
@@ -716,7 +713,7 @@ def _get_ingress_domain(self): # pragma: no cover
716713
def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]:
717714
try:
718715
config_check()
719-
api_instance = client.CustomObjectsApi(api_config_handler())
716+
api_instance = client.CustomObjectsApi(get_api_client())
720717
aws = api_instance.list_namespaced_custom_object(
721718
group="workload.codeflare.dev",
722719
version="v1beta2",
@@ -735,7 +732,7 @@ def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]:
735732
def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]:
736733
try:
737734
config_check()
738-
api_instance = client.CustomObjectsApi(api_config_handler())
735+
api_instance = client.CustomObjectsApi(get_api_client())
739736
rcs = api_instance.list_namespaced_custom_object(
740737
group="ray.io",
741738
version="v1",
@@ -757,7 +754,7 @@ def _get_ray_clusters(
757754
list_of_clusters = []
758755
try:
759756
config_check()
760-
api_instance = client.CustomObjectsApi(api_config_handler())
757+
api_instance = client.CustomObjectsApi(get_api_client())
761758
rcs = api_instance.list_namespaced_custom_object(
762759
group="ray.io",
763760
version="v1",
@@ -786,7 +783,7 @@ def _get_app_wrappers(
786783

787784
try:
788785
config_check()
789-
api_instance = client.CustomObjectsApi(api_config_handler())
786+
api_instance = client.CustomObjectsApi(get_api_client())
790787
aws = api_instance.list_namespaced_custom_object(
791788
group="workload.codeflare.dev",
792789
version="v1beta2",
@@ -815,7 +812,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
815812
dashboard_url = None
816813
if is_openshift_cluster():
817814
try:
818-
api_instance = client.CustomObjectsApi(api_config_handler())
815+
api_instance = client.CustomObjectsApi(get_api_client())
819816
routes = api_instance.list_namespaced_custom_object(
820817
group="route.openshift.io",
821818
version="v1",
@@ -834,7 +831,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
834831
dashboard_url = f"{protocol}://{route['spec']['host']}"
835832
else:
836833
try:
837-
api_instance = client.NetworkingV1Api(api_config_handler())
834+
api_instance = client.NetworkingV1Api(get_api_client())
838835
ingresses = api_instance.list_namespaced_ingress(
839836
rc["metadata"]["namespace"]
840837
)

src/codeflare_sdk/utils/generate_cert.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from cryptography import x509
2020
from cryptography.x509.oid import NameOID
2121
import datetime
22-
from ..cluster.auth import config_check, api_config_handler
22+
from ..cluster.auth import config_check, get_api_client
2323
from kubernetes import client, config
2424
from .kube_api_helpers import _kube_api_error_handling
2525

@@ -103,7 +103,7 @@ def generate_tls_cert(cluster_name, namespace, days=30):
103103
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.key"}}'
104104
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.crt"}}'|base64 -d > ${TLSDIR}/ca.crt
105105
config_check()
106-
v1 = client.CoreV1Api(api_config_handler())
106+
v1 = client.CoreV1Api(get_api_client())
107107

108108
# Secrets have a suffix appended to the end so we must list them and gather the secret that includes cluster_name-ca-secret-
109109
secret_name = get_secret_name(cluster_name, namespace, v1)

src/codeflare_sdk/utils/generate_yaml.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import uuid
2828
from kubernetes import client, config
2929
from .kube_api_helpers import _kube_api_error_handling
30-
from ..cluster.auth import api_config_handler, config_check
30+
from ..cluster.auth import get_api_client, config_check
3131
from os import urandom
3232
from base64 import b64encode
3333
from urllib3.util import parse_url
@@ -57,7 +57,7 @@ def gen_names(name):
5757
def is_openshift_cluster():
5858
try:
5959
config_check()
60-
for api in client.ApisApi(api_config_handler()).get_api_versions().groups:
60+
for api in client.ApisApi(get_api_client()).get_api_versions().groups:
6161
for v in api.versions:
6262
if "route.openshift.io/v1" in v.group_version:
6363
return True
@@ -235,7 +235,7 @@ def get_default_kueue_name(namespace: str):
235235
# If the local queue is set, use it. Otherwise, try to use the default queue.
236236
try:
237237
config_check()
238-
api_instance = client.CustomObjectsApi(api_config_handler())
238+
api_instance = client.CustomObjectsApi(get_api_client())
239239
local_queues = api_instance.list_namespaced_custom_object(
240240
group="kueue.x-k8s.io",
241241
version="v1beta1",
@@ -261,7 +261,7 @@ def local_queue_exists(namespace: str, local_queue_name: str):
261261
# get all local queues in the namespace
262262
try:
263263
config_check()
264-
api_instance = client.CustomObjectsApi(api_config_handler())
264+
api_instance = client.CustomObjectsApi(get_api_client())
265265
local_queues = api_instance.list_namespaced_custom_object(
266266
group="kueue.x-k8s.io",
267267
version="v1beta1",

0 commit comments

Comments
 (0)