Skip to content

Commit b918957

Browse files
committed
Add Ray Cluster Upgrade test for Ray Job Long Running scenarios
1 parent 9a8603d commit b918957

File tree

2 files changed

+428
-0
lines changed

2 files changed

+428
-0
lines changed

tests/e2e/mnist_sleep.py

+253
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# Copyright 2022 IBM, Red Hat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import time
17+
import torch
18+
import requests
19+
from pytorch_lightning import LightningModule, Trainer
20+
from pytorch_lightning.callbacks.progress import TQDMProgressBar
21+
from torch import nn
22+
from torch.nn import functional as F
23+
from torch.utils.data import DataLoader, random_split, RandomSampler
24+
from torchmetrics import Accuracy
25+
from torchvision import transforms
26+
from torchvision.datasets import MNIST
27+
import gzip
28+
import shutil
29+
from minio import Minio
30+
31+
32+
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
33+
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
34+
35+
local_mnist_path = os.path.dirname(os.path.abspath(__file__))
36+
# %%
37+
38+
print("prior to running the trainer")
39+
print("MASTER_ADDR: is ", os.getenv("MASTER_ADDR"))
40+
print("MASTER_PORT: is ", os.getenv("MASTER_PORT"))
41+
42+
print("ACCELERATOR: is ", os.getenv("ACCELERATOR"))
43+
ACCELERATOR = os.getenv("ACCELERATOR")
44+
45+
STORAGE_BUCKET_EXISTS = "AWS_DEFAULT_ENDPOINT" in os.environ
46+
print("STORAGE_BUCKET_EXISTS: ", STORAGE_BUCKET_EXISTS)
47+
48+
print(
49+
f'Storage_Bucket_Default_Endpoint : is {os.environ.get("AWS_DEFAULT_ENDPOINT")}'
50+
if "AWS_DEFAULT_ENDPOINT" in os.environ
51+
else ""
52+
)
53+
print(
54+
f'Storage_Bucket_Name : is {os.environ.get("AWS_STORAGE_BUCKET")}'
55+
if "AWS_STORAGE_BUCKET" in os.environ
56+
else ""
57+
)
58+
print(
59+
f'Storage_Bucket_Mnist_Directory : is {os.environ.get("AWS_STORAGE_BUCKET_MNIST_DIR")}'
60+
if "AWS_STORAGE_BUCKET_MNIST_DIR" in os.environ
61+
else ""
62+
)
63+
64+
65+
class LitMNIST(LightningModule):
66+
def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
67+
super().__init__()
68+
69+
# Set our init args as class attributes
70+
self.data_dir = data_dir
71+
self.hidden_size = hidden_size
72+
self.learning_rate = learning_rate
73+
74+
# Hardcode some dataset specific attributes
75+
self.num_classes = 10
76+
self.dims = (1, 28, 28)
77+
channels, width, height = self.dims
78+
self.transform = transforms.Compose(
79+
[
80+
transforms.ToTensor(),
81+
transforms.Normalize((0.1307,), (0.3081,)),
82+
]
83+
)
84+
85+
# Define PyTorch model
86+
self.model = nn.Sequential(
87+
nn.Flatten(),
88+
nn.Linear(channels * width * height, hidden_size),
89+
nn.ReLU(),
90+
nn.Dropout(0.1),
91+
nn.Linear(hidden_size, hidden_size),
92+
nn.ReLU(),
93+
nn.Dropout(0.1),
94+
nn.Linear(hidden_size, self.num_classes),
95+
)
96+
97+
self.val_accuracy = Accuracy()
98+
self.test_accuracy = Accuracy()
99+
100+
def forward(self, x):
101+
x = self.model(x)
102+
return F.log_softmax(x, dim=1)
103+
104+
def training_step(self, batch, batch_idx):
105+
x, y = batch
106+
logits = self(x)
107+
loss = F.nll_loss(logits, y)
108+
return loss
109+
110+
def validation_step(self, batch, batch_idx):
111+
x, y = batch
112+
logits = self(x)
113+
loss = F.nll_loss(logits, y)
114+
preds = torch.argmax(logits, dim=1)
115+
self.val_accuracy.update(preds, y)
116+
117+
# Calling self.log will surface up scalars for you in TensorBoard
118+
self.log("val_loss", loss, prog_bar=True)
119+
self.log("val_acc", self.val_accuracy, prog_bar=True)
120+
121+
def test_step(self, batch, batch_idx):
122+
x, y = batch
123+
logits = self(x)
124+
loss = F.nll_loss(logits, y)
125+
preds = torch.argmax(logits, dim=1)
126+
self.test_accuracy.update(preds, y)
127+
128+
# Calling self.log will surface up scalars for you in TensorBoard
129+
self.log("test_loss", loss, prog_bar=True)
130+
self.log("test_acc", self.test_accuracy, prog_bar=True)
131+
132+
def configure_optimizers(self):
133+
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
134+
return optimizer
135+
136+
def on_train_start(self):
137+
# Sleeping for 24 hours for upgrade test scenario
138+
print("Sleeping for 24 hours before starting training...")
139+
time.sleep(86400)
140+
print("Waking up from sleep...")
141+
142+
####################
143+
# DATA RELATED HOOKS
144+
####################
145+
146+
def prepare_data(self):
147+
# download
148+
print("Downloading MNIST dataset...")
149+
150+
if (
151+
STORAGE_BUCKET_EXISTS
152+
and os.environ.get("AWS_DEFAULT_ENDPOINT") != ""
153+
and os.environ.get("AWS_DEFAULT_ENDPOINT") != None
154+
):
155+
print("Using storage bucket to download datasets...")
156+
157+
dataset_dir = os.path.join(self.data_dir, "MNIST/raw")
158+
endpoint = os.environ.get("AWS_DEFAULT_ENDPOINT")
159+
access_key = os.environ.get("AWS_ACCESS_KEY_ID")
160+
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
161+
bucket_name = os.environ.get("AWS_STORAGE_BUCKET")
162+
163+
client = Minio(
164+
endpoint,
165+
access_key=access_key,
166+
secret_key=secret_key,
167+
cert_check=False,
168+
)
169+
170+
if not os.path.exists(dataset_dir):
171+
os.makedirs(dataset_dir)
172+
else:
173+
print(f"Directory '{dataset_dir}' already exists")
174+
175+
# To download datasets from storage bucket's specific directory, use prefix to provide directory name
176+
prefix = os.environ.get("AWS_STORAGE_BUCKET_MNIST_DIR")
177+
# download all files from prefix folder of storage bucket recursively
178+
for item in client.list_objects(bucket_name, prefix=prefix, recursive=True):
179+
file_name = item.object_name[len(prefix) + 1 :]
180+
dataset_file_path = os.path.join(dataset_dir, file_name)
181+
if not os.path.exists(dataset_file_path):
182+
client.fget_object(bucket_name, item.object_name, dataset_file_path)
183+
else:
184+
print(f"File-path '{dataset_file_path}' already exists")
185+
# Unzip files
186+
with gzip.open(dataset_file_path, "rb") as f_in:
187+
with open(dataset_file_path.split(".")[:-1][0], "wb") as f_out:
188+
shutil.copyfileobj(f_in, f_out)
189+
# delete zip file
190+
os.remove(dataset_file_path)
191+
unzipped_filepath = dataset_file_path.split(".")[0]
192+
if os.path.exists(unzipped_filepath):
193+
print(
194+
f"Unzipped and saved dataset file to path - {unzipped_filepath}"
195+
)
196+
download_datasets = False
197+
198+
else:
199+
print("Using default MNIST mirror reference to download datasets...")
200+
download_datasets = True
201+
202+
MNIST(self.data_dir, train=True, download=download_datasets)
203+
MNIST(self.data_dir, train=False, download=download_datasets)
204+
205+
def setup(self, stage=None):
206+
# Assign train/val datasets for use in dataloaders
207+
if stage == "fit" or stage is None:
208+
mnist_full = MNIST(
209+
self.data_dir, train=True, transform=self.transform, download=False
210+
)
211+
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
212+
213+
# Assign test dataset for use in dataloader(s)
214+
if stage == "test" or stage is None:
215+
self.mnist_test = MNIST(
216+
self.data_dir, train=False, transform=self.transform, download=False
217+
)
218+
219+
def train_dataloader(self):
220+
return DataLoader(
221+
self.mnist_train,
222+
batch_size=BATCH_SIZE,
223+
sampler=RandomSampler(self.mnist_train, num_samples=1000),
224+
)
225+
226+
def val_dataloader(self):
227+
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
228+
229+
def test_dataloader(self):
230+
return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
231+
232+
233+
# Init DataLoader from MNIST Dataset
234+
235+
model = LitMNIST(data_dir=local_mnist_path)
236+
237+
print("GROUP: ", int(os.environ.get("GROUP_WORLD_SIZE", 1)))
238+
print("LOCAL: ", int(os.environ.get("LOCAL_WORLD_SIZE", 1)))
239+
240+
# Initialize a trainer
241+
trainer = Trainer(
242+
accelerator=ACCELERATOR,
243+
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
244+
max_epochs=3,
245+
callbacks=[TQDMProgressBar(refresh_rate=20)],
246+
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
247+
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
248+
replace_sampler_ddp=False,
249+
strategy="ddp",
250+
)
251+
252+
# Train the model ⚡
253+
trainer.fit(model)

0 commit comments

Comments
 (0)