|
12 | 12 | def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): |
13 | 13 | source = source.to(destination.device) |
14 | 14 | if resize_source: |
15 | | - source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") |
| 15 | + source = torch.nn.functional.interpolate(source, size=(destination.shape[-2], destination.shape[-1]), mode="bilinear") |
16 | 16 |
|
17 | 17 | source = comfy.utils.repeat_to_batch_size(source, destination.shape[0]) |
18 | 18 |
|
19 | | - x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) |
20 | | - y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) |
| 19 | + x = max(-source.shape[-1] * multiplier, min(x, destination.shape[-1] * multiplier)) |
| 20 | + y = max(-source.shape[-2] * multiplier, min(y, destination.shape[-2] * multiplier)) |
21 | 21 |
|
22 | 22 | left, top = (x // multiplier, y // multiplier) |
23 | | - right, bottom = (left + source.shape[3], top + source.shape[2],) |
| 23 | + right, bottom = (left + source.shape[-1], top + source.shape[-2],) |
24 | 24 |
|
25 | 25 | if mask is None: |
26 | 26 | mask = torch.ones_like(source) |
27 | 27 | else: |
28 | 28 | mask = mask.to(destination.device, copy=True) |
29 | | - mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") |
| 29 | + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[-2], source.shape[-1]), mode="bilinear") |
30 | 30 | mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) |
31 | 31 |
|
32 | 32 | # calculate the bounds of the source that will be overlapping the destination |
33 | 33 | # this prevents the source trying to overwrite latent pixels that are out of bounds |
34 | 34 | # of the destination |
35 | | - visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) |
| 35 | + visible_width, visible_height = (destination.shape[-1] - left + min(0, x), destination.shape[-2] - top + min(0, y),) |
36 | 36 |
|
37 | 37 | mask = mask[:, :, :visible_height, :visible_width] |
| 38 | + if mask.ndim < source.ndim: |
| 39 | + mask = mask.unsqueeze(1) |
| 40 | + |
38 | 41 | inverse_mask = torch.ones_like(mask) - mask |
39 | 42 |
|
40 | | - source_portion = mask * source[:, :, :visible_height, :visible_width] |
41 | | - destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] |
| 43 | + source_portion = mask * source[..., :visible_height, :visible_width] |
| 44 | + destination_portion = inverse_mask * destination[..., top:bottom, left:right] |
42 | 45 |
|
43 | | - destination[:, :, top:bottom, left:right] = source_portion + destination_portion |
| 46 | + destination[..., top:bottom, left:right] = source_portion + destination_portion |
44 | 47 | return destination |
45 | 48 |
|
46 | 49 | class LatentCompositeMasked: |
|
0 commit comments