22 modeled after the textual_inversion.py / train_dreambooth.py and the work
33 of justinpinkney here: https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb
44"""
5+ import inspect
56import warnings
67from typing import List , Optional , Union
7- import inspect
88
99import numpy as np
1010import torch
1818from diffusers .pipelines .stable_diffusion import StableDiffusionPipelineOutput
1919from diffusers .pipelines .stable_diffusion .safety_checker import StableDiffusionSafetyChecker
2020from diffusers .schedulers import DDIMScheduler , LMSDiscreteScheduler , PNDMScheduler
21+ from diffusers .utils import logging
2122from tqdm .auto import tqdm
2223from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
23- from diffusers . utils import logging
24+
2425
2526logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
2627
@@ -119,7 +120,6 @@ def disable_attention_slicing(self):
119120 # set slice_size = `None` to disable `attention slicing`
120121 self .enable_attention_slicing (None )
121122
122-
123123 def train (
124124 self ,
125125 prompt : Union [str , List [str ]],
@@ -379,14 +379,13 @@ def __call__(
379379 """
380380 if height % 8 != 0 or width % 8 != 0 :
381381 raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
382- if self .text_embeddings is None :
382+ if self .text_embeddings is None :
383383 raise ValueError ("Please run the pipe.train() before trying to generate an image." )
384384 if self .text_embeddings_orig is None :
385385 raise ValueError ("Please run the pipe.train() before trying to generate an image." )
386386
387387 text_embeddings = alpha * self .text_embeddings_orig + (1 - alpha ) * self .text_embeddings
388388
389-
390389 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
391390 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
392391 # corresponds to doing no classifier free guidance.
@@ -487,4 +486,4 @@ def __call__(
487486 if not return_dict :
488487 return (image , has_nsfw_concept )
489488
490- return StableDiffusionPipelineOutput (images = image , nsfw_content_detected = has_nsfw_concept )
489+ return StableDiffusionPipelineOutput (images = image , nsfw_content_detected = has_nsfw_concept )
0 commit comments