Skip to content

Commit 79f42eb

Browse files
author
Cody Griffin
authored
Merge pull request #1 from EricCousineau-TRI/issue/1328
numpy: Provide concrete size aliases, test equivalence for `dtype(...).num
2 parents f7bc18f + cd496ce commit 79f42eb

File tree

3 files changed

+128
-3
lines changed

3 files changed

+128
-3
lines changed

include/pybind11/numpy.h

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <numeric>
1515
#include <algorithm>
1616
#include <array>
17+
#include <cstdint>
1718
#include <cstdlib>
1819
#include <cstring>
1920
#include <sstream>
@@ -108,6 +109,18 @@ inline numpy_internals& get_numpy_internals() {
108109
return *ptr;
109110
}
110111

112+
template <typename T> struct same_size {
113+
template <typename U> using as = bool_constant<sizeof(T) == sizeof(U)>;
114+
};
115+
116+
// Lookup a type according to its size, and return a value corresponding to the NumPy typenum.
117+
template <typename Concrete, typename... Check>
118+
constexpr int platform_lookup(const std::array<int, sizeof...(Check)> codes) {
119+
using code_index = std::integral_constant<int, constexpr_first<same_size<Concrete>::template as, Check...>()>;
120+
static_assert(code_index::value != sizeof...(Check), "Unable to match type on this platform");
121+
return codes[code_index::value];
122+
}
123+
111124
struct npy_api {
112125
enum constants {
113126
NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
@@ -126,7 +139,23 @@ struct npy_api {
126139
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
127140
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
128141
NPY_OBJECT_ = 17,
129-
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
142+
NPY_STRING_, NPY_UNICODE_, NPY_VOID_,
143+
// Platform-dependent normalization
144+
NPY_INT8_ = NPY_BYTE_,
145+
NPY_UINT8_ = NPY_UBYTE_,
146+
NPY_INT16_ = NPY_SHORT_,
147+
NPY_UINT16_ = NPY_USHORT_,
148+
// `npy_common.h` defines the integer aliases. In order, it checks:
149+
// NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
150+
// and assigns the alias to the first matching size, so we should check in this order.
151+
NPY_INT32_ = platform_lookup<std::int32_t, long, int, short>({{
152+
NPY_LONG_, NPY_INT_, NPY_SHORT_}}),
153+
NPY_UINT32_ = platform_lookup<std::uint32_t, unsigned long, unsigned int, unsigned short>({{
154+
NPY_ULONG_, NPY_UINT_, NPY_USHORT_}}),
155+
NPY_INT64_ = platform_lookup<std::int64_t, long, long long, int>({{
156+
NPY_LONG_, NPY_LONGLONG_, NPY_INT_}}),
157+
NPY_UINT64_ = platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>({{
158+
NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_}}),
130159
};
131160

132161
typedef struct {
@@ -1004,8 +1033,8 @@ struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmet
10041033
// NB: the order here must match the one in common.h
10051034
constexpr static const int values[15] = {
10061035
npy_api::NPY_BOOL_,
1007-
npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_,
1008-
npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_,
1036+
npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_INT16_, npy_api::NPY_UINT16_,
1037+
npy_api::NPY_INT32_, npy_api::NPY_UINT32_, npy_api::NPY_INT64_, npy_api::NPY_UINT64_,
10091038
npy_api::NPY_FLOAT_, npy_api::NPY_DOUBLE_, npy_api::NPY_LONGDOUBLE_,
10101039
npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
10111040
};

tests/test_numpy_array.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,67 @@
1414

1515
#include <cstdint>
1616

17+
// Size / dtype checks.
18+
struct DtypeCheck {
19+
py::dtype numpy{};
20+
py::dtype pybind11{};
21+
};
22+
23+
template <typename T>
24+
DtypeCheck get_dtype_check(const char* name) {
25+
py::module np = py::module::import("numpy");
26+
DtypeCheck check{};
27+
check.numpy = np.attr("dtype")(np.attr(name));
28+
check.pybind11 = py::dtype::of<T>();
29+
return check;
30+
}
31+
32+
std::vector<DtypeCheck> get_concrete_dtype_checks() {
33+
return {
34+
// Normalization
35+
get_dtype_check<std::int8_t>("int8"),
36+
get_dtype_check<std::uint8_t>("uint8"),
37+
get_dtype_check<std::int16_t>("int16"),
38+
get_dtype_check<std::uint16_t>("uint16"),
39+
get_dtype_check<std::int32_t>("int32"),
40+
get_dtype_check<std::uint32_t>("uint32"),
41+
get_dtype_check<std::int64_t>("int64"),
42+
get_dtype_check<std::uint64_t>("uint64")
43+
};
44+
}
45+
46+
struct DtypeSizeCheck {
47+
std::string name{};
48+
int size_cpp{};
49+
int size_numpy{};
50+
// For debugging.
51+
py::dtype dtype{};
52+
};
53+
54+
template <typename T>
55+
DtypeSizeCheck get_dtype_size_check() {
56+
DtypeSizeCheck check{};
57+
check.name = py::type_id<T>();
58+
check.size_cpp = sizeof(T);
59+
check.dtype = py::dtype::of<T>();
60+
check.size_numpy = check.dtype.attr("itemsize").template cast<int>();
61+
return check;
62+
}
63+
64+
std::vector<DtypeSizeCheck> get_platform_dtype_size_checks() {
65+
return {
66+
get_dtype_size_check<short>(),
67+
get_dtype_size_check<unsigned short>(),
68+
get_dtype_size_check<int>(),
69+
get_dtype_size_check<unsigned int>(),
70+
get_dtype_size_check<long>(),
71+
get_dtype_size_check<unsigned long>(),
72+
get_dtype_size_check<long long>(),
73+
get_dtype_size_check<unsigned long long>(),
74+
};
75+
}
76+
77+
// Arrays.
1778
using arr = py::array;
1879
using arr_t = py::array_t<uint16_t, 0>;
1980
static_assert(std::is_same<arr_t::value_type, uint16_t>::value, "");
@@ -72,6 +133,26 @@ TEST_SUBMODULE(numpy_array, sm) {
72133
try { py::module::import("numpy"); }
73134
catch (...) { return; }
74135

136+
// test_dtypes
137+
py::class_<DtypeCheck>(sm, "DtypeCheck")
138+
.def_readonly("numpy", &DtypeCheck::numpy)
139+
.def_readonly("pybind11", &DtypeCheck::pybind11)
140+
.def("__repr__", [](const DtypeCheck& self) {
141+
return py::str("<DtypeCheck numpy={} pybind11={}>").format(
142+
self.numpy, self.pybind11);
143+
});
144+
sm.def("get_concrete_dtype_checks", &get_concrete_dtype_checks);
145+
146+
py::class_<DtypeSizeCheck>(sm, "DtypeSizeCheck")
147+
.def_readonly("name", &DtypeSizeCheck::name)
148+
.def_readonly("size_cpp", &DtypeSizeCheck::size_cpp)
149+
.def_readonly("size_numpy", &DtypeSizeCheck::size_numpy)
150+
.def("__repr__", [](const DtypeSizeCheck& self) {
151+
return py::str("<DtypeSizeCheck name='{}' size_cpp={} size_numpy={} dtype={}>").format(
152+
self.name, self.size_cpp, self.size_numpy, self.dtype);
153+
});
154+
sm.def("get_platform_dtype_size_checks", &get_platform_dtype_size_checks);
155+
75156
// test_array_attributes
76157
sm.def("ndim", [](const arr& a) { return a.ndim(); });
77158
sm.def("shape", [](const arr& a) { return arr(a.ndim(), a.shape()); });

tests/test_numpy_array.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,21 @@
77
import numpy as np
88

99

10+
def test_dtypes():
11+
# See issue #1328.
12+
# - Platform-dependent sizes.
13+
for size_check in m.get_platform_dtype_size_checks():
14+
print(size_check)
15+
assert size_check.size_cpp == size_check.size_numpy, size_check
16+
# - Concrete sizes.
17+
for check in m.get_concrete_dtype_checks():
18+
print(check)
19+
assert check.numpy == check.pybind11, check
20+
if check.numpy.num != check.pybind11.num:
21+
print("NOTE: typenum mismatch for {}: {} != {}".format(
22+
check, check.numpy.num, check.pybind11.num))
23+
24+
1025
@pytest.fixture(scope='function')
1126
def arr():
1227
return np.array([[1, 2, 3], [4, 5, 6]], '=u2')

0 commit comments

Comments
 (0)