Skip to content

Commit eace71e

Browse files
authored
FLI-based Worker Manager (#622)
This PR adds a simple `TorchWorker` which performs inference. The output transform is still not implemented, but that's something that it is not needed for the moment being. [ committed by @al-rigazzi ] [ reviewed by @AlyssaCote @ankona ]
1 parent 52abd32 commit eace71e

File tree

22 files changed

+1103
-108
lines changed

22 files changed

+1103
-108
lines changed

doc/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Jump to:
1313

1414
Description
1515

16+
- Add TorchWorker first implementation and mock inference app example
1617
- Add EnvironmentConfigLoader for ML Worker Manager
1718
- Add Model schema with model metadata included
1819
- Removed device from schemas, MessageHandler and tests
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
3+
import os
4+
import base64
5+
import cloudpickle
6+
import sys
7+
from smartsim import Experiment
8+
from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker
9+
from smartsim.status import TERMINAL_STATUSES
10+
import time
11+
import typing as t
12+
13+
device = "gpu"
14+
filedir = os.path.dirname(__file__)
15+
worker_manager_script_name = os.path.join(filedir, "standalone_workermanager.py")
16+
app_script_name = os.path.join(filedir, "mock_app.py")
17+
model_name = os.path.join(filedir, f"resnet50.{device.upper()}.pt")
18+
19+
transport: t.Literal["hsta", "tcp"] = "hsta"
20+
21+
os.environ["SMARTSIM_DRAGON_TRANSPORT"] = transport
22+
23+
exp_path = os.path.join(filedir, f"MLI_proto_{transport.upper()}")
24+
os.makedirs(exp_path, exist_ok=True)
25+
exp = Experiment("MLI_proto", launcher="dragon", exp_path=exp_path)
26+
27+
torch_worker_str = base64.b64encode(cloudpickle.dumps(TorchWorker)).decode("ascii")
28+
29+
worker_manager_rs = exp.create_run_settings(sys.executable, [worker_manager_script_name, "--device", device, "--worker_class", torch_worker_str])
30+
worker_manager = exp.create_model("worker_manager", run_settings=worker_manager_rs)
31+
worker_manager.attach_generator_files(to_copy=[worker_manager_script_name])
32+
33+
app_rs = exp.create_run_settings(sys.executable, exe_args = [app_script_name, "--device", device])
34+
app = exp.create_model("app", run_settings=app_rs)
35+
app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name])
36+
37+
38+
exp.generate(worker_manager, app, overwrite=True)
39+
exp.start(worker_manager, app, block=False)
40+
41+
while True:
42+
if exp.get_status(app)[0] in TERMINAL_STATUSES:
43+
exp.stop(worker_manager)
44+
break
45+
if exp.get_status(worker_manager)[0] in TERMINAL_STATUSES:
46+
exp.stop(app)
47+
break
48+
time.sleep(5)
49+
50+
print("Exiting.")
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# BSD 2-Clause License
2+
#
3+
# Copyright (c) 2021-2024, Hewlett Packard Enterprise
4+
# All rights reserved.
5+
#
6+
# Redistribution and use in source and binary forms, with or without
7+
# modification, are permitted provided that the following conditions are met:
8+
#
9+
# 1. Redistributions of source code must retain the above copyright notice, this
10+
# list of conditions and the following disclaimer.
11+
#
12+
# 2. Redistributions in binary form must reproduce the above copyright notice,
13+
# this list of conditions and the following disclaimer in the documentation
14+
# and/or other materials provided with the distribution.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
# isort: off
28+
import dragon
29+
from dragon import fli
30+
from dragon.channels import Channel
31+
import dragon.channels
32+
from dragon.data.ddict.ddict import DDict
33+
from dragon.globalservices.api_setup import connect_to_infrastructure
34+
from dragon.utils import b64decode, b64encode
35+
36+
# isort: on
37+
38+
import argparse
39+
import io
40+
import numpy
41+
import os
42+
import time
43+
import torch
44+
import numbers
45+
46+
from collections import OrderedDict
47+
from smartsim._core.mli.message_handler import MessageHandler
48+
from smartsim.log import get_logger
49+
50+
logger = get_logger("App")
51+
52+
class ProtoClient:
53+
def __init__(self, timing_on: bool):
54+
connect_to_infrastructure()
55+
ddict_str = os.environ["SS_DRG_DDICT"]
56+
self._ddict = DDict.attach(ddict_str)
57+
to_worker_fli_str = None
58+
while to_worker_fli_str is None:
59+
try:
60+
to_worker_fli_str = self._ddict["to_worker_fli"]
61+
self._to_worker_fli = fli.FLInterface.attach(to_worker_fli_str)
62+
except KeyError:
63+
time.sleep(1)
64+
self._from_worker_ch = Channel.make_process_local()
65+
self._from_worker_ch_serialized = self._from_worker_ch.serialize()
66+
self._to_worker_ch = Channel.make_process_local()
67+
68+
self._start = None
69+
self._interm = None
70+
self._timings: OrderedDict[str, list[numbers.Number]] = OrderedDict()
71+
self._timing_on = timing_on
72+
73+
def _add_label_to_timings(self, label: str):
74+
if label not in self._timings:
75+
self._timings[label] = []
76+
77+
@staticmethod
78+
def _format_number(number: numbers.Number):
79+
return f"{number:0.4e}"
80+
81+
def start_timings(self, batch_size: int):
82+
if self._timing_on:
83+
self._add_label_to_timings("batch_size")
84+
self._timings["batch_size"].append(batch_size)
85+
self._start = time.perf_counter()
86+
self._interm = time.perf_counter()
87+
88+
def end_timings(self):
89+
if self._timing_on:
90+
self._add_label_to_timings("total_time")
91+
self._timings["total_time"].append(self._format_number(time.perf_counter()-self._start))
92+
93+
def measure_time(self, label: str):
94+
if self._timing_on:
95+
self._add_label_to_timings(label)
96+
self._timings[label].append(self._format_number(time.perf_counter()-self._interm))
97+
self._interm = time.perf_counter()
98+
99+
def print_timings(self, to_file: bool = False):
100+
print(" ".join(self._timings.keys()))
101+
value_array = numpy.array([value for value in self._timings.values()], dtype=float)
102+
value_array = numpy.transpose(value_array)
103+
for i in range(value_array.shape[0]):
104+
print(" ".join(self._format_number(value) for value in value_array[i]))
105+
if to_file:
106+
numpy.save("timings.npy", value_array)
107+
numpy.savetxt("timings.txt", value_array)
108+
109+
110+
def run_model(self, model: bytes | str, batch: torch.Tensor):
111+
self.start_timings(batch.shape[0])
112+
built_tensor = MessageHandler.build_tensor(
113+
batch.numpy(), "c", "float32", list(batch.shape))
114+
self.measure_time("build_tensor")
115+
built_model = None
116+
if isinstance(model, str):
117+
model_arg = MessageHandler.build_model_key(model)
118+
else:
119+
model_arg = MessageHandler.build_model(model, "resnet-50", "1.0")
120+
request = MessageHandler.build_request(
121+
reply_channel=self._from_worker_ch_serialized,
122+
model= model_arg,
123+
inputs=[built_tensor],
124+
outputs=[],
125+
output_descriptors=[],
126+
custom_attributes=None,
127+
)
128+
self.measure_time("build_request")
129+
request_bytes = MessageHandler.serialize_request(request)
130+
self.measure_time("serialize_request")
131+
with self._to_worker_fli.sendh(timeout=None, stream_channel=self._to_worker_ch) as to_sendh:
132+
to_sendh.send_bytes(request_bytes)
133+
logger.info(f"Message size: {len(request_bytes)} bytes")
134+
135+
self.measure_time("send")
136+
with self._from_worker_ch.recvh(timeout=None) as from_recvh:
137+
resp = from_recvh.recv_bytes(timeout=None)
138+
self.measure_time("receive")
139+
response = MessageHandler.deserialize_response(resp)
140+
self.measure_time("deserialize_response")
141+
result = torch.from_numpy(
142+
numpy.frombuffer(
143+
response.result.data[0].blob,
144+
dtype=str(response.result.data[0].tensorDescriptor.dataType),
145+
)
146+
)
147+
self.measure_time("deserialize_tensor")
148+
149+
self.end_timings()
150+
return result
151+
152+
def set_model(self, key: str, model: bytes):
153+
self._ddict[key] = model
154+
155+
156+
class ResNetWrapper():
157+
def __init__(self, name: str, model: str):
158+
self._model = torch.jit.load(model)
159+
self._name = name
160+
buffer = io.BytesIO()
161+
scripted = torch.jit.trace(self._model, self.get_batch())
162+
torch.jit.save(scripted, buffer)
163+
self._serialized_model = buffer.getvalue()
164+
165+
def get_batch(self, batch_size: int=32):
166+
return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32)
167+
168+
@property
169+
def model(self):
170+
return self._serialized_model
171+
172+
@property
173+
def name(self):
174+
return self._name
175+
176+
if __name__ == "__main__":
177+
178+
parser = argparse.ArgumentParser("Mock application")
179+
parser.add_argument("--device", default="cpu")
180+
args = parser.parse_args()
181+
182+
resnet = ResNetWrapper("resnet50", f"resnet50.{args.device.upper()}.pt")
183+
184+
client = ProtoClient(timing_on=True)
185+
client.set_model(resnet.name, resnet.model)
186+
187+
total_iterations = 100
188+
189+
for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
190+
logger.info(f"Batch size: {batch_size}")
191+
for iteration_number in range(total_iterations + int(batch_size==1)):
192+
logger.info(f"Iteration: {iteration_number}")
193+
client.run_model(resnet.name, resnet.get_batch(batch_size))
194+
195+
client.print_timings(to_file=True)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# BSD 2-Clause License
2+
#
3+
# Copyright (c) 2021-2024, Hewlett Packard Enterprise
4+
# All rights reserved.
5+
#
6+
# Redistribution and use in source and binary forms, with or without
7+
# modification, are permitted provided that the following conditions are met:
8+
#
9+
# 1. Redistributions of source code must retain the above copyright notice, this
10+
# list of conditions and the following disclaimer.
11+
#
12+
# 2. Redistributions in binary form must reproduce the above copyright notice,
13+
# this list of conditions and the following disclaimer in the documentation
14+
# and/or other materials provided with the distribution.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
import argparse
28+
import io
29+
import numpy
30+
import time
31+
import torch
32+
from smartsim.log import get_logger
33+
from smartredis import Client
34+
35+
logger = get_logger("App")
36+
37+
class ResNetWrapper():
38+
def __init__(self, name: str, model: str):
39+
self._model = torch.jit.load(model)
40+
self._name = name
41+
buffer = io.BytesIO()
42+
scripted = torch.jit.trace(self._model, self.get_batch())
43+
torch.jit.save(scripted, buffer)
44+
self._serialized_model = buffer.getvalue()
45+
46+
def get_batch(self, batch_size: int=32):
47+
return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32)
48+
49+
@property
50+
def model(self):
51+
return self._serialized_model
52+
53+
@property
54+
def name(self):
55+
return self._name
56+
57+
if __name__ == "__main__":
58+
59+
parser = argparse.ArgumentParser("Mock application")
60+
parser.add_argument("--device", default="cpu")
61+
args = parser.parse_args()
62+
63+
resnet = ResNetWrapper("resnet50", f"resnet50.{args.device.upper()}.pt")
64+
65+
client = Client(cluster=False, address=None)
66+
client.set_model(resnet.name, resnet.model, backend='TORCH', device=args.device.upper())
67+
68+
total_iterations = 100
69+
timings=[]
70+
for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
71+
logger.info(f"Batch size: {batch_size}")
72+
for iteration_number in range(total_iterations + int(batch_size==1)):
73+
timing = [batch_size]
74+
logger.info(f"Iteration: {iteration_number}")
75+
start = time.perf_counter()
76+
client.put_tensor(name="batch", data=resnet.get_batch(batch_size).numpy())
77+
client.run_model(name=resnet.name, inputs=["batch"], outputs=["result"])
78+
result = client.get_tensor(name="result")
79+
end = time.perf_counter()
80+
timing.append(end-start)
81+
timings.append(timing)
82+
83+
84+
85+
timings_np = numpy.asarray(timings)
86+
numpy.save("timings.npy", timings_np)
87+
for timing in timings:
88+
print(" ".join(str(t) for t in timing))

0 commit comments

Comments
 (0)