Skip to content

Commit 8b500da

Browse files
committed
Add .get_capsule_for_scipy_LowLevelCallable() method in function_record_pyobject.h
1 parent d97186f commit 8b500da

File tree

3 files changed

+71
-0
lines changed

3 files changed

+71
-0
lines changed

include/pybind11/detail/function_record_pyobject.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ void tp_dealloc_impl(PyObject *self);
3131
void tp_free_impl(void *self);
3232

3333
static PyObject *reduce_ex_impl(PyObject *self, PyObject *, PyObject *);
34+
static PyObject *
35+
get_capsule_for_scipy_LowLevelCallable_impl(PyObject *self, PyObject *, PyObject *);
3436

3537
PYBIND11_WARNING_PUSH
3638
#if defined(__GNUC__) && __GNUC__ >= 8
@@ -41,6 +43,10 @@ PYBIND11_WARNING_DISABLE_CLANG("-Wcast-function-type-mismatch")
4143
#endif
4244
static PyMethodDef tp_methods_impl[]
4345
= {{"__reduce_ex__", (PyCFunction) reduce_ex_impl, METH_VARARGS | METH_KEYWORDS, nullptr},
46+
{"get_capsule_for_scipy_LowLevelCallable",
47+
(PyCFunction) get_capsule_for_scipy_LowLevelCallable_impl,
48+
METH_VARARGS | METH_KEYWORDS,
49+
"for use with scipy.LowLevelCallable()"},
4450
{nullptr, nullptr, 0, nullptr}};
4551
PYBIND11_WARNING_POP
4652

@@ -202,6 +208,29 @@ inline PyObject *reduce_ex_impl(PyObject *self, PyObject *, PyObject *) {
202208
return nullptr;
203209
}
204210

211+
inline PyObject *
212+
get_capsule_for_scipy_LowLevelCallable_impl(PyObject *self, PyObject *args, PyObject *kwargs) {
213+
static const char *kwlist[] = {"signature", nullptr};
214+
const char *signature = nullptr;
215+
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s", const_cast<char **>(kwlist), &signature)) {
216+
return nullptr;
217+
}
218+
function_record *rec = function_record_ptr_from_PyObject(self);
219+
if (rec == nullptr) {
220+
pybind11_fail("FATAL: get_capsule_for_scipy_LowLevelCallable_impl(): cannot obtain C++ "
221+
"function_record.");
222+
}
223+
if (!rec->is_stateless) {
224+
set_error(PyExc_TypeError, repr(self) + str(" is not a stateless function."));
225+
return nullptr;
226+
}
227+
struct capture {
228+
void *f; // DANGER: TYPE SAFETY IS LOST COMPLETELY.
229+
};
230+
auto cap = reinterpret_cast<capture *>(&rec->data);
231+
return capsule(cap->f, signature).release().ptr();
232+
}
233+
205234
PYBIND11_NAMESPACE_END(function_record_PyTypeObject_methods)
206235

207236
PYBIND11_NAMESPACE_END(detail)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#include "pybind11_tests.h"
2+
3+
namespace pybind11_tests {
4+
namespace scipy_low_level_callable {
5+
6+
extern "C" double square(double x) { return x * x; }
7+
8+
} // namespace scipy_low_level_callable
9+
} // namespace pybind11_tests
10+
11+
TEST_SUBMODULE(scipy_low_level_callable, m) {
12+
using namespace pybind11_tests::scipy_low_level_callable;
13+
14+
m.def("square", square);
15+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from pybind11_tests import scipy_low_level_callable as m
6+
7+
8+
def test_square():
9+
assert m.square(2.0) == 4.0
10+
11+
12+
def test_get_capsule_for_scipy_LowLevelCallable():
13+
cap = m.square.__self__.get_capsule_for_scipy_LowLevelCallable(
14+
signature="double (double)"
15+
)
16+
assert repr(cap).startswith('<capsule object "double (double)" at 0x')
17+
18+
19+
def test_with_scipy_LowLevelCallable():
20+
scipy = pytest.importorskip("scipy")
21+
llc = scipy.LowLevelCallable(
22+
m.square.__self__.get_capsule_for_scipy_LowLevelCallable(
23+
signature="double (double)"
24+
)
25+
)
26+
integral = scipy.integrate.quad(llc, 0, 1)
27+
assert integral[0] == pytest.approx(1 / 3, rel=1e-12)

0 commit comments

Comments
 (0)