Skip to content

Commit 0ce7c83

Browse files
irexyccx
andauthored
mmrotate sdk module (open-mmlab#450)
* support mmrotate * fix name * windows default link to cudart_static.lib, which is not compatible with static build && python_api * python api * fix ci * fix type & remove unused meta info * fix doxygen, add [out] to @param * fix mmrotate-c-api * refactor naming * refactor naming * fix lint * fix lint * move replace_RResize -> get_preprocess * Update cuda.cmake On windows, make static lib and python api build success. * fix ptr * Use unique ptr to prevent memory leaks * move unique_ptr * remove deleter Co-authored-by: chenxin2 <[email protected]> Co-authored-by: cx <[email protected]>
1 parent 1a8d7ac commit 0ce7c83

File tree

18 files changed

+631
-6
lines changed

18 files changed

+631
-6
lines changed

cmake/cuda.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ if (${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.18.0")
66
cmake_policy(SET CMP0104 OLD)
77
endif ()
88

9+
if (MSVC)
10+
# use shared, on windows, python api can't build with static lib.
11+
set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
12+
endif ()
13+
914
# nvcc compiler settings
1015
find_package(CUDA REQUIRED)
1116
#message(STATUS "CUDA VERSION: ${CUDA_VERSION_STRING}")
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
_base_ = ['./rotated-detection_static.py', '../_base_/backends/sdk.py']
2+
3+
codebase_config = dict(model_type='sdk')
4+
5+
backend_config = dict(pipeline=[
6+
dict(type='LoadImageFromFile'),
7+
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape'])
8+
])

csrc/apis/c/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ project(capis)
55
include(${CMAKE_SOURCE_DIR}/cmake/MMDeploy.cmake)
66

77
if ("all" IN_LIST MMDEPLOY_CODEBASES)
8-
set(TASK_LIST "classifier;detector;segmentor;text_detector;text_recognizer;pose_detector;restorer;model")
8+
set(TASK_LIST "classifier;detector;segmentor;text_detector;text_recognizer;"
9+
"pose_detector;restorer;model;rotated_detector")
910
else ()
1011
set(TASK_LIST "model")
1112
if ("mmcls" IN_LIST MMDEPLOY_CODEBASES)
@@ -27,6 +28,9 @@ else ()
2728
if ("mmpose" IN_LIST MMDEPLOY_CODEBASES)
2829
list(APPEND TASK_LIST "pose_detector")
2930
endif ()
31+
if ("mmrotate" IN_LIST MMDEPLOY_CODEBASES)
32+
list(APPEND TASK_LIST "rotated_detector")
33+
endif()
3034
endif ()
3135

3236
foreach (TASK ${TASK_LIST})

csrc/apis/c/detector.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ MMDEPLOY_API int mmdeploy_detector_create_by_path(const char* model_path, const
5757
* @param[in] mat_count number of images in the batch
5858
* @param[out] results a linear buffer to save detection results of each image. It must be released
5959
* by \ref mmdeploy_detector_release_result
60-
* @param result_count a linear buffer with length being \p mat_count to save the number of
60+
* @param[out] result_count a linear buffer with length being \p mat_count to save the number of
6161
* detection results of each image. And it must be released by \ref
6262
* mmdeploy_detector_release_result
6363
* @return status of inference

csrc/apis/c/rotated_detector.cpp

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
3+
#include "rotated_detector.h"
4+
5+
#include <numeric>
6+
7+
#include "codebase/mmrotate/mmrotate.h"
8+
#include "core/device.h"
9+
#include "core/graph.h"
10+
#include "core/mat.h"
11+
#include "core/utils/formatter.h"
12+
#include "handle.h"
13+
14+
using namespace std;
15+
using namespace mmdeploy;
16+
17+
namespace {
18+
19+
Value& config_template() {
20+
// clang-format off
21+
static Value v{
22+
{
23+
"pipeline", {
24+
{"input", {"image"}},
25+
{"output", {"det"}},
26+
{
27+
"tasks",{
28+
{
29+
{"name", "mmrotate"},
30+
{"type", "Inference"},
31+
{"params", {{"model", "TBD"}}},
32+
{"input", {"image"}},
33+
{"output", {"det"}}
34+
}
35+
}
36+
}
37+
}
38+
}
39+
};
40+
// clang-format on
41+
return v;
42+
}
43+
44+
template <class ModelType>
45+
int mmdeploy_rotated_detector_create_impl(ModelType&& m, const char* device_name, int device_id,
46+
mm_handle_t* handle) {
47+
try {
48+
auto value = config_template();
49+
value["pipeline"]["tasks"][0]["params"]["model"] = std::forward<ModelType>(m);
50+
51+
auto pose_estimator = std::make_unique<Handle>(device_name, device_id, std::move(value));
52+
53+
*handle = pose_estimator.release();
54+
return MM_SUCCESS;
55+
56+
} catch (const std::exception& e) {
57+
MMDEPLOY_ERROR("exception caught: {}", e.what());
58+
} catch (...) {
59+
MMDEPLOY_ERROR("unknown exception caught");
60+
}
61+
return MM_E_FAIL;
62+
}
63+
64+
} // namespace
65+
66+
int mmdeploy_rotated_detector_create(mm_model_t model, const char* device_name, int device_id,
67+
mm_handle_t* handle) {
68+
return mmdeploy_rotated_detector_create_impl(*static_cast<Model*>(model), device_name, device_id,
69+
handle);
70+
}
71+
72+
int mmdeploy_rotated_detector_create_by_path(const char* model_path, const char* device_name,
73+
int device_id, mm_handle_t* handle) {
74+
return mmdeploy_rotated_detector_create_impl(model_path, device_name, device_id, handle);
75+
}
76+
77+
int mmdeploy_rotated_detector_apply(mm_handle_t handle, const mm_mat_t* mats, int mat_count,
78+
mm_rotated_detect_t** results, int** result_count) {
79+
if (handle == nullptr || mats == nullptr || mat_count == 0 || results == nullptr ||
80+
result_count == nullptr) {
81+
return MM_E_INVALID_ARG;
82+
}
83+
84+
try {
85+
auto detector = static_cast<Handle*>(handle);
86+
87+
Value input{Value::kArray};
88+
for (int i = 0; i < mat_count; ++i) {
89+
mmdeploy::Mat _mat{mats[i].height, mats[i].width, PixelFormat(mats[i].format),
90+
DataType(mats[i].type), mats[i].data, Device{"cpu"}};
91+
input.front().push_back({{"ori_img", _mat}});
92+
}
93+
94+
auto output = detector->Run(std::move(input)).value().front();
95+
auto detector_outputs = from_value<vector<mmrotate::RotatedDetectorOutput>>(output);
96+
97+
vector<int> _result_count;
98+
_result_count.reserve(mat_count);
99+
for (const auto& det_output : detector_outputs) {
100+
_result_count.push_back((int)det_output.detections.size());
101+
}
102+
103+
auto total = std::accumulate(_result_count.begin(), _result_count.end(), 0);
104+
105+
std::unique_ptr<int[]> result_count_data(new int[_result_count.size()]{});
106+
std::copy(_result_count.begin(), _result_count.end(), result_count_data.get());
107+
108+
std::unique_ptr<mm_rotated_detect_t[]> result_data(new mm_rotated_detect_t[total]{});
109+
auto result_ptr = result_data.get();
110+
111+
for (const auto& det_output : detector_outputs) {
112+
for (const auto& detection : det_output.detections) {
113+
result_ptr->label_id = detection.label_id;
114+
result_ptr->score = detection.score;
115+
const auto& rbbox = detection.rbbox;
116+
for (int i = 0; i < 5; i++) {
117+
result_ptr->rbbox[i] = rbbox[i];
118+
}
119+
++result_ptr;
120+
}
121+
}
122+
123+
*result_count = result_count_data.release();
124+
*results = result_data.release();
125+
126+
return MM_SUCCESS;
127+
128+
} catch (const std::exception& e) {
129+
MMDEPLOY_ERROR("exception caught: {}", e.what());
130+
} catch (...) {
131+
MMDEPLOY_ERROR("unknown exception caught");
132+
}
133+
return MM_E_FAIL;
134+
}
135+
136+
void mmdeploy_rotated_detector_release_result(mm_rotated_detect_t* results,
137+
const int* result_count) {
138+
delete[] results;
139+
delete[] result_count;
140+
}
141+
142+
void mmdeploy_rotated_detector_destroy(mm_handle_t handle) { delete static_cast<Handle*>(handle); }

csrc/apis/c/rotated_detector.h

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
3+
/**
4+
* @file rotated_detector.h
5+
* @brief Interface to MMRotate task
6+
*/
7+
8+
#ifndef MMDEPLOY_SRC_APIS_C_ROTATED_DETECTOR_H_
9+
#define MMDEPLOY_SRC_APIS_C_ROTATED_DETECTOR_H_
10+
11+
#include "common.h"
12+
13+
#ifdef __cplusplus
14+
extern "C" {
15+
#endif
16+
17+
typedef struct mm_rotated_detect_t {
18+
int label_id;
19+
float score;
20+
float rbbox[5]; // cx, cy, w, h, angle
21+
} mm_rotated_detect_t;
22+
23+
/**
24+
* @brief Create rotated detector's handle
25+
* @param[in] model an instance of mmrotate sdk model created by
26+
* \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h
27+
* @param[in] device_name name of device, such as "cpu", "cuda", etc.
28+
* @param[in] device_id id of device.
29+
* @param[out] handle instance of a rotated detector
30+
* @return status of creating rotated detector's handle
31+
*/
32+
MMDEPLOY_API int mmdeploy_rotated_detector_create(mm_model_t model, const char* device_name,
33+
int device_id, mm_handle_t* handle);
34+
35+
/**
36+
* @brief Create rotated detector's handle
37+
* @param[in] model_path path of mmrotate sdk model exported by mmdeploy model converter
38+
* @param[in] device_name name of device, such as "cpu", "cuda", etc.
39+
* @param[in] device_id id of device.
40+
* @param[out] handle instance of a rotated detector
41+
* @return status of creating rotated detector's handle
42+
*/
43+
MMDEPLOY_API int mmdeploy_rotated_detector_create_by_path(const char* model_path,
44+
const char* device_name, int device_id,
45+
mm_handle_t* handle);
46+
47+
/**
48+
* @brief Apply rotated detector to batch images and get their inference results
49+
* @param[in] handle rotated detector's handle created by \ref
50+
* mmdeploy_rotated_detector_create_by_path
51+
* @param[in] mats a batch of images
52+
* @param[in] mat_count number of images in the batch
53+
* @param[out] results a linear buffer to save detection results of each image. It must be released
54+
* by \ref mmdeploy_rotated_detector_release_result
55+
* @param[out] result_count a linear buffer with length being \p mat_count to save the number of
56+
* detection results of each image. And it must be released by \ref
57+
* mmdeploy_rotated_detector_release_result
58+
* @return status of inference
59+
*/
60+
MMDEPLOY_API int mmdeploy_rotated_detector_apply(mm_handle_t handle, const mm_mat_t* mats,
61+
int mat_count, mm_rotated_detect_t** results,
62+
int** result_count);
63+
64+
/** @brief Release the inference result buffer created by \ref mmdeploy_rotated_detector_apply
65+
* @param[in] results rotated detection results buffer
66+
* @param[in] result_count \p results size buffer
67+
*/
68+
MMDEPLOY_API void mmdeploy_rotated_detector_release_result(mm_rotated_detect_t* results,
69+
const int* result_count);
70+
71+
/**
72+
* @brief Destroy rotated detector's handle
73+
* @param[in] handle rotated detector's handle created by \ref
74+
* mmdeploy_rotated_detector_create_by_path or by \ref mmdeploy_rotated_detector_create
75+
*/
76+
MMDEPLOY_API void mmdeploy_rotated_detector_destroy(mm_handle_t handle);
77+
78+
#ifdef __cplusplus
79+
}
80+
#endif
81+
82+
#endif // MMDEPLOY_SRC_APIS_C_ROTATED_DETECTOR_H_

csrc/apis/python/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ mmdeploy_python_add_module(text_detector)
2525
mmdeploy_python_add_module(text_recognizer)
2626
mmdeploy_python_add_module(restorer)
2727
mmdeploy_python_add_module(pose_detector)
28+
mmdeploy_python_add_module(rotated_detector)
2829

2930
pybind11_add_module(${PROJECT_NAME} ${MMDEPLOY_PYTHON_SRCS})
3031

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
3+
#include "rotated_detector.h"
4+
5+
#include "common.h"
6+
#include "core/logger.h"
7+
8+
namespace mmdeploy {
9+
10+
class PyRotatedDetector {
11+
public:
12+
PyRotatedDetector(const char *model_path, const char *device_name, int device_id) {
13+
MMDEPLOY_INFO("{}, {}, {}", model_path, device_name, device_id);
14+
auto status =
15+
mmdeploy_rotated_detector_create_by_path(model_path, device_name, device_id, &handle_);
16+
if (status != MM_SUCCESS) {
17+
throw std::runtime_error("failed to create rotated detector");
18+
}
19+
}
20+
py::list Apply(const std::vector<PyImage> &imgs) {
21+
std::vector<mm_mat_t> mats;
22+
mats.reserve(imgs.size());
23+
for (const auto &img : imgs) {
24+
auto mat = GetMat(img);
25+
mats.push_back(mat);
26+
}
27+
28+
mm_rotated_detect_t *rbboxes{};
29+
int *res_count{};
30+
auto status = mmdeploy_rotated_detector_apply(handle_, mats.data(), (int)mats.size(), &rbboxes,
31+
&res_count);
32+
if (status != MM_SUCCESS) {
33+
throw std::runtime_error("failed to apply rotated detector, code: " + std::to_string(status));
34+
}
35+
auto output = py::list{};
36+
auto result = rbboxes;
37+
auto counts = res_count;
38+
for (int i = 0; i < mats.size(); i++) {
39+
auto _dets = py::array_t<float>({*counts, 6});
40+
auto _labels = py::array_t<int>({*counts});
41+
auto dets = _dets.mutable_data();
42+
auto labels = _labels.mutable_data();
43+
for (int j = 0; j < *counts; j++) {
44+
for (int k = 0; k < 5; k++) {
45+
*dets++ = result->rbbox[k];
46+
}
47+
*dets++ = result->score;
48+
*labels++ = result->label_id;
49+
result++;
50+
}
51+
counts++;
52+
output.append(py::make_tuple(std::move(_dets), std::move(_labels)));
53+
}
54+
mmdeploy_rotated_detector_release_result(rbboxes, res_count);
55+
return output;
56+
}
57+
~PyRotatedDetector() {
58+
mmdeploy_rotated_detector_destroy(handle_);
59+
handle_ = {};
60+
}
61+
62+
private:
63+
mm_handle_t handle_{};
64+
};
65+
66+
static void register_python_rotated_detector(py::module &m) {
67+
py::class_<PyRotatedDetector>(m, "RotatedDetector")
68+
.def(py::init([](const char *model_path, const char *device_name, int device_id) {
69+
return std::make_unique<PyRotatedDetector>(model_path, device_name, device_id);
70+
}))
71+
.def("__call__", &PyRotatedDetector::Apply);
72+
}
73+
74+
class PythonRotatedDetectorRegisterer {
75+
public:
76+
PythonRotatedDetectorRegisterer() {
77+
gPythonBindings().emplace("rotated_detector", register_python_rotated_detector);
78+
}
79+
};
80+
81+
static PythonRotatedDetectorRegisterer python_rotated_detector_registerer;
82+
83+
} // namespace mmdeploy

csrc/codebase/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ if ("all" IN_LIST MMDEPLOY_CODEBASES)
1010
list(APPEND CODEBASES "mmocr")
1111
list(APPEND CODEBASES "mmedit")
1212
list(APPEND CODEBASES "mmpose")
13+
list(APPEND CODEBASES "mmrotate")
1314
else ()
1415
set(CODEBASES ${MMDEPLOY_CODEBASES})
1516
endif ()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
cmake_minimum_required(VERSION 3.14)
3+
project(mmdeploy_mmrotate)
4+
5+
include(${CMAKE_SOURCE_DIR}/cmake/opencv.cmake)
6+
include(${CMAKE_SOURCE_DIR}/cmake/MMDeploy.cmake)
7+
8+
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
9+
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
10+
target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy_opencv_utils)
11+
add_library(mmdeploy::mmrotate ALIAS ${PROJECT_NAME})

0 commit comments

Comments
 (0)