Skip to content

Commit b50e1d9

Browse files
apply make style
1 parent 7c6e7f7 commit b50e1d9

File tree

1 file changed

+44
-20
lines changed

1 file changed

+44
-20
lines changed

examples/community/stable_diffusion_tensorrt_txt2img.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,35 +15,42 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18-
import os
1918
import gc
20-
import numpy as np
21-
from copy import copy
19+
import os
2220
from collections import OrderedDict
21+
from copy import copy
2322
from typing import List, Optional, Union
2423

24+
import numpy as np
2525
import onnx
26-
from onnx import shape_inference
27-
import torch
28-
import tensorrt as trt
2926
import onnx_graphsurgeon as gs
27+
import tensorrt as trt
28+
import torch
29+
from huggingface_hub import snapshot_download
30+
from onnx import shape_inference
3031
from polygraphy import cuda
31-
from polygraphy.backend.onnx.loader import fold_constants
3232
from polygraphy.backend.common import bytes_from_path
33-
from polygraphy.backend.trt import CreateConfig, Profile
34-
from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine
33+
from polygraphy.backend.onnx.loader import fold_constants
34+
from polygraphy.backend.trt import (
35+
CreateConfig,
36+
Profile,
37+
engine_from_bytes,
38+
engine_from_network,
39+
network_from_onnx_path,
40+
save_engine,
41+
)
3542
from polygraphy.backend.trt import util as trt_util
3643
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
37-
from huggingface_hub import snapshot_download
3844

39-
from diffusers.utils import DIFFUSERS_CACHE, logging
40-
from diffusers.schedulers import DDIMScheduler
4145
from diffusers.models import AutoencoderKL, UNet2DConditionModel
4246
from diffusers.pipelines.stable_diffusion import (
4347
StableDiffusionPipeline,
4448
StableDiffusionPipelineOutput,
4549
StableDiffusionSafetyChecker,
4650
)
51+
from diffusers.schedulers import DDIMScheduler
52+
from diffusers.utils import DIFFUSERS_CACHE, logging
53+
4754

4855
"""
4956
Installation instructions
@@ -162,7 +169,7 @@ def infer(self, feed_dict, stream):
162169
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
163170
noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
164171
if not noerror:
165-
raise ValueError(f"ERROR: inference failed.")
172+
raise ValueError("ERROR: inference failed.")
166173

167174
return self.tensors
168175

@@ -469,9 +476,18 @@ def get_dynamic_axes(self):
469476

470477
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
471478
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
472-
min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = self.get_minmax_dims(
473-
batch_size, image_height, image_width, static_batch, static_shape
474-
)
479+
(
480+
min_batch,
481+
max_batch,
482+
_,
483+
_,
484+
_,
485+
_,
486+
min_latent_height,
487+
max_latent_height,
488+
min_latent_width,
489+
max_latent_width,
490+
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
475491
return {
476492
"sample": [
477493
(2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
@@ -534,9 +550,18 @@ def get_dynamic_axes(self):
534550

535551
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
536552
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
537-
min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = self.get_minmax_dims(
538-
batch_size, image_height, image_width, static_batch, static_shape
539-
)
553+
(
554+
min_batch,
555+
max_batch,
556+
_,
557+
_,
558+
_,
559+
_,
560+
min_latent_height,
561+
max_latent_height,
562+
min_latent_width,
563+
max_latent_width,
564+
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
540565
return {
541566
"latent": [
542567
(min_batch, 4, min_latent_height, min_latent_width),
@@ -774,7 +799,6 @@ def __denoise_latent(
774799
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
775800

776801
# Predict the noise residual
777-
embeddings_dtype = np.float16
778802
timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
779803

780804
sample_inp = device_view(latent_model_input)

0 commit comments

Comments
 (0)