Skip to content

Error while running inference #4

@aisagarw

Description

@aisagarw

Hi,
I am getting the following error on running the inference.ipynb notebook
TypeError: prepare_attention_mask() takes 3 positional arguments but 4 were given

Traceback:

─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:5                                                                                    │
│                                                                                                  │
│   2 │   ps = pm.embed_prompt("a colorful photo of a <yc> in the jungles")                        │
│   3 torch.manual_seed(0)                                                                         │
│   4 print(ps["CONTEXT_TENSOR_0"].shape)                                                          │
│ ❱ 5 im = overwrite_call(pipe, prompt_embeds=ps).images[0]  # .save("contents/yc_.jpg")           │
│   6 # im.save("contents/yc_eti_coarse.jpg")                                                      │
│   7 im                                                                                           │
│   8                                                                                              │
│                                                                                                  │
│ /opt/conda/envs/ldm2/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27 in               │
│ decorate_context                                                                                 │
│                                                                                                  │
│    24 │   │   @functools.wraps(func)                                                             │
│    25 │   │   def decorate_context(*args, **kwargs):                                             │
│    26 │   │   │   with self.clone():                                                             │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                                               │
│    28 │   │   return cast(F, decorate_context)                                                   │
│    29 │                                                                                          │
│    30 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ /opt/conda/envs/ldm2/lib/python3.8/site-packages/ppp/pipeline_call.py:191 in overwrite_call      │
│                                                                                                  │
│   188 │   │   │   │   # compare the inputs going to this part in this repo vs the baseline SD    │
│   189 │   │   │   │   # need to check if the difference is mainly in how the inputs are given    │
│   190 │   │   │   │                                                                              │
│ ❱ 191 │   │   │   │   noise_pred_uncond = self.unet(                                             │
│   192 │   │   │   │   │   latent_model_input,                                                    │
│   193 │   │   │   │   │   t,                                                                     │
│   194 │   │   │   │   │   encoder_hidden_states=negative_prompt_embeds,                          │
│                                                                                                  │
│ /opt/conda/envs/ldm2/lib/python3.8/site-packages/torch/nn/modules/module.py:1110 in _call_impl   │
│                                                                                                  │
│   1107 │   │   # this function, and just call forward.                                           │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1111 │   │   # Do not call functions when jit is used                                          │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /opt/conda/envs/ldm2/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py:481 in    │
│ forward                                                                                          │
│                                                                                                  │
│   478 │   │   down_block_res_samples = (sample,)                                                 │
│   479 │   │   for downsample_block in self.down_blocks:                                          │
│   480 │   │   │   if hasattr(downsample_block, "has_cross_attention") and downsample_block.has   │
│ ❱ 481 │   │   │   │   sample, res_samples = downsample_block(                                    │
│   482 │   │   │   │   │   hidden_states=sample,                                                  │
│   483 │   │   │   │   │   temb=emb,                                                              │
│   484 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,                           │
│                                                                                                  │
│ /opt/conda/envs/ldm2/lib/python3.8/site-packages/torch/nn/modules/module.py:1110 in _call_impl   │
│                                                                                                  │
│   1107 │   │   # this function, and just call forward.                                           │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1111 │   │   # Do not call functions when jit is used                                          │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /opt/conda/envs/ldm2/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py:789 in       │
│ forward                                                                                          │
│                                                                                                  │
│    786 │   │   │   │   )[0]                                                                      │
│    787 │   │   │   else:                                                                         │
│    788 │   │   │   │   hidden_states = resnet(hidden_states, temb)                               │
│ ❱  789 │   │   │   │   hidden_states = attn(                                                     │
│    790 │   │   │   │   │   hidden_states,                                                        │
│    791 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,                          │
│    792 │   │   │   │   │   cross_attention_kwargs=cross_attention_kwargs,                        │
│                                                                                                  │
│ /opt/conda/envs/ldm2/lib/python3.8/site-packages/torch/nn/modules/module.py:1110 in _call_impl   │
│                                                                                                  │
│   1107 │   │   # this function, and just call forward.                                           │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1111 │   │   # Do not call functions when jit is used                                          │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /opt/conda/envs/ldm2/lib/python3.8/site-packages/diffusers/models/transformer_2d.py:265 in       │
│ forward                                                                                          │
│                                                                                                  │
│   262 │   │                                                                                      │
│   263 │   │   # 2. Blocks                                                                        │
│   264 │   │   for block in self.transformer_blocks:                                              │
│ ❱ 265 │   │   │   hidden_states = block(                                                         │
│   266 │   │   │   │   hidden_states,                                                             │
│   267 │   │   │   │   encoder_hidden_states=encoder_hidden_states,                               │
│   268 │   │   │   │   timestep=timestep,                                                         │
│                                                                                                  │
│ /opt/conda/envs/ldm2/lib/python3.8/site-packages/torch/nn/modules/module.py:1110 in _call_impl   │
│                                                                                                  │
│   1107 │   │   # this function, and just call forward.                                           │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1111 │   │   # Do not call functions when jit is used                                          │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /opt/conda/envs/ldm2/lib/python3.8/site-packages/diffusers/models/attention.py:291 in forward    │
│                                                                                                  │
│   288 │   │                                                                                      │
│   289 │   │   # 1. Self-Attention                                                                │
│   290 │   │   cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not   │
│ ❱ 291 │   │   attn_output = self.attn1(                                                          │
│   292 │   │   │   norm_hidden_states,                                                            │
│   293 │   │   │   encoder_hidden_states=encoder_hidden_states if self.only_cross_attention els   │
│   294 │   │   │   attention_mask=attention_mask,                                                 │
│                                                                                                  │
│ /opt/conda/envs/ldm2/lib/python3.8/site-packages/torch/nn/modules/module.py:1110 in _call_impl   │
│                                                                                                  │
│   1107 │   │   # this function, and just call forward.                                           │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1111 │   │   # Do not call functions when jit is used                                          │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /opt/conda/envs/ldm2/lib/python3.8/site-packages/diffusers/models/cross_attention.py:160 in      │
│ forward                                                                                          │
│                                                                                                  │
│   157 │   │   # The `CrossAttention` class can call different attention processors / attention   │
│   158 │   │   # here we simply pass along all tensors to the selected processor class            │
│   159 │   │   # For standard processors that are defined here, `**cross_attention_kwargs` is e   │
│ ❱ 160 │   │   return self.processor(                                                             │
│   161 │   │   │   self,                                                                          │
│   162 │   │   │   hidden_states,                                                                 │
│   163 │   │   │   encoder_hidden_states=encoder_hidden_states,                                   │
│                                                                                                  │
│ /opt/conda/envs/ldm2/lib/python3.8/site-packages/ppp/utils.py:48 in __call__                     │
│                                                                                                  │
│    45 │   │   print("attention_mask:", attention_mask)                                           │
│    46 │   │   print("sequence_length: ", sequence_length)                                        │
│    47 │   │                                                                                      │
│ ❱  48 │   │   attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, ba   │
│    49 │   │   query = attn.to_q(hidden_states)                                                   │
│    50 │   │                                                                                      │
│    51 │   │   if _ehs is None:            

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions