Skip to content

Commit 5e55200

Browse files
Multi-tensor axpby kernel for more flexible unscaling (groundwork for pytorch#163 and pytorch#179 fix)
1 parent 74c06d8 commit 5e55200

File tree

6 files changed

+270
-3
lines changed

6 files changed

+270
-3
lines changed

csrc/amp_C_frontend.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,16 @@ void multi_tensor_scale_cuda(
66
std::vector<std::vector<at::Tensor>> tensor_lists,
77
float scale);
88

9+
void multi_tensor_axpby_cuda(
10+
int chunk_size,
11+
at::Tensor noop_flag,
12+
std::vector<std::vector<at::Tensor>> tensor_lists,
13+
float a,
14+
float b);
15+
916
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1017
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
1118
"Fused overflow check + scale for a list of contiguous tensors");
19+
m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
20+
"out = a*x + b*y for a list of contiguous tensors");
1221
}

csrc/multi_tensor_axpby_kernel.cu

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/AccumulateType.h>
3+
#include <ATen/cuda/CUDAContext.h>
4+
#include <ATen/cuda/Exceptions.h>
5+
// Another possibility:
6+
// #include <torch/all.h>
7+
8+
#include <assert.h>
9+
10+
#include "type_shim.h"
11+
#include "multi_tensor_apply.cuh"
12+
13+
#define BLOCK_SIZE 512
14+
#define ILP 4
15+
16+
template<typename x_t, typename y_t, typename out_t>
17+
struct AxpbyFunctor
18+
{
19+
__device__ __forceinline__ void operator()(
20+
int chunk_size,
21+
volatile int* noop_gmem,
22+
TensorListMetadata<3>& tl,
23+
float a,
24+
float b)
25+
{
26+
// I'd like this kernel to propagate infs/nans.
27+
// if(*noop_gmem == 1)
28+
// return;
29+
30+
int tensor_loc = tl.block_to_tensor[blockIdx.x];
31+
int chunk_idx = tl.block_to_chunk[blockIdx.x];
32+
int n = tl.sizes[tensor_loc];
33+
34+
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
35+
x += chunk_idx*chunk_size;
36+
37+
y_t* y = (y_t*)tl.addresses[1][tensor_loc];
38+
y += chunk_idx*chunk_size;
39+
40+
out_t* out = (out_t*)tl.addresses[2][tensor_loc];
41+
out += chunk_idx*chunk_size;
42+
43+
n -= chunk_idx*chunk_size;
44+
45+
// Non-divergent exit condition for __syncthreads, not necessary here
46+
float xs[ILP];
47+
float ys[ILP];
48+
for(int i_start = 0;
49+
i_start < n && i_start < chunk_size;
50+
i_start += blockDim.x*ILP)
51+
{
52+
#pragma unroll
53+
for(int ii = 0; ii < ILP; ii++)
54+
{
55+
xs[ii] = 0;
56+
ys[ii] = 0;
57+
int i = i_start + threadIdx.x + ii*blockDim.x;
58+
if(i < n && i < chunk_size)
59+
{
60+
xs[ii] = static_cast<float>(x[i]);
61+
ys[ii] = static_cast<float>(y[i]);
62+
}
63+
}
64+
65+
// see note in multi_tensor_scale_kernel.cu
66+
#pragma unroll
67+
for(int ii = 0; ii < ILP; ii++)
68+
{
69+
int i = i_start + threadIdx.x + ii*blockDim.x;
70+
if(i < n && i < chunk_size)
71+
if(isfinite(xs[ii]) && isfinite(ys[ii]))
72+
out[i] = static_cast<out_t>(a*xs[ii] + b*ys[ii]);
73+
else
74+
{
75+
out[i] = static_cast<out_t>(a*xs[ii] + b*ys[ii]);
76+
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
77+
}
78+
}
79+
}
80+
}
81+
};
82+
83+
void multi_tensor_axpby_cuda(
84+
int chunk_size,
85+
at::Tensor noop_flag,
86+
std::vector<std::vector<at::Tensor>> tensor_lists,
87+
float a,
88+
float b)
89+
{
90+
using namespace at;
91+
// The output (downscaled) type is always float.
92+
// If build times suffer, think about where to put this dispatch,
93+
// and what logic should be moved out of multi_tensor_apply.
94+
95+
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda",
96+
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda",
97+
DISPATCH_FLOAT_AND_HALF(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda",
98+
multi_tensor_apply<3>(
99+
BLOCK_SIZE,
100+
chunk_size,
101+
noop_flag,
102+
tensor_lists,
103+
AxpbyFunctor<scalar_t_0, scalar_t_1, scalar_t_2>(),
104+
a,
105+
b); )))
106+
107+
AT_CUDA_CHECK(cudaGetLastError());
108+
109+
// AT_CUDA_CHECK(cudaDeviceSynchronize());
110+
}

csrc/type_shim.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,22 @@ struct TypeShim
1212
// Enable dispatch switch statements to take *this directly for post-3aeb78
1313
operator at::ScalarType(){ return payload.scalarType(); };
1414
};
15+
16+
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
17+
switch(TYPE) \
18+
{ \
19+
case at::ScalarType::Float: \
20+
{ \
21+
using scalar_t_##LEVEL = float; \
22+
__VA_ARGS__; \
23+
break; \
24+
} \
25+
case at::ScalarType::Half: \
26+
{ \
27+
using scalar_t_##LEVEL = at::Half; \
28+
__VA_ARGS__; \
29+
break; \
30+
} \
31+
default: \
32+
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
33+
}

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@
4848
ext_modules.append(
4949
CUDAExtension(name='amp_C',
5050
sources=['csrc/amp_C_frontend.cpp',
51-
'csrc/multi_tensor_scale_kernel.cu'],
51+
'csrc/multi_tensor_scale_kernel.cu',
52+
'csrc/multi_tensor_axpby_kernel.cu'],
5253
extra_compile_args={'cxx': ['-O3'],
5354
'nvcc':['-lineinfo',
5455
'-O3',
56+
# '--resource-usage',
5557
'--use_fast_math']}))
5658
ext_modules.append(
5759
CUDAExtension(name='fused_adam_cuda',
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import unittest
2+
3+
import functools as ft
4+
import itertools as it
5+
6+
from apex import amp
7+
import torch
8+
from torch import nn
9+
import torch.nn.functional as F
10+
11+
from utils import common_init, HALF, FLOAT,\
12+
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
13+
14+
try:
15+
import amp_C
16+
from amp_C import multi_tensor_axpby
17+
from apex.multi_tensor_apply import MultiTensorApply
18+
disabled = False
19+
except ImportError as err:
20+
print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
21+
disabled = True
22+
23+
24+
class TestMultiTensorAxpby(unittest.TestCase):
25+
26+
def setUp(self):
27+
common_init(self)
28+
29+
self.a = 2.0
30+
self.b = 8.0
31+
self.xval = 4.0
32+
self.yval = 16.0
33+
self.overflow_buf = torch.cuda.IntTensor(1).zero_()
34+
self.ref = torch.cuda.FloatTensor([136.0])
35+
36+
def tearDown(self):
37+
pass
38+
39+
# The tensor creation here is written for convenience, not speed.
40+
def axpby(self, sizea, sizeb, applier, repeat_tensors,
41+
x_type, y_type, out_type, inplace=False):
42+
self.overflow_buf.zero_()
43+
t1 = torch.cuda.FloatTensor(sizea).fill_(1.0)
44+
t2 = torch.cuda.FloatTensor(sizeb).fill_(1.0)
45+
46+
y_list = []
47+
for i in range(repeat_tensors):
48+
y_list += [t1.clone().to(y_type)*self.yval, t2.clone().to(y_type)*self.yval]
49+
50+
x_list = [x.clone().to(x_type)*(self.xval/self.yval) for x in y_list]
51+
52+
if inplace:
53+
out_list = y_list
54+
else:
55+
out_list = [out.clone().to(out_type)*3.0 for out in y_list]
56+
57+
applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b)
58+
59+
self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]),
60+
msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors,
61+
x_type, y_type, out_type, inplace))
62+
self.assertTrue(self.overflow_buf.item() == 0,
63+
msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors,
64+
x_type, y_type, out_type, inplace))
65+
66+
# def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False):
67+
# self.overflow_buf.zero_()
68+
# a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
69+
# b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)
70+
71+
# out_list = []
72+
# for i in range(repeat_tensors):
73+
# out_list += [a.clone().to(out_type), b.clone().to(out_type)]
74+
75+
# if inplace:
76+
# in_list = out_list
77+
# else:
78+
# in_list = [out.clone().to(in_type) for out in out_list]
79+
80+
# applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
81+
82+
# self.overflow_buf.zero_()
83+
# in_list[t][ind] = val
84+
# applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
85+
# self.assertTrue(self.overflow_buf.item())
86+
87+
@unittest.skipIf(disabled, "amp_C is unavailable")
88+
def test_fuzz(self):
89+
input_size_pairs = (
90+
(7777*77, 555*555),
91+
(777, 555),
92+
(555, 2048*32+1),
93+
(2048*32+1, 555),
94+
(555, 2048*32),
95+
(2048*32, 555),
96+
(33333, 555),
97+
(555, 33333))
98+
appliers = (
99+
MultiTensorApply(2048*32),
100+
MultiTensorApply(333),
101+
MultiTensorApply(33333))
102+
repeat_tensors = (
103+
1,
104+
55)
105+
106+
for sizea, sizeb in input_size_pairs:
107+
for applier in appliers:
108+
for repeat in repeat_tensors:
109+
for x_type in (torch.float32, torch.float16):
110+
for y_type in (torch.float32, torch.float16):
111+
for out_type in (torch.float32, torch.float16):
112+
for inplace in (True, False):
113+
if inplace is True and (y_type is not out_type):
114+
continue
115+
else:
116+
self.axpby(sizea, sizeb, applier, repeat,
117+
x_type, y_type, out_type, inplace=inplace)
118+
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
119+
# 0, 0, float('nan'), inplace=inplace)
120+
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
121+
# 2*repeat-1, sizeb-1, float('inf'), inplace=inplace)
122+
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
123+
# 2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
124+
125+
126+
127+
if __name__ == '__main__':
128+
unittest.main()

tests/L0/run_amp/test_multi_tensor_scale.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,11 @@
2424
class TestMultiTensorScale(unittest.TestCase):
2525

2626
def setUp(self):
27+
common_init(self)
2728
self.scale = 4.0
2829
self.overflow_buf = torch.cuda.IntTensor(1).zero_()
2930
self.ref = torch.cuda.FloatTensor([1.0])
3031

31-
common_init(self)
32-
3332
def tearDown(self):
3433
pass
3534

0 commit comments

Comments
 (0)