Skip to content

Commit ac3ba94

Browse files
authored
Enable autocast for NMS and ROIAlign on ROCm (#2637)
* add autocasting on ROCm * enable ROIAlign autocasting on ROCm * enable NMS autocasting on ROCm * fix to use correct torch CUDA APIs
1 parent 02f46a5 commit ac3ba94

File tree

5 files changed

+7
-12
lines changed

5 files changed

+7
-12
lines changed

torchvision/csrc/ROIAlign.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "cuda/vision_cuda.h"
88
#endif
99
#ifdef WITH_HIP
10+
#include "autocast.h"
1011
#include "hip/vision_cuda.h"
1112
#endif
1213

@@ -37,7 +38,7 @@ at::Tensor roi_align(
3738
aligned);
3839
}
3940

40-
#ifdef WITH_CUDA
41+
#if defined(WITH_CUDA) || defined(WITH_HIP)
4142
at::Tensor ROIAlign_autocast(
4243
const at::Tensor& input,
4344
const at::Tensor& rois,

torchvision/csrc/autocast.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#ifdef WITH_CUDA
3+
#if defined(WITH_CUDA) || defined(WITH_HIP)
44
namespace autocast {
55

66
inline bool is_eligible(const at::Tensor& arg) {

torchvision/csrc/cuda/nms_cuda.cu

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
#include <ATen/ATen.h>
22
#include <ATen/cuda/CUDAContext.h>
3-
4-
#if defined(WITH_CUDA)
53
#include <c10/cuda/CUDAGuard.h>
6-
#elif defined(WITH_HIP)
7-
#include <c10/hip/HIPGuard.h>
8-
#endif
94

105
#include "cuda_helpers.h"
116

@@ -98,10 +93,8 @@ at::Tensor nms_cuda(const at::Tensor& dets,
9893
" and ",
9994
scores.size(0))
10095

101-
#if defined(WITH_CUDA)
96+
#if defined(WITH_CUDA) || defined(WITH_HIP)
10297
at::cuda::CUDAGuard device_guard(dets.device());
103-
#elif defined(WITH_HIP)
104-
at::cuda::HIPGuard device_guard(dets.device());
10598
#else
10699
AT_ERROR("Not compiled with GPU support");
107100
#endif

torchvision/csrc/nms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "cuda/vision_cuda.h"
77
#endif
88
#ifdef WITH_HIP
9+
#include "autocast.h"
910
#include "hip/vision_cuda.h"
1011
#endif
1112

@@ -20,7 +21,7 @@ at::Tensor nms(
2021
return op.call(dets, scores, iou_threshold);
2122
}
2223

23-
#ifdef WITH_CUDA
24+
#if defined(WITH_CUDA) || defined(WITH_HIP)
2425
at::Tensor nms_autocast(
2526
const at::Tensor& dets,
2627
const at::Tensor& scores,

torchvision/csrc/vision.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
7272
#endif
7373

7474
// Autocast only needs to wrap forward pass ops.
75-
#if defined(WITH_CUDA)
75+
#if defined(WITH_CUDA) || defined(WITH_HIP)
7676
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
7777
m.impl("roi_align", ROIAlign_autocast);
7878
m.impl("nms", nms_autocast);

0 commit comments

Comments
 (0)