@@ -25,9 +25,18 @@ def test_mnist_ray_cluster_sdk_kind(self):
25
25
self .setup_method ()
26
26
create_namespace (self )
27
27
create_kueue_resources (self )
28
- self .run_mnist_raycluster_sdk_kind ()
28
+ self .run_mnist_raycluster_sdk_kind (accelerator = "cpu" )
29
29
30
- def run_mnist_raycluster_sdk_kind (self ):
30
+ @pytest .mark .nvidia_gpu
31
+ def test_mnist_ray_cluster_sdk_kind_nvidia_gpu (self ):
32
+ self .setup_method ()
33
+ create_namespace (self )
34
+ create_kueue_resources (self )
35
+ self .run_mnist_raycluster_sdk_kind (accelerator = "gpu" , number_of_gpus = 1 )
36
+
37
+ def run_mnist_raycluster_sdk_kind (
38
+ self , accelerator , gpu_resource_name = "nvidia.com/gpu" , number_of_gpus = 0
39
+ ):
31
40
ray_image = get_ray_image ()
32
41
33
42
cluster = Cluster (
@@ -40,7 +49,8 @@ def run_mnist_raycluster_sdk_kind(self):
40
49
worker_cpu_requests = "500m" ,
41
50
worker_cpu_limits = 1 ,
42
51
worker_memory_requests = 1 ,
43
- worker_memory_limits = 2 ,
52
+ worker_memory_limits = 4 ,
53
+ worker_extended_resource_requests = {gpu_resource_name : number_of_gpus },
44
54
image = ray_image ,
45
55
write_to_file = True ,
46
56
verify_tls = False ,
@@ -57,11 +67,11 @@ def run_mnist_raycluster_sdk_kind(self):
57
67
58
68
cluster .details ()
59
69
60
- self .assert_jobsubmit_withoutlogin_kind (cluster )
70
+ self .assert_jobsubmit_withoutlogin_kind (cluster , accelerator , number_of_gpus )
61
71
62
72
# Assertions
63
73
64
- def assert_jobsubmit_withoutlogin_kind (self , cluster ):
74
+ def assert_jobsubmit_withoutlogin_kind (self , cluster , accelerator , number_of_gpus ):
65
75
ray_dashboard = cluster .cluster_dashboard_uri ()
66
76
client = RayJobClient (address = ray_dashboard , verify = False )
67
77
@@ -70,7 +80,9 @@ def assert_jobsubmit_withoutlogin_kind(self, cluster):
70
80
runtime_env = {
71
81
"working_dir" : "./tests/e2e/" ,
72
82
"pip" : "./tests/e2e/mnist_pip_requirements.txt" ,
83
+ "env_vars" : {"ACCELERATOR" : accelerator },
73
84
},
85
+ entrypoint_num_gpus = number_of_gpus ,
74
86
)
75
87
print (f"Submitted job with ID: { submission_id } " )
76
88
done = False
0 commit comments