Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 68 additions & 73 deletions cuda/lltm_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,64 +37,59 @@ __device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) {

template <typename scalar_t>
__global__ void lltm_cuda_forward_kernel(
const scalar_t* __restrict__ gates,
const scalar_t* __restrict__ old_cell,
scalar_t* __restrict__ new_h,
scalar_t* __restrict__ new_cell,
scalar_t* __restrict__ input_gate,
scalar_t* __restrict__ output_gate,
scalar_t* __restrict__ candidate_cell,
size_t state_size) {
const int column = blockIdx.x * blockDim.x + threadIdx.x;
const int index = blockIdx.y * state_size + column;
const int gates_row = blockIdx.y * (state_size * 3);
if (column < state_size) {
input_gate[index] = sigmoid(gates[gates_row + column]);
output_gate[index] = sigmoid(gates[gates_row + state_size + column]);
candidate_cell[index] = elu(gates[gates_row + 2 * state_size + column]);
new_cell[index] =
old_cell[index] + candidate_cell[index] * input_gate[index];
new_h[index] = tanh(new_cell[index]) * output_gate[index];
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gates,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_cell,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell) {
//batch index
const int n = blockIdx.y;
// column index
const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < gates.size(2)){
input_gate[n][c] = sigmoid(gates[n][0][c]);
output_gate[n][c] = sigmoid(gates[n][1][c]);
candidate_cell[n][c] = elu(gates[n][2][c]);
new_cell[n][c] =
old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c];
new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c];
}
}

template <typename scalar_t>
__global__ void lltm_cuda_backward_kernel(
scalar_t* __restrict__ d_old_cell,
scalar_t* __restrict__ d_gates,
const scalar_t* __restrict__ grad_h,
const scalar_t* __restrict__ grad_cell,
const scalar_t* __restrict__ new_cell,
const scalar_t* __restrict__ input_gate,
const scalar_t* __restrict__ output_gate,
const scalar_t* __restrict__ candidate_cell,
const scalar_t* __restrict__ gate_weights,
size_t state_size) {
const int column = blockIdx.x * blockDim.x + threadIdx.x;
const int index = blockIdx.y * state_size + column;
const int gates_row = blockIdx.y * (state_size * 3);
if (column < state_size) {
const auto d_output_gate = tanh(new_cell[index]) * grad_h[index];
const auto d_tanh_new_cell = output_gate[index] * grad_h[index];
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_cell,
torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> d_gates,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_cell,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell,
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_weights) {
//batch index
const int n = blockIdx.y;
// column index
const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < d_gates.size(2)){
const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c];
const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c];
const auto d_new_cell =
d_tanh(new_cell[index]) * d_tanh_new_cell + grad_cell[index];
d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c];


d_old_cell[index] = d_new_cell;
const auto d_candidate_cell = input_gate[index] * d_new_cell;
const auto d_input_gate = candidate_cell[index] * d_new_cell;
d_old_cell[n][c] = d_new_cell;
const auto d_candidate_cell = input_gate[n][c] * d_new_cell;
const auto d_input_gate = candidate_cell[n][c] * d_new_cell;


const auto input_gate_index = gates_row + column;
const auto output_gate_index = gates_row + state_size + column;
const auto candidate_cell_index = gates_row + 2 * state_size + column;

d_gates[input_gate_index] =
d_input_gate * d_sigmoid(gate_weights[input_gate_index]);
d_gates[output_gate_index] =
d_output_gate * d_sigmoid(gate_weights[output_gate_index]);
d_gates[candidate_cell_index] =
d_candidate_cell * d_elu(gate_weights[candidate_cell_index]);
d_gates[n][0][c] =
d_input_gate * d_sigmoid(gate_weights[n][0][c]);
d_gates[n][1][c] =
d_output_gate * d_sigmoid(gate_weights[n][1][c]);
d_gates[n][2][c] =
d_candidate_cell * d_elu(gate_weights[n][2][c]);
}
}
} // namespace
Expand All @@ -106,11 +101,12 @@ std::vector<torch::Tensor> lltm_cuda_forward(
torch::Tensor old_h,
torch::Tensor old_cell) {
auto X = torch::cat({old_h, input}, /*dim=*/1);
auto gates = torch::addmm(bias, X, weights.transpose(0, 1));
auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));

const auto batch_size = old_cell.size(0);
const auto state_size = old_cell.size(1);

auto gates = gate_weights.reshape({batch_size, 3, state_size});
auto new_h = torch::zeros_like(old_cell);
auto new_cell = torch::zeros_like(old_cell);
auto input_gate = torch::zeros_like(old_cell);
Expand All @@ -122,14 +118,13 @@ std::vector<torch::Tensor> lltm_cuda_forward(

AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
gates.data<scalar_t>(),
old_cell.data<scalar_t>(),
new_h.data<scalar_t>(),
new_cell.data<scalar_t>(),
input_gate.data<scalar_t>(),
output_gate.data<scalar_t>(),
candidate_cell.data<scalar_t>(),
state_size);
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
}));

return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
Expand All @@ -143,10 +138,10 @@ std::vector<torch::Tensor> lltm_cuda_backward(
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor gates,
torch::Tensor weights) {
auto d_old_cell = torch::zeros_like(new_cell);
auto d_gates = torch::zeros_like(gate_weights);
auto d_gates = torch::zeros_like(gates);

const auto batch_size = new_cell.size(0);
const auto state_size = new_cell.size(1);
Expand All @@ -156,22 +151,22 @@ std::vector<torch::Tensor> lltm_cuda_backward(

AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] {
lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
d_old_cell.data<scalar_t>(),
d_gates.data<scalar_t>(),
grad_h.data<scalar_t>(),
grad_cell.data<scalar_t>(),
new_cell.data<scalar_t>(),
input_gate.data<scalar_t>(),
output_gate.data<scalar_t>(),
candidate_cell.data<scalar_t>(),
gate_weights.data<scalar_t>(),
state_size);
d_old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
grad_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
grad_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>());
}));

auto d_weights = d_gates.t().mm(X);
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
auto d_gate_weights = d_gates.flatten(1, 2);
auto d_weights = d_gate_weights.t().mm(X);
auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true);

auto d_X = d_gates.mm(weights);
auto d_X = d_gate_weights.mm(weights);
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
auto d_input = d_X.slice(/*dim=*/1, state_size);

Expand Down