1+ # BSD 2-Clause License
2+ #
3+ # Copyright (c) 2021-2024, Hewlett Packard Enterprise
4+ # All rights reserved.
5+ #
6+ # Redistribution and use in source and binary forms, with or without
7+ # modification, are permitted provided that the following conditions are met:
8+ #
9+ # 1. Redistributions of source code must retain the above copyright notice, this
10+ # list of conditions and the following disclaimer.
11+ #
12+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13+ # this list of conditions and the following disclaimer in the documentation
14+ # and/or other materials provided with the distribution.
15+ #
16+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+ # isort: off
28+ import dragon
29+ from dragon import fli
30+ from dragon .channels import Channel
31+ import dragon .channels
32+ from dragon .data .ddict .ddict import DDict
33+ from dragon .globalservices .api_setup import connect_to_infrastructure
34+ from dragon .utils import b64decode , b64encode
35+
36+ # isort: on
37+
38+ import argparse
39+ import io
40+ import numpy
41+ import os
42+ import time
43+ import torch
44+ import numbers
45+
46+ from collections import OrderedDict
47+ from smartsim ._core .mli .message_handler import MessageHandler
48+ from smartsim .log import get_logger
49+
50+ logger = get_logger ("App" )
51+
52+ class ProtoClient :
53+ def __init__ (self , timing_on : bool ):
54+ connect_to_infrastructure ()
55+ ddict_str = os .environ ["SS_DRG_DDICT" ]
56+ self ._ddict = DDict .attach (ddict_str )
57+ to_worker_fli_str = None
58+ while to_worker_fli_str is None :
59+ try :
60+ to_worker_fli_str = self ._ddict ["to_worker_fli" ]
61+ self ._to_worker_fli = fli .FLInterface .attach (to_worker_fli_str )
62+ except KeyError :
63+ time .sleep (1 )
64+ self ._from_worker_ch = Channel .make_process_local ()
65+ self ._from_worker_ch_serialized = self ._from_worker_ch .serialize ()
66+ self ._to_worker_ch = Channel .make_process_local ()
67+
68+ self ._start = None
69+ self ._interm = None
70+ self ._timings : OrderedDict [str , list [numbers .Number ]] = OrderedDict ()
71+ self ._timing_on = timing_on
72+
73+ def _add_label_to_timings (self , label : str ):
74+ if label not in self ._timings :
75+ self ._timings [label ] = []
76+
77+ @staticmethod
78+ def _format_number (number : numbers .Number ):
79+ return f"{ number :0.4e} "
80+
81+ def start_timings (self , batch_size : int ):
82+ if self ._timing_on :
83+ self ._add_label_to_timings ("batch_size" )
84+ self ._timings ["batch_size" ].append (batch_size )
85+ self ._start = time .perf_counter ()
86+ self ._interm = time .perf_counter ()
87+
88+ def end_timings (self ):
89+ if self ._timing_on :
90+ self ._add_label_to_timings ("total_time" )
91+ self ._timings ["total_time" ].append (self ._format_number (time .perf_counter ()- self ._start ))
92+
93+ def measure_time (self , label : str ):
94+ if self ._timing_on :
95+ self ._add_label_to_timings (label )
96+ self ._timings [label ].append (self ._format_number (time .perf_counter ()- self ._interm ))
97+ self ._interm = time .perf_counter ()
98+
99+ def print_timings (self , to_file : bool = False ):
100+ print (" " .join (self ._timings .keys ()))
101+ value_array = numpy .array ([value for value in self ._timings .values ()], dtype = float )
102+ value_array = numpy .transpose (value_array )
103+ for i in range (value_array .shape [0 ]):
104+ print (" " .join (self ._format_number (value ) for value in value_array [i ]))
105+ if to_file :
106+ numpy .save ("timings.npy" , value_array )
107+ numpy .savetxt ("timings.txt" , value_array )
108+
109+
110+ def run_model (self , model : bytes | str , batch : torch .Tensor ):
111+ self .start_timings (batch .shape [0 ])
112+ built_tensor = MessageHandler .build_tensor (
113+ batch .numpy (), "c" , "float32" , list (batch .shape ))
114+ self .measure_time ("build_tensor" )
115+ built_model = None
116+ if isinstance (model , str ):
117+ model_arg = MessageHandler .build_model_key (model )
118+ else :
119+ model_arg = MessageHandler .build_model (model , "resnet-50" , "1.0" )
120+ request = MessageHandler .build_request (
121+ reply_channel = self ._from_worker_ch_serialized ,
122+ model = model_arg ,
123+ inputs = [built_tensor ],
124+ outputs = [],
125+ output_descriptors = [],
126+ custom_attributes = None ,
127+ )
128+ self .measure_time ("build_request" )
129+ request_bytes = MessageHandler .serialize_request (request )
130+ self .measure_time ("serialize_request" )
131+ with self ._to_worker_fli .sendh (timeout = None , stream_channel = self ._to_worker_ch ) as to_sendh :
132+ to_sendh .send_bytes (request_bytes )
133+ logger .info (f"Message size: { len (request_bytes )} bytes" )
134+
135+ self .measure_time ("send" )
136+ with self ._from_worker_ch .recvh (timeout = None ) as from_recvh :
137+ resp = from_recvh .recv_bytes (timeout = None )
138+ self .measure_time ("receive" )
139+ response = MessageHandler .deserialize_response (resp )
140+ self .measure_time ("deserialize_response" )
141+ result = torch .from_numpy (
142+ numpy .frombuffer (
143+ response .result .data [0 ].blob ,
144+ dtype = str (response .result .data [0 ].tensorDescriptor .dataType ),
145+ )
146+ )
147+ self .measure_time ("deserialize_tensor" )
148+
149+ self .end_timings ()
150+ return result
151+
152+ def set_model (self , key : str , model : bytes ):
153+ self ._ddict [key ] = model
154+
155+
156+ class ResNetWrapper ():
157+ def __init__ (self , name : str , model : str ):
158+ self ._model = torch .jit .load (model )
159+ self ._name = name
160+ buffer = io .BytesIO ()
161+ scripted = torch .jit .trace (self ._model , self .get_batch ())
162+ torch .jit .save (scripted , buffer )
163+ self ._serialized_model = buffer .getvalue ()
164+
165+ def get_batch (self , batch_size : int = 32 ):
166+ return torch .randn ((batch_size , 3 , 224 , 224 ), dtype = torch .float32 )
167+
168+ @property
169+ def model (self ):
170+ return self ._serialized_model
171+
172+ @property
173+ def name (self ):
174+ return self ._name
175+
176+ if __name__ == "__main__" :
177+
178+ parser = argparse .ArgumentParser ("Mock application" )
179+ parser .add_argument ("--device" , default = "cpu" )
180+ args = parser .parse_args ()
181+
182+ resnet = ResNetWrapper ("resnet50" , f"resnet50.{ args .device .upper ()} .pt" )
183+
184+ client = ProtoClient (timing_on = True )
185+ client .set_model (resnet .name , resnet .model )
186+
187+ total_iterations = 100
188+
189+ for batch_size in [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 ]:
190+ logger .info (f"Batch size: { batch_size } " )
191+ for iteration_number in range (total_iterations + int (batch_size == 1 )):
192+ logger .info (f"Iteration: { iteration_number } " )
193+ client .run_model (resnet .name , resnet .get_batch (batch_size ))
194+
195+ client .print_timings (to_file = True )
0 commit comments