|
15 | 15 | # See the License for the specific language governing permissions and
|
16 | 16 | # limitations under the License.
|
17 | 17 |
|
18 |
| -import os |
19 | 18 | import gc
|
20 |
| -import numpy as np |
21 |
| -from copy import copy |
| 19 | +import os |
22 | 20 | from collections import OrderedDict
|
| 21 | +from copy import copy |
23 | 22 | from typing import List, Optional, Union
|
24 | 23 |
|
| 24 | +import numpy as np |
25 | 25 | import onnx
|
26 |
| -from onnx import shape_inference |
27 |
| -import torch |
28 |
| -import tensorrt as trt |
29 | 26 | import onnx_graphsurgeon as gs
|
| 27 | +import tensorrt as trt |
| 28 | +import torch |
| 29 | +from huggingface_hub import snapshot_download |
| 30 | +from onnx import shape_inference |
30 | 31 | from polygraphy import cuda
|
31 |
| -from polygraphy.backend.onnx.loader import fold_constants |
32 | 32 | from polygraphy.backend.common import bytes_from_path
|
33 |
| -from polygraphy.backend.trt import CreateConfig, Profile |
34 |
| -from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine |
| 33 | +from polygraphy.backend.onnx.loader import fold_constants |
| 34 | +from polygraphy.backend.trt import ( |
| 35 | + CreateConfig, |
| 36 | + Profile, |
| 37 | + engine_from_bytes, |
| 38 | + engine_from_network, |
| 39 | + network_from_onnx_path, |
| 40 | + save_engine, |
| 41 | +) |
35 | 42 | from polygraphy.backend.trt import util as trt_util
|
36 | 43 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
37 |
| -from huggingface_hub import snapshot_download |
38 | 44 |
|
39 |
| -from diffusers.utils import DIFFUSERS_CACHE, logging |
40 |
| -from diffusers.schedulers import DDIMScheduler |
41 | 45 | from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
42 | 46 | from diffusers.pipelines.stable_diffusion import (
|
43 | 47 | StableDiffusionPipeline,
|
44 | 48 | StableDiffusionPipelineOutput,
|
45 | 49 | StableDiffusionSafetyChecker,
|
46 | 50 | )
|
| 51 | +from diffusers.schedulers import DDIMScheduler |
| 52 | +from diffusers.utils import DIFFUSERS_CACHE, logging |
| 53 | + |
47 | 54 |
|
48 | 55 | """
|
49 | 56 | Installation instructions
|
@@ -162,7 +169,7 @@ def infer(self, feed_dict, stream):
|
162 | 169 | bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
|
163 | 170 | noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
|
164 | 171 | if not noerror:
|
165 |
| - raise ValueError(f"ERROR: inference failed.") |
| 172 | + raise ValueError("ERROR: inference failed.") |
166 | 173 |
|
167 | 174 | return self.tensors
|
168 | 175 |
|
@@ -469,9 +476,18 @@ def get_dynamic_axes(self):
|
469 | 476 |
|
470 | 477 | def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
471 | 478 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
472 |
| - min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = self.get_minmax_dims( |
473 |
| - batch_size, image_height, image_width, static_batch, static_shape |
474 |
| - ) |
| 479 | + ( |
| 480 | + min_batch, |
| 481 | + max_batch, |
| 482 | + _, |
| 483 | + _, |
| 484 | + _, |
| 485 | + _, |
| 486 | + min_latent_height, |
| 487 | + max_latent_height, |
| 488 | + min_latent_width, |
| 489 | + max_latent_width, |
| 490 | + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) |
475 | 491 | return {
|
476 | 492 | "sample": [
|
477 | 493 | (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
|
@@ -534,9 +550,18 @@ def get_dynamic_axes(self):
|
534 | 550 |
|
535 | 551 | def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
536 | 552 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
537 |
| - min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = self.get_minmax_dims( |
538 |
| - batch_size, image_height, image_width, static_batch, static_shape |
539 |
| - ) |
| 553 | + ( |
| 554 | + min_batch, |
| 555 | + max_batch, |
| 556 | + _, |
| 557 | + _, |
| 558 | + _, |
| 559 | + _, |
| 560 | + min_latent_height, |
| 561 | + max_latent_height, |
| 562 | + min_latent_width, |
| 563 | + max_latent_width, |
| 564 | + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) |
540 | 565 | return {
|
541 | 566 | "latent": [
|
542 | 567 | (min_batch, 4, min_latent_height, min_latent_width),
|
@@ -774,7 +799,6 @@ def __denoise_latent(
|
774 | 799 | latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
775 | 800 |
|
776 | 801 | # Predict the noise residual
|
777 |
| - embeddings_dtype = np.float16 |
778 | 802 | timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
|
779 | 803 |
|
780 | 804 | sample_inp = device_view(latent_model_input)
|
|
0 commit comments