diff --git a/cuda/lltm_cuda_kernel.cu b/cuda/lltm_cuda_kernel.cu index e8759fb..fd408c9 100644 --- a/cuda/lltm_cuda_kernel.cu +++ b/cuda/lltm_cuda_kernel.cu @@ -37,64 +37,59 @@ __device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) { template __global__ void lltm_cuda_forward_kernel( - const scalar_t* __restrict__ gates, - const scalar_t* __restrict__ old_cell, - scalar_t* __restrict__ new_h, - scalar_t* __restrict__ new_cell, - scalar_t* __restrict__ input_gate, - scalar_t* __restrict__ output_gate, - scalar_t* __restrict__ candidate_cell, - size_t state_size) { - const int column = blockIdx.x * blockDim.x + threadIdx.x; - const int index = blockIdx.y * state_size + column; - const int gates_row = blockIdx.y * (state_size * 3); - if (column < state_size) { - input_gate[index] = sigmoid(gates[gates_row + column]); - output_gate[index] = sigmoid(gates[gates_row + state_size + column]); - candidate_cell[index] = elu(gates[gates_row + 2 * state_size + column]); - new_cell[index] = - old_cell[index] + candidate_cell[index] * input_gate[index]; - new_h[index] = tanh(new_cell[index]) * output_gate[index]; + const torch::PackedTensorAccessor gates, + const torch::PackedTensorAccessor old_cell, + torch::PackedTensorAccessor new_h, + torch::PackedTensorAccessor new_cell, + torch::PackedTensorAccessor input_gate, + torch::PackedTensorAccessor output_gate, + torch::PackedTensorAccessor candidate_cell) { + //batch index + const int n = blockIdx.y; + // column index + const int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c < gates.size(2)){ + input_gate[n][c] = sigmoid(gates[n][0][c]); + output_gate[n][c] = sigmoid(gates[n][1][c]); + candidate_cell[n][c] = elu(gates[n][2][c]); + new_cell[n][c] = + old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c]; + new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c]; } } template __global__ void lltm_cuda_backward_kernel( - scalar_t* __restrict__ d_old_cell, - scalar_t* __restrict__ d_gates, - const scalar_t* __restrict__ grad_h, - const scalar_t* __restrict__ grad_cell, - const scalar_t* __restrict__ new_cell, - const scalar_t* __restrict__ input_gate, - const scalar_t* __restrict__ output_gate, - const scalar_t* __restrict__ candidate_cell, - const scalar_t* __restrict__ gate_weights, - size_t state_size) { - const int column = blockIdx.x * blockDim.x + threadIdx.x; - const int index = blockIdx.y * state_size + column; - const int gates_row = blockIdx.y * (state_size * 3); - if (column < state_size) { - const auto d_output_gate = tanh(new_cell[index]) * grad_h[index]; - const auto d_tanh_new_cell = output_gate[index] * grad_h[index]; + torch::PackedTensorAccessor d_old_cell, + torch::PackedTensorAccessor d_gates, + const torch::PackedTensorAccessor grad_h, + const torch::PackedTensorAccessor grad_cell, + const torch::PackedTensorAccessor new_cell, + const torch::PackedTensorAccessor input_gate, + const torch::PackedTensorAccessor output_gate, + const torch::PackedTensorAccessor candidate_cell, + const torch::PackedTensorAccessor gate_weights) { + //batch index + const int n = blockIdx.y; + // column index + const int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c < d_gates.size(2)){ + const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c]; + const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c]; const auto d_new_cell = - d_tanh(new_cell[index]) * d_tanh_new_cell + grad_cell[index]; + d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c]; - d_old_cell[index] = d_new_cell; - const auto d_candidate_cell = input_gate[index] * d_new_cell; - const auto d_input_gate = candidate_cell[index] * d_new_cell; + d_old_cell[n][c] = d_new_cell; + const auto d_candidate_cell = input_gate[n][c] * d_new_cell; + const auto d_input_gate = candidate_cell[n][c] * d_new_cell; - - const auto input_gate_index = gates_row + column; - const auto output_gate_index = gates_row + state_size + column; - const auto candidate_cell_index = gates_row + 2 * state_size + column; - - d_gates[input_gate_index] = - d_input_gate * d_sigmoid(gate_weights[input_gate_index]); - d_gates[output_gate_index] = - d_output_gate * d_sigmoid(gate_weights[output_gate_index]); - d_gates[candidate_cell_index] = - d_candidate_cell * d_elu(gate_weights[candidate_cell_index]); + d_gates[n][0][c] = + d_input_gate * d_sigmoid(gate_weights[n][0][c]); + d_gates[n][1][c] = + d_output_gate * d_sigmoid(gate_weights[n][1][c]); + d_gates[n][2][c] = + d_candidate_cell * d_elu(gate_weights[n][2][c]); } } } // namespace @@ -106,11 +101,12 @@ std::vector lltm_cuda_forward( torch::Tensor old_h, torch::Tensor old_cell) { auto X = torch::cat({old_h, input}, /*dim=*/1); - auto gates = torch::addmm(bias, X, weights.transpose(0, 1)); + auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); const auto batch_size = old_cell.size(0); const auto state_size = old_cell.size(1); + auto gates = gate_weights.reshape({batch_size, 3, state_size}); auto new_h = torch::zeros_like(old_cell); auto new_cell = torch::zeros_like(old_cell); auto input_gate = torch::zeros_like(old_cell); @@ -122,14 +118,13 @@ std::vector lltm_cuda_forward( AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] { lltm_cuda_forward_kernel<<>>( - gates.data(), - old_cell.data(), - new_h.data(), - new_cell.data(), - input_gate.data(), - output_gate.data(), - candidate_cell.data(), - state_size); + gates.packed_accessor(), + old_cell.packed_accessor(), + new_h.packed_accessor(), + new_cell.packed_accessor(), + input_gate.packed_accessor(), + output_gate.packed_accessor(), + candidate_cell.packed_accessor()); })); return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates}; @@ -143,10 +138,10 @@ std::vector lltm_cuda_backward( torch::Tensor output_gate, torch::Tensor candidate_cell, torch::Tensor X, - torch::Tensor gate_weights, + torch::Tensor gates, torch::Tensor weights) { auto d_old_cell = torch::zeros_like(new_cell); - auto d_gates = torch::zeros_like(gate_weights); + auto d_gates = torch::zeros_like(gates); const auto batch_size = new_cell.size(0); const auto state_size = new_cell.size(1); @@ -156,22 +151,22 @@ std::vector lltm_cuda_backward( AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] { lltm_cuda_backward_kernel<<>>( - d_old_cell.data(), - d_gates.data(), - grad_h.data(), - grad_cell.data(), - new_cell.data(), - input_gate.data(), - output_gate.data(), - candidate_cell.data(), - gate_weights.data(), - state_size); + d_old_cell.packed_accessor(), + d_gates.packed_accessor(), + grad_h.packed_accessor(), + grad_cell.packed_accessor(), + new_cell.packed_accessor(), + input_gate.packed_accessor(), + output_gate.packed_accessor(), + candidate_cell.packed_accessor(), + gates.packed_accessor()); })); - auto d_weights = d_gates.t().mm(X); - auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); + auto d_gate_weights = d_gates.flatten(1, 2); + auto d_weights = d_gate_weights.t().mm(X); + auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true); - auto d_X = d_gates.mm(weights); + auto d_X = d_gate_weights.mm(weights); auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); auto d_input = d_X.slice(/*dim=*/1, state_size);