|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 4 | +# |
| 5 | +# Redistribution and use in source and binary forms, with or without |
| 6 | +# modification, are permitted provided that the following conditions |
| 7 | +# are met: |
| 8 | +# * Redistributions of source code must retain the above copyright |
| 9 | +# notice, this list of conditions and the following disclaimer. |
| 10 | +# * Redistributions in binary form must reproduce the above copyright |
| 11 | +# notice, this list of conditions and the following disclaimer in the |
| 12 | +# documentation and/or other materials provided with the distribution. |
| 13 | +# * Neither the name of NVIDIA CORPORATION nor the names of its |
| 14 | +# contributors may be used to endorse or promote products derived |
| 15 | +# from this software without specific prior written permission. |
| 16 | +# |
| 17 | +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY |
| 18 | +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 19 | +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR |
| 20 | +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR |
| 21 | +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, |
| 22 | +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, |
| 23 | +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR |
| 24 | +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY |
| 25 | +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| 26 | +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 27 | +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 28 | + |
| 29 | +import importlib |
| 30 | +import json |
| 31 | +import os |
| 32 | + |
| 33 | +try: |
| 34 | + import torch |
| 35 | +except ModuleNotFoundError as error: |
| 36 | + raise RuntimeError("Missing/Incomplete PyTorch package installation") from error |
| 37 | + |
| 38 | +# triton_python_backend_utils is available in every Triton Python model. You |
| 39 | +# need to use this module to create inference requests and responses. It also |
| 40 | +# contains some utility functions for extracting information from model_config |
| 41 | +# and converting Triton input/output types to numpy types. |
| 42 | +import triton_python_backend_utils as pb_utils |
| 43 | + |
| 44 | + |
| 45 | +def _get_model_path(config): |
| 46 | + filenames = ["model.py", "model.pt"] |
| 47 | + if config["default_model_filename"]: |
| 48 | + filenames.insert(0, config["default_model_filename"]) |
| 49 | + for filename in filenames: |
| 50 | + model_path = os.path.join(pb_utils.get_model_dir(), filename) |
| 51 | + if os.path.exists(model_path): |
| 52 | + return model_path |
| 53 | + raise pb_utils.TritonModelException( |
| 54 | + "No model found in " + pb_utils.get_model_dir() + "/" + str(filenames) |
| 55 | + ) |
| 56 | + |
| 57 | + |
| 58 | +def _get_model_data_path(model_path): |
| 59 | + data_path_extensions = [".pt"] |
| 60 | + model_path_no_extension = model_path[: -(len(model_path.split(".")[-1]) + 1)] |
| 61 | + for extension in data_path_extensions: |
| 62 | + data_path = model_path_no_extension + extension |
| 63 | + if os.path.exists(data_path): |
| 64 | + return data_path |
| 65 | + # data file not provided |
| 66 | + return "" |
| 67 | + |
| 68 | + |
| 69 | +def _is_py_class_model(model_path): |
| 70 | + return model_path[-3:] == ".py" |
| 71 | + |
| 72 | + |
| 73 | +def _import_module_from_path(module_name, file_path): |
| 74 | + spec = importlib.util.spec_from_file_location(module_name, file_path) |
| 75 | + module = importlib.util.module_from_spec(spec) |
| 76 | + spec.loader.exec_module(module) |
| 77 | + return module |
| 78 | + |
| 79 | + |
| 80 | +def _get_model_class_from_module(module): |
| 81 | + names = dir(module) |
| 82 | + for name in names: |
| 83 | + attr = getattr(module, name) |
| 84 | + try: |
| 85 | + if issubclass(attr, torch.nn.Module): |
| 86 | + return attr |
| 87 | + except TypeError: |
| 88 | + # attr may not be a class |
| 89 | + pass |
| 90 | + raise pb_utils.TritonModelException("Cannot find a subclass of torch.nn.Module") |
| 91 | + |
| 92 | + |
| 93 | +def _parse_io_config(io_config): |
| 94 | + io = [] |
| 95 | + for conf in io_config: |
| 96 | + io.append({"name": conf["name"]}) |
| 97 | + return io |
| 98 | + |
| 99 | + |
| 100 | +def _get_device_name(kind, device_id): |
| 101 | + if kind == "GPU": |
| 102 | + return "cuda:" + device_id |
| 103 | + if kind == "CPU": |
| 104 | + return "cpu" |
| 105 | + # unspecified device |
| 106 | + return "" |
| 107 | + |
| 108 | + |
| 109 | +def _get_device(kind, device_id, model): |
| 110 | + device_name = _get_device_name(kind, device_id) |
| 111 | + if device_name == "": |
| 112 | + for param in model.parameters(): |
| 113 | + return param.device |
| 114 | + raise pb_utils.TritonModelException("Cannot determine model device") |
| 115 | + return torch.device(device_name) |
| 116 | + |
| 117 | + |
| 118 | +def _set_torch_parallelism(config): |
| 119 | + log_msg = "" |
| 120 | + parallelism_settings = ["NUM_THREADS", "NUM_INTEROP_THREADS"] |
| 121 | + for setting in parallelism_settings: |
| 122 | + val = "1" |
| 123 | + if setting in config["parameters"]: |
| 124 | + val = config["parameters"][setting]["string_value"] |
| 125 | + getattr(torch, "set_" + setting.lower())(int(val)) |
| 126 | + log_msg += setting + " = " + val + "; " |
| 127 | + return log_msg |
| 128 | + |
| 129 | + |
| 130 | +def _get_torch_compile_params(config): |
| 131 | + params = {} |
| 132 | + if "TORCH_COMPILE_OPTIONAL_PARAMETERS" in config["parameters"]: |
| 133 | + val = config["parameters"]["TORCH_COMPILE_OPTIONAL_PARAMETERS"]["string_value"] |
| 134 | + params = json.loads(val) |
| 135 | + if "model" in params: |
| 136 | + raise pb_utils.TritonModelException( |
| 137 | + "'model' is not an optional parameter for 'torch.compile'" |
| 138 | + ) |
| 139 | + return params |
| 140 | + |
| 141 | + |
| 142 | +def _gather_torch_tensors(scatter_tensors): |
| 143 | + gather_tensors = [] |
| 144 | + sections = [] |
| 145 | + for i in range(len(scatter_tensors)): |
| 146 | + tensors = scatter_tensors[i] |
| 147 | + for j in range(len(tensors)): |
| 148 | + tensor = tensors[j] |
| 149 | + if j < len(gather_tensors): |
| 150 | + # add to existing tensor |
| 151 | + gather_tensors[j] = torch.cat((gather_tensors[j], tensor), 0) |
| 152 | + else: |
| 153 | + # start a new tensor |
| 154 | + gather_tensors.append(tensor) |
| 155 | + # record section |
| 156 | + section_length = tensors[0].size()[0] |
| 157 | + sections.append(section_length) |
| 158 | + return gather_tensors, sections |
| 159 | + |
| 160 | + |
| 161 | +def _scatter_torch_tensors(gather_tensors, sections): |
| 162 | + scatter_tensors = [] |
| 163 | + for j in range(len(gather_tensors)): |
| 164 | + scatter_tensor = torch.split(gather_tensors[j], sections) |
| 165 | + for i in range(len(scatter_tensor)): |
| 166 | + tensor = scatter_tensor[i] |
| 167 | + if i < len(scatter_tensors): |
| 168 | + # add to existing response |
| 169 | + scatter_tensors[i].append(tensor) |
| 170 | + else: |
| 171 | + # start a new response |
| 172 | + scatter_tensors.append([tensor]) |
| 173 | + return scatter_tensors |
| 174 | + |
| 175 | + |
| 176 | +class TritonPythonModel: |
| 177 | + """Your Python model must use the same class name. Every Python model |
| 178 | + that is created must have "TritonPythonModel" as the class name. |
| 179 | + """ |
| 180 | + |
| 181 | + def initialize(self, args): |
| 182 | + """`initialize` is called only once when the model is being loaded. |
| 183 | + Implementing `initialize` function is optional. This function allows |
| 184 | + the model to initialize any state associated with this model. |
| 185 | + Parameters |
| 186 | + ---------- |
| 187 | + args : dict |
| 188 | + Both keys and values are strings. The dictionary keys and values are: |
| 189 | + * model_config: A JSON string containing the model configuration |
| 190 | + * model_instance_kind: A string containing model instance kind |
| 191 | + * model_instance_device_id: A string containing model instance device ID |
| 192 | + * model_repository: Model repository path |
| 193 | + * model_version: Model version |
| 194 | + * model_name: Model name |
| 195 | + """ |
| 196 | + self._model_name = args["model_name"] |
| 197 | + for_model = "for '" + self._model_name + "'" |
| 198 | + self._logger = pb_utils.Logger |
| 199 | + self._logger.log_info("Initializing model instance " + for_model) |
| 200 | + |
| 201 | + self._model_config = json.loads(args["model_config"]) |
| 202 | + self._kind = args["model_instance_kind"] |
| 203 | + self._device_id = args["model_instance_device_id"] |
| 204 | + self._support_batching = self._model_config["max_batch_size"] > 0 |
| 205 | + self._inputs = _parse_io_config(self._model_config["input"]) |
| 206 | + self._outputs = _parse_io_config(self._model_config["output"]) |
| 207 | + |
| 208 | + setting_msg = _set_torch_parallelism(self._model_config) |
| 209 | + self._logger.log_verbose( |
| 210 | + "Torch parallelism settings " + for_model + ": " + setting_msg |
| 211 | + ) |
| 212 | + |
| 213 | + self._infer_mode = torch.inference_mode(mode=True) |
| 214 | + self._infer_mode.__enter__() |
| 215 | + |
| 216 | + params = _get_torch_compile_params(self._model_config) |
| 217 | + self._logger.log_verbose( |
| 218 | + "'torch.compile' optional parameter(s) " + for_model + ": " + str(params) |
| 219 | + ) |
| 220 | + if self._support_batching: |
| 221 | + self._gather = torch.compile(_gather_torch_tensors, **params) |
| 222 | + self._scatter = torch.compile(_scatter_torch_tensors, **params) |
| 223 | + |
| 224 | + model_path = _get_model_path(self._model_config) |
| 225 | + if not _is_py_class_model(model_path): |
| 226 | + self._logger.log_info("Loading '" + self._model_name + "' as TorchScript") |
| 227 | + self._model = torch.jit.load(model_path) |
| 228 | + self._device = _get_device(self._kind, self._device_id, self._model) |
| 229 | + self._model.to(self._device) |
| 230 | + self._model.eval() |
| 231 | + return |
| 232 | + |
| 233 | + self._model_module = _import_module_from_path(self._model_name, model_path) |
| 234 | + self._model_class = _get_model_class_from_module(self._model_module) |
| 235 | + self._raw_model = self._model_class() |
| 236 | + self._device = _get_device(self._kind, self._device_id, self._raw_model) |
| 237 | + data_path = _get_model_data_path(model_path) |
| 238 | + if data_path != "": |
| 239 | + self._raw_model.load_state_dict( |
| 240 | + torch.load(data_path, map_location=self._device) |
| 241 | + ) |
| 242 | + else: |
| 243 | + self._logger.log_info("Model parameter file not found " + for_model) |
| 244 | + self._raw_model.to(self._device) |
| 245 | + self._raw_model.eval() |
| 246 | + self._model = torch.compile(self._raw_model, **params) |
| 247 | + |
| 248 | + def execute(self, requests): |
| 249 | + """`execute` MUST be implemented in every Python model. `execute` |
| 250 | + function receives a list of pb_utils.InferenceRequest as the only |
| 251 | + argument. This function is called when an inference request is made |
| 252 | + for this model. Depending on the batching configuration (e.g. Dynamic |
| 253 | + Batching) used, `requests` may contain multiple requests. Every |
| 254 | + Python model, must create one pb_utils.InferenceResponse for every |
| 255 | + pb_utils.InferenceRequest in `requests`. If there is an error, you can |
| 256 | + set the error argument when creating a pb_utils.InferenceResponse |
| 257 | + Parameters |
| 258 | + ---------- |
| 259 | + requests : list |
| 260 | + A list of pb_utils.InferenceRequest |
| 261 | + Returns |
| 262 | + ------- |
| 263 | + list |
| 264 | + A list of pb_utils.InferenceResponse. The length of this list must |
| 265 | + be the same as `requests` |
| 266 | + """ |
| 267 | + |
| 268 | + responses = [] |
| 269 | + |
| 270 | + requests_tensors = [] |
| 271 | + for request in requests: |
| 272 | + tensors = [] |
| 273 | + for io in self._inputs: |
| 274 | + tensor = pb_utils.get_input_tensor_by_name( |
| 275 | + request, io["name"] |
| 276 | + ).to_dlpack() |
| 277 | + tensor = torch.from_dlpack(tensor).to(self._device) |
| 278 | + tensors.append(tensor) |
| 279 | + requests_tensors.append(tensors) |
| 280 | + |
| 281 | + sections = None |
| 282 | + if self._support_batching: |
| 283 | + requests_tensors, sections = self._gather(requests_tensors) |
| 284 | + requests_tensors = [requests_tensors] |
| 285 | + |
| 286 | + responses_tensors = [] |
| 287 | + for input_tensors in requests_tensors: |
| 288 | + output_tensors = self._model(*input_tensors) |
| 289 | + if not isinstance(output_tensors, tuple) and not isinstance( |
| 290 | + output_tensors, list |
| 291 | + ): |
| 292 | + output_tensors = [output_tensors] |
| 293 | + responses_tensors.append(output_tensors) |
| 294 | + |
| 295 | + if self._support_batching: |
| 296 | + responses_tensors = self._scatter(responses_tensors[0], sections) |
| 297 | + |
| 298 | + for response_tensors in responses_tensors: |
| 299 | + output_tensors = [] |
| 300 | + for i in range(len(self._outputs)): |
| 301 | + io = self._outputs[i] |
| 302 | + tensor = response_tensors[i].detach() |
| 303 | + tensor = pb_utils.Tensor.from_dlpack(io["name"], tensor) |
| 304 | + output_tensors.append(tensor) |
| 305 | + inference_response = pb_utils.InferenceResponse( |
| 306 | + output_tensors=output_tensors |
| 307 | + ) |
| 308 | + responses.append(inference_response) |
| 309 | + |
| 310 | + return responses |
| 311 | + |
| 312 | + def finalize(self): |
| 313 | + """`finalize` is called only once when the model is being unloaded. |
| 314 | + Implementing `finalize` function is OPTIONAL. This function allows |
| 315 | + the model to perform any necessary clean ups before exit. |
| 316 | + """ |
| 317 | + self._logger.log_info("Removing model instance for '" + self._model_name + "'") |
| 318 | + self._infer_mode.__exit__(exc_type=None, exc_value=None, traceback=None) |
0 commit comments