From f38dbdabb26374db146d91f791677e28ac9b108b Mon Sep 17 00:00:00 2001 From: ydshieh Date: Tue, 13 Sep 2022 19:43:22 +0200 Subject: [PATCH 1/4] Fix PT up/down sample_2d --- src/diffusers/models/resnet.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 27fae24f71d8..e663a51f56cb 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -134,10 +134,10 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): kernel = [1] * factor # setup kernel - kernel = np.asarray(kernel, dtype=np.float32) + kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: - kernel = np.outer(kernel, kernel) - kernel /= np.sum(kernel) + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) @@ -219,10 +219,10 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): kernel = [1] * factor # setup kernel - kernel = np.asarray(kernel, dtype=np.float32) + kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: - kernel = np.outer(kernel, kernel) - kernel /= np.sum(kernel) + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) kernel = kernel * gain @@ -391,15 +391,15 @@ def upsample_2d(x, kernel=None, factor=2, gain=1): if kernel is None: kernel = [1] * factor - kernel = np.asarray(kernel, dtype=np.float32) + kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: - kernel = np.outer(kernel, kernel) - kernel /= np.sum(kernel) + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) p = kernel.shape[0] - factor return upfirdn2d_native( - x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) + x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) ) @@ -425,14 +425,14 @@ def downsample_2d(x, kernel=None, factor=2, gain=1): if kernel is None: kernel = [1] * factor - kernel = np.asarray(kernel, dtype=np.float32) + kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: - kernel = np.outer(kernel, kernel) - kernel /= np.sum(kernel) + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) kernel = kernel * gain p = kernel.shape[0] - factor - return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): From 4c05beda742b6973bc5c8cdd2aacdb9855183f19 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Tue, 13 Sep 2022 19:59:04 +0200 Subject: [PATCH 2/4] empty commit From 7a1ebbfc02b38679c6b9747731acfdb874e8d5ea Mon Sep 17 00:00:00 2001 From: ydshieh Date: Tue, 13 Sep 2022 20:04:49 +0200 Subject: [PATCH 3/4] style --- src/diffusers/models/resnet.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index e663a51f56cb..24513044dede 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -398,9 +398,7 @@ def upsample_2d(x, kernel=None, factor=2, gain=1): kernel = kernel * (gain * (factor**2)) p = kernel.shape[0] - factor - return upfirdn2d_native( - x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) - ) + return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) def downsample_2d(x, kernel=None, factor=2, gain=1): From 720f034a93b3ce4c818718da748222b6ba93a70e Mon Sep 17 00:00:00 2001 From: ydshieh Date: Tue, 13 Sep 2022 20:06:13 +0200 Subject: [PATCH 4/4] style --- src/diffusers/models/resnet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 24513044dede..0623b895ac5c 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -1,6 +1,5 @@ from functools import partial -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F