diff --git a/examples/community/README.md b/examples/community/README.md index aee6ffee09c7..f6044102641f 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -48,6 +48,7 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap | Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) | | Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) | | Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) | +| Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) | | LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) | | @@ -2524,6 +2525,181 @@ images[0].save("controlnet_and_adapter_inpaint.png") ``` +### Regional Prompting Pipeline +This pipeline is a port of the [Regional Prompter extension](https://github.com/hako-mikan/sd-webui-regional-prompter) for [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) to diffusers. +This code implements a pipeline for the Stable Diffusion model, enabling the division of the canvas into multiple regions, with different prompts applicable to each region. Users can specify regions in two ways: using `Cols` and `Rows` modes for grid-like divisions, or the `Prompt` mode for regions calculated based on prompts. + +![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline1.png) + +### Usage +### Sample Code +``` +from from examples.community.regional_prompting_stable_diffusion import RegionalPromptingStableDiffusionPipeline +pipe = RegionalPromptingStableDiffusionPipeline.from_single_file(model_path, vae=vae) + +rp_args = { + "mode":"rows", + "div": "1;1;1" +} + +prompt =""" +green hair twintail BREAK +red blouse BREAK +blue skirt +""" + +images = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + guidance_scale=7.5, + height = 768, + width = 512, + num_inference_steps =20, + num_images_per_prompt = 1, + rp_args = rp_args + ).images + +time = time.strftime(r"%Y%m%d%H%M%S") +i = 1 +for image in images: + i += 1 + fileName = f'img-{time}-{i+1}.png' + image.save(fileName) +``` +### Cols, Rows mode +In the Cols, Rows mode, you can split the screen vertically and horizontally and assign prompts to each region. The split ratio can be specified by 'div', and you can set the division ratio like '3;3;2' or '0.1;0.5'. Furthermore, as will be described later, you can also subdivide the split Cols, Rows to specify more complex regions. + +In this image, the image is divided into three parts, and a separate prompt is applied to each. The prompts are divided by 'BREAK', and each is applied to the respective region. +![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline2.png) +``` +green hair twintail BREAK +red blouse BREAK +blue skirt +``` + +### 2-Dimentional division +The prompt consists of instructions separated by the term `BREAK` and is assigned to different regions of a two-dimensional space. The image is initially split in the main splitting direction, which in this case is rows, due to the presence of a single semicolon`;`, dividing the space into an upper and a lower section. Additional sub-splitting is then applied, indicated by commas. The upper row is split into ratios of `2:1:1`, while the lower row is split into a ratio of `4:6`. Rows themselves are split in a `1:2` ratio. According to the reference image, the blue sky is designated as the first region, green hair as the second, the bookshelf as the third, and so on, in a sequence based on their position from the top left. The terrarium is placed on the desk in the fourth region, and the orange dress and sofa are in the fifth region, conforming to their respective splits. +``` +rp_args = { + "mode":"rows", + "div": "1,2,1,1;2,4,6" +} + +prompt =""" +blue sky BREAK +green hair BREAK +book shelf BREAK +terrarium on desk BREAK +orange dress and sofa +""" +``` +![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline4.png) + +### Prompt Mode +There are limitations to methods of specifying regions in advance. This is because specifying regions can be a hindrance when designating complex shapes or dynamic compositions. In the region specified by the prompt, the regions is determined after the image generation has begun. This allows us to accommodate compositions and complex regions. +For further infomagen, see [here](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/main/prompt_en.md). +### syntax +``` +baseprompt target1 target2 BREAK +effect1, target1 BREAK +effect2 ,target2 +``` + +First, write the base prompt. In the base prompt, write the words (target1, target2) for which you want to create a mask. Next, separate them with BREAK. Next, write the prompt corresponding to target1. Then enter a comma and write target1. The order of the targets in the base prompt and the order of the BREAK-separated targets can be back to back. + +``` +target2 baseprompt target1 BREAK +effect1, target1 BREAK +effect2 ,target2 +``` +is also effective. + +### Sample +In this example, masks are calculated for shirt, tie, skirt, and color prompts are specified only for those regions. +``` +rp_args = { + "mode":"prompt-ex", + "save_mask":True, + "th": "0.4,0.6,0.6", +} + +prompt =""" +a girl in street with shirt, tie, skirt BREAK +red, shirt BREAK +green, tie BREAK +blue , skirt +""" +``` +![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline3.png) +### threshold +The threshold used to determine the mask created by the prompt. This can be set as many times as there are masks, as the range varies widely depending on the target prompt. If multiple regions are used, enter them separated by commas. For example, hair tends to be ambiguous and requires a small value, while face tends to be large and requires a small value. These should be ordered by BREAK. + +``` +a lady ,hair, face BREAK +red, hair BREAK +tanned ,face +``` +`threshold : 0.4,0.6` +If only one input is given for multiple regions, they are all assumed to be the same value. + +### Prompt and Prompt-EX +The difference is that in Prompt, duplicate regions are added, whereas in Prompt-EX, duplicate regions are overwritten sequentially. Since they are processed in order, setting a TARGET with a large regions first makes it easier for the effect of small regions to remain unmuffled. + +### Accuracy +In the case of a 512 x 512 image, Attention mode reduces the size of the region to about 8 x 8 pixels deep in the U-Net, so that small regions get mixed up; Latent mode calculates 64*64, so that the region is exact. +``` +girl hair twintail frills,ribbons, dress, face BREAK +girl, ,face +``` + +### Mask +When an image is generated, the generated mask is displayed. It is generated at the same size as the image, but is actually used at a much smaller size. + + +### Use common prompt +You can attach the prompt up to ADDCOMM to all prompts by separating it first with ADDCOMM. This is useful when you want to include elements common to all regions. For example, when generating pictures of three people with different appearances, it's necessary to include the instruction of 'three people' in all regions. It's also useful when inserting quality tags and other things."For example, if you write as follows: +``` +best quality, 3persons in garden, ADDCOMM +a girl white dress BREAK +a boy blue shirt BREAK +an old man red suit +``` +If common is enabled, this prompt is converted to the following: +``` +best quality, 3persons in garden, a girl white dress BREAK +best quality, 3persons in garden, a boy blue shirt BREAK +best quality, 3persons in garden, an old man red suit +``` +### Negative prompt +Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions. + +### Parameters +To activate Regional Prompter, it is necessary to enter settings in rp_args. The items that can be set are as follows. rp_args is a dictionary type. + +### Input Parameters +Parameters are specified through the `rp_arg`(dictionary type). + +``` +rp_args = { + "mode":"rows", + "div": "1;1;1" +} + +pipe(prompt =prompt, rp_args = rp_args) +``` + + + +### Required Parameters +- `mode`: Specifies the method for defining regions. Choose from `Cols`, `Rows`, `Prompt` or `Prompt-Ex`. This parameter is case-insensitive. +- `divide`: Used in `Cols` and `Rows` modes. Details on how to specify this are provided under the respective `Cols` and `Rows` sections. +- `th`: Used in `Prompt` mode. The method of specification is detailed under the `Prompt` section. + +### Optional Parameters +- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`. + +The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed. + ## Diffusion Posterior Sampling Pipeline * Reference paper ``` diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py new file mode 100644 index 000000000000..24a50d9b9011 --- /dev/null +++ b/examples/community/regional_prompting_stable_diffusion.py @@ -0,0 +1,511 @@ +import torchvision.transforms.functional as FF +import torch +import torchvision +from typing import Dict, Optional +from diffusers import StableDiffusionPipeline +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.utils import USE_PEFT_BACKEND +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +try: + from compel import Compel +except: + Compel = None + +KCOMM = "ADDCOMM" +KBRK = "BREAK" + +class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): + r""" + Args for Regional Prompting Pipeline: + rp_args:dict + Required + rp_args["mode"]: cols, rows, prompt, prompt-ex + for cols, rows mode + rp_args["div"]: ex) 1;1;1(Divide into 3 regions) + for prompt, prompt-ex mode + rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode) + + Optional + rp_args["save_mask"]: True/False (save masks in prompt mode) + + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__(vae,text_encoder,tokenizer,unet,scheduler,safety_checker,feature_extractor,requires_safety_checker) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + @torch.no_grad() + def __call__( + self, + prompt: str, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + rp_args:Dict[str,str] = None, + ): + + active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt + if negative_prompt is None: negative_prompt = "" if type(prompt) == str else [""] * len(prompt) + + device = self._execution_device + regions = 0 + + self.power = int(rp_args["power"]) if "power" in rp_args else 1 + + prompts = prompt if type(prompt) == list else [prompt] + n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] + self.batch = batch = num_images_per_prompt * len(prompts) + all_prompts_cn, all_prompts_p = promptsmaker(prompts,num_images_per_prompt) + all_n_prompts_cn, _ = promptsmaker(n_prompts,num_images_per_prompt) + + cn = len(all_prompts_cn) == len(all_n_prompts_cn) + + if Compel: + compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder) + def getcompelembs(prps): + embl = [] + for prp in prps: + embl.append(compel.build_conditioning_tensor(prp)) + return torch.cat(embl) + conds = getcompelembs(all_prompts_cn) + unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts) + embs = getcompelembs(prompts) + n_embs = getcompelembs(n_prompts) + prompt = negative_prompt = None + else: + conds = self.encode_prompt(prompts, device, 1, True)[0] + unconds = self.encode_prompt(n_prompts, device, 1, True)[0] if cn else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0] + embs = n_embs = None + + if not active: + pcallback = None + mode = None + else: + if any(x in rp_args["mode"].upper() for x in ["COL","ROW"]): + mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW" + ocells,icells,regions = make_cells(rp_args["div"]) + + elif "PRO" in rp_args["mode"].upper(): + regions = len(all_prompts_p[0]) + mode = "PROMPT" + reset_attnmaps(self) + self.ex = "EX" in rp_args["mode"].upper() + self.target_tokens = target_tokens = tokendealer(self, all_prompts_p) + thresholds = [float(x) for x in rp_args["th"].split(",")] + + orig_hw = (height,width) + revers = True + + def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor,selfs=None): + if "PRO" in mode: # in Prompt mode, make masks from sum of attension maps + self.step = step + + if len(self.attnmaps_sizes) > 3: + self.history[step] = self.attnmaps.copy() + for hw in self.attnmaps_sizes: + allmasks = [] + basemasks = [None] * batch + for tt, th in zip(target_tokens, thresholds): + for b in range(batch): + key = f"{tt}-{b}" + _, mask, _ = makepmask(self, self.attnmaps[key], hw[0], hw[1], th, step) + mask = mask.unsqueeze(0).unsqueeze(-1) + if self.ex: + allmasks[b::batch] = [x - mask for x in allmasks[b::batch]] + allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]] + allmasks.append(mask) + basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask + basemasks = [1 -mask for mask in basemasks] + basemasks = [torch.where(x > 0, 1, 0) for x in basemasks] + allmasks = basemasks + allmasks + + self.attnmasks[hw] = torch.cat(allmasks) + self.maskready = True + return latents + + def hook_forward(module): + #diffusers==0.23.2 + def forward( + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> torch.Tensor: + + attn = module + xshape = hidden_states.shape + self.hw = (h,w) = split_dims(xshape[1], *orig_hw) + + if revers: + nx,px = hidden_states.chunk(2) + else: + px,nx = hidden_states.chunk(2) + + if cn: + hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)],0) + encoder_hidden_states = torch.cat([conds]+[unconds]) + else: + hidden_states = torch.cat([px for i in range(regions)] + [nx],0) + encoder_hidden_states = torch.cat([conds]+[unconds]) + + residual = hidden_states + + args = () if USE_PEFT_BACKEND else (scale,) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + args = () if USE_PEFT_BACKEND else (scale,) + query = attn.to_q(hidden_states, *args) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = scaled_dot_product_attention( + self, query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, getattn = "PRO" in mode + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + #### Regional Prompting Col/Row mode + if any(x in mode for x in ["COL", "ROW"]): + reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2]) + center = reshaped.shape[0] // 2 + px = reshaped[0:center] if cn else reshaped[0:-batch] + nx = reshaped[center:] if cn else reshaped[-batch:] + outs = [px,nx] if cn else [px] + for out in outs: + c = 0 + for i,ocell in enumerate(ocells): + for icell in icells[i]: + if "ROW" in mode: + out[0:batch,int(h*ocell[0]):int(h*ocell[1]),int(w*icell[0]):int(w*icell[1]),:] = out[c*batch:(c+1)*batch,int(h*ocell[0]):int(h*ocell[1]),int(w*icell[0]):int(w*icell[1]),:] + else: + out[0:batch,int(h*icell[0]):int(h*icell[1]),int(w*ocell[0]):int(w*ocell[1]),:] = out[c*batch:(c+1)*batch,int(h*icell[0]):int(h*icell[1]),int(w*ocell[0]):int(w*ocell[1]),:] + c += 1 + px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx) + hidden_states = torch.cat([nx,px],0) if revers else torch.cat([px,nx],0) + hidden_states = hidden_states.reshape(xshape) + + #### Regional Prompting Prompt mode + elif "PRO" in mode: + center = reshaped.shape[0] // 2 + px = reshaped[0:center] if cn else reshaped[0:-batch] + nx = reshaped[center:] if cn else reshaped[-batch:] + + if (h,w) in self.attnmasks and self.maskready: + def mask(input): + out = torch.multiply(input,self.attnmasks[(h,w)]) + for b in range(batch): + for r in range(1, regions): + out[b] = out[b] + out[r * batch + b] + return out + px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx) + px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx) + hidden_states = torch.cat([nx,px],0) if revers else torch.cat([px,nx],0) + return hidden_states + + return forward + + def hook_forwards(root_module: torch.nn.Module): + for name, module in root_module.named_modules(): + if "attn2" in name and module.__class__.__name__ == "Attention": + module.forward = hook_forward(module) + + hook_forwards(self.unet) + + output = StableDiffusionPipeline(**self.components)( + prompt=prompt, + prompt_embeds=embs, + negative_prompt=negative_prompt, + negative_prompt_embeds=n_embs, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback_on_step_end = pcallback + ) + + if "save_mask" in rp_args: + save_mask = rp_args["save_mask"] + else: + save_mask = False + + if mode == "PROMPT" and save_mask: saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions) + + return output + + +### Make prompt list for each regions +def promptsmaker(prompts,batch): + out_p = [] + plen = len(prompts) + for prompt in prompts: + add = "" + if KCOMM in prompt: + add, prompt = prompt.split(KCOMM) + add = add + " " + prompts = prompt.split(KBRK) + out_p.append([add + p for p in prompts]) + out = [None]*batch*len(out_p[0]) * len(out_p) + for p, prs in enumerate(out_p): # inputs prompts + for r, pr in enumerate(prs): # prompts for regions + start = (p + r * plen) * batch + out[start : start + batch]= [pr] * batch #P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1... + return out, out_p + +### make regions from ratios +### ";" makes outercells, "," makes inner cells +def make_cells(ratios): + if ";" not in ratios and "," in ratios:ratios = ratios.replace(",",";") + ratios = ratios.split(";") + ratios = [inratios.split(",") for inratios in ratios] + + icells = [] + ocells = [] + + def startend(cells,array): + current_start = 0 + array = [float(x) for x in array] + for value in array: + end = current_start + (value / sum(array)) + cells.append([current_start, end]) + current_start = end + + startend(ocells,[r[0] for r in ratios]) + + for inratios in ratios: + if 2 > len(inratios): + icells.append([[0,1]]) + else: + add = [] + startend(add,inratios[1:]) + icells.append(add) + + return ocells, icells, sum(len(cell) for cell in icells) + +def make_emblist(self, prompts): + with torch.no_grad(): + tokens = self.tokenizer(prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids.to(self.device) + embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype = self.dtype) + return embs +import math + +def split_dims(xs, height, width): + xs = xs + def repeat_div(x,y): + while y > 0: + x = math.ceil(x / 2) + y = y - 1 + return x + scale = math.ceil(math.log2(math.sqrt(height * width / xs))) + dsh = repeat_div(height,scale) + dsw = repeat_div(width,scale) + return dsh,dsw + +##### for prompt mode +def get_attn_maps(self,attn): + height,width = self.hw + target_tokens = self.target_tokens + if (height,width) not in self.attnmaps_sizes: + self.attnmaps_sizes.append((height,width)) + + for b in range(self.batch): + for t in target_tokens: + power = self.power + add = attn[b,:,:,t[0]:t[0]+len(t)]**(power)*(self.attnmaps_sizes.index((height,width)) + 1) + add = torch.sum(add,dim = 2) + key = f"{t}-{b}" + if key not in self.attnmaps: + self.attnmaps[key] = add + else: + if self.attnmaps[key].shape[1] != add.shape[1]: + add = add.view(8,height,width) + add = FF.resize(add,self.attnmaps_sizes[0],antialias=None) + add = add.reshape_as(self.attnmaps[key]) + + self.attnmaps[key] = self.attnmaps[key] + add + +def reset_attnmaps(self): # init parameters in every batch + self.step = 0 + self.attnmaps = {} #maked from attention maps + self.attnmaps_sizes =[] #height,width set of u-net blocks + self.attnmasks = {} #maked from attnmaps for regions + self.maskready = False + self.history = {} + +def saveattnmaps(self,output,h,w,th,step,regions): + masks = [] + for i, mask in enumerate(self.history[step].values()): + img, _ , mask = makepmask(self, mask, h, w, th[i % len(th)], step) + if self.ex: + masks = [x - mask for x in masks] + masks.append(mask) + if len(masks) == regions - 1: + output.images.extend([FF.to_pil_image(mask) for mask in masks]) + masks = [] + else: + output.images.append(img) + +def makepmask(self, mask, h, w, th, step): # make masks from attention cache return [for preview, for attention, for Latent] + th = th - step * 0.005 + if 0.05 >= th: th = 0.05 + mask = torch.mean(mask,dim=0) + mask = mask / mask.max().item() + mask = torch.where(mask > th ,1,0) + mask = mask.float() + mask = mask.view(1,*self.attnmaps_sizes[0]) + img = FF.to_pil_image(mask) + img = img.resize((w,h)) + mask = FF.resize(mask,(h,w),interpolation=FF.InterpolationMode.NEAREST,antialias=None) + lmask = mask + mask = mask.reshape(h*w) + mask = torch.where(mask > 0.1 ,1,0) + return img, mask, lmask + +def tokendealer(self, all_prompts): + for prompts in all_prompts: + targets =[p.split(",")[-1] for p in prompts[1:]] + tt = [] + + for target in targets: + ptokens = (self.tokenizer(prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids)[0] + ttokens = (self.tokenizer(target, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids)[0] + + tlist = [] + + for t in range(ttokens.shape[0] -2): + for p in range(ptokens.shape[0]): + if ttokens[t + 1] == ptokens[p]: + tlist.append(p) + if tlist != [] : tt.append(tlist) + + return tt + +def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn = False) -> torch.Tensor: + # Efficient implementation equivalent to the following: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype,device=self.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + if getattn: get_attn_maps(self,attn_weight) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value \ No newline at end of file