Skip to content

Commit 4ecccdd

Browse files
committed
Add Python backend based PyTorch runtime
1 parent 304c2e8 commit 4ecccdd

File tree

2 files changed

+325
-0
lines changed

2 files changed

+325
-0
lines changed

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,13 @@ install(
504504
${INSTALL_CONFIGDIR}
505505
)
506506

507+
install(
508+
FILES
509+
src/model.py
510+
DESTINATION
511+
${CMAKE_INSTALL_PREFIX}/backends/pytorch
512+
)
513+
507514
include(CMakePackageConfigHelpers)
508515
configure_package_config_file(
509516
${CMAKE_CURRENT_LIST_DIR}/cmake/TritonPyTorchBackendConfig.cmake.in

src/model.py

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions
7+
# are met:
8+
# * Redistributions of source code must retain the above copyright
9+
# notice, this list of conditions and the following disclaimer.
10+
# * Redistributions in binary form must reproduce the above copyright
11+
# notice, this list of conditions and the following disclaimer in the
12+
# documentation and/or other materials provided with the distribution.
13+
# * Neither the name of NVIDIA CORPORATION nor the names of its
14+
# contributors may be used to endorse or promote products derived
15+
# from this software without specific prior written permission.
16+
#
17+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
18+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
20+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
21+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
25+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
import importlib
30+
import json
31+
import os
32+
33+
try:
34+
import torch
35+
except ModuleNotFoundError as error:
36+
raise RuntimeError("Missing/Incomplete PyTorch package installation") from error
37+
38+
# triton_python_backend_utils is available in every Triton Python model. You
39+
# need to use this module to create inference requests and responses. It also
40+
# contains some utility functions for extracting information from model_config
41+
# and converting Triton input/output types to numpy types.
42+
import triton_python_backend_utils as pb_utils
43+
44+
45+
def _get_model_path(config):
46+
filenames = ["model.py", "model.pt"]
47+
if config["default_model_filename"]:
48+
filenames.insert(0, config["default_model_filename"])
49+
for filename in filenames:
50+
model_path = os.path.join(pb_utils.get_model_dir(), filename)
51+
if os.path.exists(model_path):
52+
return model_path
53+
raise pb_utils.TritonModelException(
54+
"No model found in " + pb_utils.get_model_dir() + "/" + str(filenames)
55+
)
56+
57+
58+
def _get_model_data_path(model_path):
59+
data_path_extensions = [".pt"]
60+
model_path_no_extension = model_path[: -(len(model_path.split(".")[-1]) + 1)]
61+
for extension in data_path_extensions:
62+
data_path = model_path_no_extension + extension
63+
if os.path.exists(data_path):
64+
return data_path
65+
# data file not provided
66+
return ""
67+
68+
69+
def _is_py_class_model(model_path):
70+
return model_path[-3:] == ".py"
71+
72+
73+
def _import_module_from_path(module_name, file_path):
74+
spec = importlib.util.spec_from_file_location(module_name, file_path)
75+
module = importlib.util.module_from_spec(spec)
76+
spec.loader.exec_module(module)
77+
return module
78+
79+
80+
def _get_model_class_from_module(module):
81+
names = dir(module)
82+
for name in names:
83+
attr = getattr(module, name)
84+
try:
85+
if issubclass(attr, torch.nn.Module):
86+
return attr
87+
except TypeError:
88+
# attr may not be a class
89+
pass
90+
raise pb_utils.TritonModelException("Cannot find a subclass of torch.nn.Module")
91+
92+
93+
def _parse_io_config(io_config):
94+
io = []
95+
for conf in io_config:
96+
io.append({"name": conf["name"]})
97+
return io
98+
99+
100+
def _get_device_name(kind, device_id):
101+
if kind == "GPU":
102+
return "cuda:" + device_id
103+
if kind == "CPU":
104+
return "cpu"
105+
# unspecified device
106+
return ""
107+
108+
109+
def _get_device(kind, device_id, model):
110+
device_name = _get_device_name(kind, device_id)
111+
if device_name == "":
112+
for param in model.parameters():
113+
return param.device
114+
raise pb_utils.TritonModelException("Cannot determine model device")
115+
return torch.device(device_name)
116+
117+
118+
def _set_torch_parallelism(config):
119+
log_msg = ""
120+
parallelism_settings = ["NUM_THREADS", "NUM_INTEROP_THREADS"]
121+
for setting in parallelism_settings:
122+
val = "1"
123+
if setting in config["parameters"]:
124+
val = config["parameters"][setting]["string_value"]
125+
getattr(torch, "set_" + setting.lower())(int(val))
126+
log_msg += setting + " = " + val + "; "
127+
return log_msg
128+
129+
130+
def _get_torch_compile_params(config):
131+
params = {}
132+
if "TORCH_COMPILE_OPTIONAL_PARAMETERS" in config["parameters"]:
133+
val = config["parameters"]["TORCH_COMPILE_OPTIONAL_PARAMETERS"]["string_value"]
134+
params = json.loads(val)
135+
if "model" in params:
136+
raise pb_utils.TritonModelException(
137+
"'model' is not an optional parameter for 'torch.compile'"
138+
)
139+
return params
140+
141+
142+
def _gather_torch_tensors(scatter_tensors):
143+
gather_tensors = []
144+
sections = []
145+
for i in range(len(scatter_tensors)):
146+
tensors = scatter_tensors[i]
147+
for j in range(len(tensors)):
148+
tensor = tensors[j]
149+
if j < len(gather_tensors):
150+
# add to existing tensor
151+
gather_tensors[j] = torch.cat((gather_tensors[j], tensor), 0)
152+
else:
153+
# start a new tensor
154+
gather_tensors.append(tensor)
155+
# record section
156+
section_length = tensors[0].size()[0]
157+
sections.append(section_length)
158+
return gather_tensors, sections
159+
160+
161+
def _scatter_torch_tensors(gather_tensors, sections):
162+
scatter_tensors = []
163+
for j in range(len(gather_tensors)):
164+
scatter_tensor = torch.split(gather_tensors[j], sections)
165+
for i in range(len(scatter_tensor)):
166+
tensor = scatter_tensor[i]
167+
if i < len(scatter_tensors):
168+
# add to existing response
169+
scatter_tensors[i].append(tensor)
170+
else:
171+
# start a new response
172+
scatter_tensors.append([tensor])
173+
return scatter_tensors
174+
175+
176+
class TritonPythonModel:
177+
"""Your Python model must use the same class name. Every Python model
178+
that is created must have "TritonPythonModel" as the class name.
179+
"""
180+
181+
def initialize(self, args):
182+
"""`initialize` is called only once when the model is being loaded.
183+
Implementing `initialize` function is optional. This function allows
184+
the model to initialize any state associated with this model.
185+
Parameters
186+
----------
187+
args : dict
188+
Both keys and values are strings. The dictionary keys and values are:
189+
* model_config: A JSON string containing the model configuration
190+
* model_instance_kind: A string containing model instance kind
191+
* model_instance_device_id: A string containing model instance device ID
192+
* model_repository: Model repository path
193+
* model_version: Model version
194+
* model_name: Model name
195+
"""
196+
self._model_name = args["model_name"]
197+
for_model = "for '" + self._model_name + "'"
198+
self._logger = pb_utils.Logger
199+
self._logger.log_info("Initializing model instance " + for_model)
200+
201+
self._model_config = json.loads(args["model_config"])
202+
self._kind = args["model_instance_kind"]
203+
self._device_id = args["model_instance_device_id"]
204+
self._support_batching = self._model_config["max_batch_size"] > 0
205+
self._inputs = _parse_io_config(self._model_config["input"])
206+
self._outputs = _parse_io_config(self._model_config["output"])
207+
208+
setting_msg = _set_torch_parallelism(self._model_config)
209+
self._logger.log_verbose(
210+
"Torch parallelism settings " + for_model + ": " + setting_msg
211+
)
212+
213+
self._infer_mode = torch.inference_mode(mode=True)
214+
self._infer_mode.__enter__()
215+
216+
params = _get_torch_compile_params(self._model_config)
217+
self._logger.log_verbose(
218+
"'torch.compile' optional parameter(s) " + for_model + ": " + str(params)
219+
)
220+
if self._support_batching:
221+
self._gather = torch.compile(_gather_torch_tensors, **params)
222+
self._scatter = torch.compile(_scatter_torch_tensors, **params)
223+
224+
model_path = _get_model_path(self._model_config)
225+
if not _is_py_class_model(model_path):
226+
self._logger.log_info("Loading '" + self._model_name + "' as TorchScript")
227+
self._model = torch.jit.load(model_path)
228+
self._device = _get_device(self._kind, self._device_id, self._model)
229+
self._model.to(self._device)
230+
self._model.eval()
231+
return
232+
233+
self._model_module = _import_module_from_path(self._model_name, model_path)
234+
self._model_class = _get_model_class_from_module(self._model_module)
235+
self._raw_model = self._model_class()
236+
self._device = _get_device(self._kind, self._device_id, self._raw_model)
237+
data_path = _get_model_data_path(model_path)
238+
if data_path != "":
239+
self._raw_model.load_state_dict(
240+
torch.load(data_path, map_location=self._device)
241+
)
242+
else:
243+
self._logger.log_info("Model parameter file not found " + for_model)
244+
self._raw_model.to(self._device)
245+
self._raw_model.eval()
246+
self._model = torch.compile(self._raw_model, **params)
247+
248+
def execute(self, requests):
249+
"""`execute` MUST be implemented in every Python model. `execute`
250+
function receives a list of pb_utils.InferenceRequest as the only
251+
argument. This function is called when an inference request is made
252+
for this model. Depending on the batching configuration (e.g. Dynamic
253+
Batching) used, `requests` may contain multiple requests. Every
254+
Python model, must create one pb_utils.InferenceResponse for every
255+
pb_utils.InferenceRequest in `requests`. If there is an error, you can
256+
set the error argument when creating a pb_utils.InferenceResponse
257+
Parameters
258+
----------
259+
requests : list
260+
A list of pb_utils.InferenceRequest
261+
Returns
262+
-------
263+
list
264+
A list of pb_utils.InferenceResponse. The length of this list must
265+
be the same as `requests`
266+
"""
267+
268+
responses = []
269+
270+
requests_tensors = []
271+
for request in requests:
272+
tensors = []
273+
for io in self._inputs:
274+
tensor = pb_utils.get_input_tensor_by_name(
275+
request, io["name"]
276+
).to_dlpack()
277+
tensor = torch.from_dlpack(tensor).to(self._device)
278+
tensors.append(tensor)
279+
requests_tensors.append(tensors)
280+
281+
sections = None
282+
if self._support_batching:
283+
requests_tensors, sections = self._gather(requests_tensors)
284+
requests_tensors = [requests_tensors]
285+
286+
responses_tensors = []
287+
for input_tensors in requests_tensors:
288+
output_tensors = self._model(*input_tensors)
289+
if not isinstance(output_tensors, tuple) and not isinstance(
290+
output_tensors, list
291+
):
292+
output_tensors = [output_tensors]
293+
responses_tensors.append(output_tensors)
294+
295+
if self._support_batching:
296+
responses_tensors = self._scatter(responses_tensors[0], sections)
297+
298+
for response_tensors in responses_tensors:
299+
output_tensors = []
300+
for i in range(len(self._outputs)):
301+
io = self._outputs[i]
302+
tensor = response_tensors[i].detach()
303+
tensor = pb_utils.Tensor.from_dlpack(io["name"], tensor)
304+
output_tensors.append(tensor)
305+
inference_response = pb_utils.InferenceResponse(
306+
output_tensors=output_tensors
307+
)
308+
responses.append(inference_response)
309+
310+
return responses
311+
312+
def finalize(self):
313+
"""`finalize` is called only once when the model is being unloaded.
314+
Implementing `finalize` function is OPTIONAL. This function allows
315+
the model to perform any necessary clean ups before exit.
316+
"""
317+
self._logger.log_info("Removing model instance for '" + self._model_name + "'")
318+
self._infer_mode.__exit__(exc_type=None, exc_value=None, traceback=None)

0 commit comments

Comments
 (0)