Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions aten/src/ATen/native/cuda/SoftMax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/TensorUtils.h>
#include <ATen/TensorOperators.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/ceil_div.h>
#include <c10/macros/Macros.h>

#include <ATen/AccumulateType.h>
Expand Down Expand Up @@ -486,6 +487,10 @@ ilpReduce(index_t shift,
}

offset = size - last + threadIdx.x;
if (offset < 0) {
// Ensure offset >= 0
offset += round_up<long>(-offset, blockDim.x);
}
// Epilogue
for (; offset < size; offset += blockDim.x)
threadVal = r(threadVal, data[offset]);
Expand Down Expand Up @@ -543,6 +548,10 @@ WriteFpropResultsVectorized(
}

offset = size - last + threadIdx.x;
if (offset < 0) {
// Ensure offset >= 0
offset += round_up<long>(-offset, blockDim.x);
}
// handle the tail
for (; offset < size; offset += blockDim.x) {
output[offset] = epilogue(input[offset]);
Expand Down Expand Up @@ -603,6 +612,10 @@ WriteBpropResultsVectorized(
}

offset = size - last + threadIdx.x;
if (offset < 0) {
// Ensure offset >= 0
offset += round_up<long>(-offset, blockDim.x);
}
for (; offset < size; offset += blockDim.x) {
gradInput[offset] = epilogue(gradOutput[offset], output[offset]);
}
Expand Down