11from functools import partial
22
3- import numpy as np
43import torch
54import torch .nn as nn
65import torch .nn .functional as F
@@ -134,10 +133,10 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
134133 kernel = [1 ] * factor
135134
136135 # setup kernel
137- kernel = np . asarray (kernel , dtype = np .float32 )
136+ kernel = torch . tensor (kernel , dtype = torch .float32 )
138137 if kernel .ndim == 1 :
139- kernel = np .outer (kernel , kernel )
140- kernel /= np .sum (kernel )
138+ kernel = torch .outer (kernel , kernel )
139+ kernel /= torch .sum (kernel )
141140
142141 kernel = kernel * (gain * (factor ** 2 ))
143142
@@ -219,10 +218,10 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
219218 kernel = [1 ] * factor
220219
221220 # setup kernel
222- kernel = np . asarray (kernel , dtype = np .float32 )
221+ kernel = torch . tensor (kernel , dtype = torch .float32 )
223222 if kernel .ndim == 1 :
224- kernel = np .outer (kernel , kernel )
225- kernel /= np .sum (kernel )
223+ kernel = torch .outer (kernel , kernel )
224+ kernel /= torch .sum (kernel )
226225
227226 kernel = kernel * gain
228227
@@ -391,16 +390,14 @@ def upsample_2d(x, kernel=None, factor=2, gain=1):
391390 if kernel is None :
392391 kernel = [1 ] * factor
393392
394- kernel = np . asarray (kernel , dtype = np .float32 )
393+ kernel = torch . tensor (kernel , dtype = torch .float32 )
395394 if kernel .ndim == 1 :
396- kernel = np .outer (kernel , kernel )
397- kernel /= np .sum (kernel )
395+ kernel = torch .outer (kernel , kernel )
396+ kernel /= torch .sum (kernel )
398397
399398 kernel = kernel * (gain * (factor ** 2 ))
400399 p = kernel .shape [0 ] - factor
401- return upfirdn2d_native (
402- x , torch .tensor (kernel , device = x .device ), up = factor , pad = ((p + 1 ) // 2 + factor - 1 , p // 2 )
403- )
400+ return upfirdn2d_native (x , kernel .to (device = x .device ), up = factor , pad = ((p + 1 ) // 2 + factor - 1 , p // 2 ))
404401
405402
406403def downsample_2d (x , kernel = None , factor = 2 , gain = 1 ):
@@ -425,14 +422,14 @@ def downsample_2d(x, kernel=None, factor=2, gain=1):
425422 if kernel is None :
426423 kernel = [1 ] * factor
427424
428- kernel = np . asarray (kernel , dtype = np .float32 )
425+ kernel = torch . tensor (kernel , dtype = torch .float32 )
429426 if kernel .ndim == 1 :
430- kernel = np .outer (kernel , kernel )
431- kernel /= np .sum (kernel )
427+ kernel = torch .outer (kernel , kernel )
428+ kernel /= torch .sum (kernel )
432429
433430 kernel = kernel * gain
434431 p = kernel .shape [0 ] - factor
435- return upfirdn2d_native (x , torch . tensor ( kernel , device = x .device ), down = factor , pad = ((p + 1 ) // 2 , p // 2 ))
432+ return upfirdn2d_native (x , kernel . to ( device = x .device ), down = factor , pad = ((p + 1 ) // 2 , p // 2 ))
436433
437434
438435def upfirdn2d_native (input , kernel , up = 1 , down = 1 , pad = (0 , 0 )):
0 commit comments