diff --git a/pytorch_superpixpool/suppixpool_cuda.cpp b/pytorch_superpixpool/suppixpool_cuda.cpp index 0ec2819..d785fde 100644 --- a/pytorch_superpixpool/suppixpool_cuda.cpp +++ b/pytorch_superpixpool/suppixpool_cuda.cpp @@ -35,10 +35,10 @@ std::vector suppixpool_max_forward( const int batch_size = img.size(0); const int channels_size = img.size(1); - at::Tensor output = at::zeros(torch::CUDA(at::kInt), {batch_size, channels_size, K}); + at::Tensor output = torch::zeros({batch_size, channels_size, K}, torch::CUDA(at::kInt)); output = output.type_as(img); // torch::set_requires_grad(output, true); - at::Tensor outIdx = -at::ones(torch::CUDA(at::kInt), {batch_size, channels_size, K}); + at::Tensor outIdx = -torch::ones({batch_size, channels_size, K}, torch::CUDA(at::kInt)); return suppixpool_max_cuda_forward(img, spx_labels, output, outIdx, K); // return {output, outIdx}; // return {img, spx_labels}; diff --git a/pytorch_superpixpool/suppixpool_cuda_kernel.cu b/pytorch_superpixpool/suppixpool_cuda_kernel.cu index de745cf..d32046c 100644 --- a/pytorch_superpixpool/suppixpool_cuda_kernel.cu +++ b/pytorch_superpixpool/suppixpool_cuda_kernel.cu @@ -268,7 +268,10 @@ std::vector suppixpool_max_cuda_forward( output.data() ); })); - return {output, outIdx}; + return { + output, + outIdx + }; } std::vector suppixpool_max_cuda_backward( diff --git a/pytorch_superpixpool/suppixpool_layer.py b/pytorch_superpixpool/suppixpool_layer.py index acee15f..9b8bbef 100644 --- a/pytorch_superpixpool/suppixpool_layer.py +++ b/pytorch_superpixpool/suppixpool_layer.py @@ -16,6 +16,7 @@ def forward(ctx, img, spx): # print("number of -1: ", indices.eq(-1).sum()) # print indices # assert np.all(indices.cpu().numpy()>=0) + ctx.save_for_backward(indices, img, spx, K) return outputs @@ -44,6 +45,6 @@ def __init__(self): def forward(self, pooled, spx): outShape = pooled.size()[0:2]+spx.size()[-2:] out = pooled.new_zeros(outShape) - for batch in xrange(pooled.size()[0]): + for batch in range(pooled.size()[0]): out[batch, :, :, :] = pooled[batch, :, spx[batch,:,:]] return out diff --git a/pytorch_superpixpool/test_GPUpool.py b/pytorch_superpixpool/test_GPUpool.py index 0c03305..cce09ae 100644 --- a/pytorch_superpixpool/test_GPUpool.py +++ b/pytorch_superpixpool/test_GPUpool.py @@ -5,32 +5,39 @@ import numpy as np import time from skimage.segmentation import slic +from torch.autograd import Variable -if __name__ == "__main__": - GPU = torch.device("cuda:0") +if __name__ == "__main__": + + GPU = torch.device('cuda') batch_size = 1 - n_channels = 16 - xSize = 256 - ySize = 512 + n_channels = 2 + xSize = 4 + ySize = 4 - X = torch.randn((batch_size,n_channels,xSize,ySize), dtype=torch.float32, device=GPU) + X = torch.randn((batch_size,n_channels,xSize,ySize), dtype=torch.float32, device=GPU,requires_grad=True) spx = np.array([np.arange(xSize*ySize).reshape(xSize,ySize)]*batch_size) # spx = np.zeros((batch_size, xSize, ySize)) spx = torch.from_numpy(spx) spx = spx.to(GPU) + # X + X print ("INPUT ARRAY ----------------- \n", X) pool = SupPixPool() pld = pool(X, spx) + + + print ("POOLED ARRAY ----------------- \n", pld) print ("Shape of pooled array: ", pld.size()) - unpool = SupPixUnpool() - unpld = unpool(pld, spx) - print ("Unpooling back to original: ", np.all(unpld == X)) + # unpool = SupPixUnpool() + # unpld = unpool(pld, spx) + # print(unpld.shape, X.shape) + #print ("Unpooling back to original: ", np.all(unpld.detach().cpu().numpy() == X.detach().cpu().numpy())) - res = torch.autograd.gradcheck(pool, (X, spx), raise_exception=False) - resUnpool = torch.autograd.gradcheck(unpool, (pld, spx), raise_exception=False) + res = torch.autograd.gradcheck(pool, (X.double(), spx), raise_exception=True) + # resUnpool = torch.autograd.gradcheck(unpool, (pld, spx), raise_exception=False) - print ("Gradients of pooling are {}.".format("correct" if res else "wrong")) # res should be True if the gradients are correct. - print ("Gradients of unpooling are {}.".format("correct" if resUnpool else "wrong")) \ No newline at end of file + # print ("Gradients of pooling are {}.".format("correct" if res else "wrong")) # res should be True if the gradients are correct. + # print ("Gradients of unpooling are {}.".format("correct" if resUnpool else "wrong")) \ No newline at end of file