Skip to content

Commit c049372

Browse files
authored
Remove the usage of numpy in up/down sample_2d (#503)
* Fix PT up/down sample_2d * empty commit * style * style Co-authored-by: ydshieh <[email protected]>
1 parent c727a6a commit c049372

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

src/diffusers/models/resnet.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from functools import partial
22

3-
import numpy as np
43
import torch
54
import torch.nn as nn
65
import torch.nn.functional as F
@@ -134,10 +133,10 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
134133
kernel = [1] * factor
135134

136135
# setup kernel
137-
kernel = np.asarray(kernel, dtype=np.float32)
136+
kernel = torch.tensor(kernel, dtype=torch.float32)
138137
if kernel.ndim == 1:
139-
kernel = np.outer(kernel, kernel)
140-
kernel /= np.sum(kernel)
138+
kernel = torch.outer(kernel, kernel)
139+
kernel /= torch.sum(kernel)
141140

142141
kernel = kernel * (gain * (factor**2))
143142

@@ -219,10 +218,10 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
219218
kernel = [1] * factor
220219

221220
# setup kernel
222-
kernel = np.asarray(kernel, dtype=np.float32)
221+
kernel = torch.tensor(kernel, dtype=torch.float32)
223222
if kernel.ndim == 1:
224-
kernel = np.outer(kernel, kernel)
225-
kernel /= np.sum(kernel)
223+
kernel = torch.outer(kernel, kernel)
224+
kernel /= torch.sum(kernel)
226225

227226
kernel = kernel * gain
228227

@@ -391,16 +390,14 @@ def upsample_2d(x, kernel=None, factor=2, gain=1):
391390
if kernel is None:
392391
kernel = [1] * factor
393392

394-
kernel = np.asarray(kernel, dtype=np.float32)
393+
kernel = torch.tensor(kernel, dtype=torch.float32)
395394
if kernel.ndim == 1:
396-
kernel = np.outer(kernel, kernel)
397-
kernel /= np.sum(kernel)
395+
kernel = torch.outer(kernel, kernel)
396+
kernel /= torch.sum(kernel)
398397

399398
kernel = kernel * (gain * (factor**2))
400399
p = kernel.shape[0] - factor
401-
return upfirdn2d_native(
402-
x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
403-
)
400+
return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
404401

405402

406403
def downsample_2d(x, kernel=None, factor=2, gain=1):
@@ -425,14 +422,14 @@ def downsample_2d(x, kernel=None, factor=2, gain=1):
425422
if kernel is None:
426423
kernel = [1] * factor
427424

428-
kernel = np.asarray(kernel, dtype=np.float32)
425+
kernel = torch.tensor(kernel, dtype=torch.float32)
429426
if kernel.ndim == 1:
430-
kernel = np.outer(kernel, kernel)
431-
kernel /= np.sum(kernel)
427+
kernel = torch.outer(kernel, kernel)
428+
kernel /= torch.sum(kernel)
432429

433430
kernel = kernel * gain
434431
p = kernel.shape[0] - factor
435-
return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
432+
return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
436433

437434

438435
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):

0 commit comments

Comments
 (0)