diff --git a/CMakeLists.txt b/CMakeLists.txt index 8063303938..3edd13cbbe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -188,6 +188,7 @@ set(PYBIND11_HEADERS include/pybind11/detail/dynamic_raw_ptr_cast_if_possible.h include/pybind11/detail/exception_translation.h include/pybind11/detail/function_record_pyobject.h + include/pybind11/detail/function_ref.h include/pybind11/detail/holder_caster_foreign_helpers.h include/pybind11/detail/init.h include/pybind11/detail/internals.h diff --git a/include/pybind11/detail/common.h b/include/pybind11/detail/common.h index 16952c5829..75a7125314 100644 --- a/include/pybind11/detail/common.h +++ b/include/pybind11/detail/common.h @@ -163,6 +163,14 @@ # define PYBIND11_NOINLINE __attribute__((noinline)) inline #endif +#if defined(_MSC_VER) +# define PYBIND11_ALWAYS_INLINE __forceinline +#elif defined(__GNUC__) +# define PYBIND11_ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +# define PYBIND11_ALWAYS_INLINE inline +#endif + #if defined(__MINGW32__) // For unknown reasons all PYBIND11_DEPRECATED member trigger a warning when declared // whether it is used or not diff --git a/include/pybind11/detail/function_ref.h b/include/pybind11/detail/function_ref.h new file mode 100644 index 0000000000..f99724cbfc --- /dev/null +++ b/include/pybind11/detail/function_ref.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//===- llvm/ADT/STLFunctionalExtras.h - Extras for -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains some extension to . +// +// No library is required when using these functions. +// +//===----------------------------------------------------------------------===// +// Extra additions to +//===----------------------------------------------------------------------===// + +/// An efficient, type-erasing, non-owning reference to a callable. This is +/// intended for use as the type of a function parameter that is not used +/// after the function in question returns. +/// +/// This class does not own the callable, so it is not in general safe to store +/// a FunctionRef. + +// pybind11: modified again from executorch::runtime::FunctionRef +// - renamed back to function_ref +// - use pybind11 enable_if_t, remove_cvref_t, and remove_reference_t +// - lint suppressions + +// torch::executor: modified from llvm::function_ref +// - renamed to FunctionRef +// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses +// - use namespaced internal::remove_cvref_t + +#pragma once + +#include + +#include +#include +#include + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +PYBIND11_NAMESPACE_BEGIN(detail) + +//===----------------------------------------------------------------------===// +// Features from C++20 +//===----------------------------------------------------------------------===// + +template +class function_ref; + +template +class function_ref { + Ret (*callback)(intptr_t callable, Params... params) = nullptr; + intptr_t callable; + + template + // NOLINTNEXTLINE(performance-unnecessary-value-param) + static Ret callback_fn(intptr_t callable, Params... params) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + return (*reinterpret_cast(callable))(std::forward(params)...); + } + +public: + function_ref() = default; + // NOLINTNEXTLINE(google-explicit-constructor) + function_ref(std::nullptr_t) {} + + template + // NOLINTNEXTLINE(google-explicit-constructor) + function_ref( + Callable &&callable, + // This is not the copy-constructor. + enable_if_t, function_ref>::value> * = nullptr, + // Functor must be callable and return a suitable type. + enable_if_t< + std::is_void::value + || std::is_convertible()(std::declval()...)), + Ret>::value> * = nullptr) + : callback(callback_fn>), + callable(reinterpret_cast(&callable)) {} + + // NOLINTNEXTLINE(performance-unnecessary-value-param) + Ret operator()(Params... params) const { + return callback(callable, std::forward(params)...); + } + + explicit operator bool() const { return callback; } + + bool operator==(const function_ref &Other) const { + return callable == Other.callable; + } +}; +PYBIND11_NAMESPACE_END(detail) +PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 8ab4681c76..7214d482d5 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -13,6 +13,7 @@ #include "detail/dynamic_raw_ptr_cast_if_possible.h" #include "detail/exception_translation.h" #include "detail/function_record_pyobject.h" +#include "detail/function_ref.h" #include "detail/init.h" #include "detail/native_enum_data.h" #include "detail/using_smart_holder.h" @@ -379,6 +380,40 @@ class cpp_function : public function { return unique_function_record(new detail::function_record()); } +private: + // This is outlined from the dispatch lambda in initialize to save + // on code size. Crucially, we use function_ref to type-erase the + // actual function lambda so that we can get code reuse for + // functions with the same Return, Args, and Guard. + template + static handle call_impl(detail::function_call &call, detail::function_ref f) { + using namespace detail; + using cast_out + = make_caster::value, void_type, Return>>; + + ArgsConverter args_converter; + if (!args_converter.load_args(call)) { + return PYBIND11_TRY_NEXT_OVERLOAD; + } + + /* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */ + return_value_policy policy + = return_value_policy_override::policy(call.func.policy); + + /* Perform the function call */ + handle result; + if (call.func.is_setter) { + (void) std::move(args_converter).template call(f); + result = none().release(); + } else { + result = cast_out::cast( + std::move(args_converter).template call(f), policy, call.parent); + } + + return result; + } + +protected: /// Special internal constructor for functors, lambda functions, etc. template void initialize(Func &&f, Return (*)(Args...), const Extra &...extra) { @@ -441,13 +476,6 @@ class cpp_function : public function { /* Dispatch code which converts function arguments and performs the actual function call */ rec->impl = [](function_call &call) -> handle { - cast_in args_converter; - - /* Try to cast the function arguments into the C++ domain */ - if (!args_converter.load_args(call)) { - return PYBIND11_TRY_NEXT_OVERLOAD; - } - /* Invoke call policy pre-call hook */ process_attributes::precall(call); @@ -456,24 +484,11 @@ class cpp_function : public function { : call.func.data[0]); auto *cap = const_cast(reinterpret_cast(data)); - /* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */ - return_value_policy policy - = return_value_policy_override::policy(call.func.policy); - - /* Function scope guard -- defaults to the compile-to-nothing `void_type` */ - using Guard = extract_guard_t; - - /* Perform the function call */ - handle result; - if (call.func.is_setter) { - (void) std::move(args_converter).template call(cap->f); - result = none().release(); - } else { - result = cast_out::cast( - std::move(args_converter).template call(cap->f), - policy, - call.parent); - } + auto result = call_impl, + cast_in>(call, detail::function_ref(cap->f)); /* Invoke call policy post-call hook */ process_attributes::postcall(call, result); @@ -2218,7 +2233,7 @@ class class_ : public detail::generic_type { static void add_base(detail::type_record &) {} template - class_ &def(const char *name_, Func &&f, const Extra &...extra) { + PYBIND11_ALWAYS_INLINE class_ &def(const char *name_, Func &&f, const Extra &...extra) { cpp_function cf(method_adaptor(std::forward(f)), name(name_), is_method(*this), @@ -2797,38 +2812,13 @@ struct enum_base { pos_only()) if (is_convertible) { - PYBIND11_ENUM_OP_CONV_LHS("__eq__", !b.is_none() && a.equal(b)); - PYBIND11_ENUM_OP_CONV_LHS("__ne__", b.is_none() || !a.equal(b)); - if (is_arithmetic) { - PYBIND11_ENUM_OP_CONV("__lt__", a < b); - PYBIND11_ENUM_OP_CONV("__gt__", a > b); - PYBIND11_ENUM_OP_CONV("__le__", a <= b); - PYBIND11_ENUM_OP_CONV("__ge__", a >= b); - PYBIND11_ENUM_OP_CONV("__and__", a & b); - PYBIND11_ENUM_OP_CONV("__rand__", a & b); - PYBIND11_ENUM_OP_CONV("__or__", a | b); - PYBIND11_ENUM_OP_CONV("__ror__", a | b); - PYBIND11_ENUM_OP_CONV("__xor__", a ^ b); - PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b); m_base.attr("__invert__") = cpp_function([](const object &arg) { return ~(int_(arg)); }, name("__invert__"), is_method(m_base), pos_only()); } - } else { - PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false); - PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)), return true); - - if (is_arithmetic) { -#define PYBIND11_THROW throw type_error("Expected an enumeration of matching type!"); - PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b), PYBIND11_THROW); - PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b), PYBIND11_THROW); - PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b), PYBIND11_THROW); - PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b), PYBIND11_THROW); -#undef PYBIND11_THROW - } } #undef PYBIND11_ENUM_OP_CONV_LHS @@ -2944,6 +2934,61 @@ class enum_ : public class_ { def(init([](Scalar i) { return static_cast(i); }), arg("value")); def_property_readonly("value", [](Type value) { return (Scalar) value; }, pos_only()); +#define PYBIND11_ENUM_OP_SAME_TYPE(op, expr) \ + def(op, [](Type a, Type b) { return expr; }, pybind11::name(op), arg("other"), pos_only()) +#define PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE(op, expr) \ + def(op, [](Type a, Type *b_ptr) { return expr; }, pybind11::name(op), arg("other"), pos_only()) +#define PYBIND11_ENUM_OP_SCALAR(op, op_expr) \ + def( \ + op, \ + [](Type a, Scalar b) { return static_cast(a) op_expr b; }, \ + pybind11::name(op), \ + arg("other"), \ + pos_only()) +#define PYBIND11_ENUM_OP_CONV_ARITHMETIC(op, op_expr) \ + /* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \ + PYBIND11_ENUM_OP_SAME_TYPE(op, static_cast(a) op_expr static_cast(b)); \ + PYBIND11_ENUM_OP_SCALAR(op, op_expr) +#define PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE(op, strict_behavior) \ + def( \ + op, \ + [](Type, const object &) { strict_behavior; }, \ + pybind11::name(op), \ + arg("other"), \ + pos_only()) +#define PYBIND11_ENUM_OP_STRICT_ARITHMETIC(op, op_expr, strict_behavior) \ + /* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \ + PYBIND11_ENUM_OP_SAME_TYPE(op, static_cast(a) op_expr static_cast(b)); \ + PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE(op, strict_behavior); + + PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE("__eq__", b_ptr && a == *b_ptr); + PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE("__ne__", !b_ptr || a != *b_ptr); + if (std::is_convertible::value) { + PYBIND11_ENUM_OP_SCALAR("__eq__", ==); + PYBIND11_ENUM_OP_SCALAR("__ne__", !=); + if (is_arithmetic) { + PYBIND11_ENUM_OP_CONV_ARITHMETIC("__lt__", <); + PYBIND11_ENUM_OP_CONV_ARITHMETIC("__gt__", >); + PYBIND11_ENUM_OP_CONV_ARITHMETIC("__le__", <=); + PYBIND11_ENUM_OP_CONV_ARITHMETIC("__ge__", >=); + PYBIND11_ENUM_OP_CONV_ARITHMETIC("__and__", &); + PYBIND11_ENUM_OP_CONV_ARITHMETIC("__rand__", &); + PYBIND11_ENUM_OP_CONV_ARITHMETIC("__or__", |); + PYBIND11_ENUM_OP_CONV_ARITHMETIC("__ror__", |); + PYBIND11_ENUM_OP_CONV_ARITHMETIC("__xor__", ^); + PYBIND11_ENUM_OP_CONV_ARITHMETIC("__rxor__", ^); + } + } else if (is_arithmetic) { +#define PYBIND11_THROW throw type_error("Expected an enumeration of matching type!"); + PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__lt__", <, PYBIND11_THROW); + PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__gt__", >, PYBIND11_THROW); + PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__le__", <=, PYBIND11_THROW); + PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__ge__", >=, PYBIND11_THROW); +#undef PYBIND11_THROW + } + PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE("__eq__", return false); + PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE("__ne__", return true); + def("__int__", [](Type value) { return (Scalar) value; }, pos_only()); def("__index__", [](Type value) { return (Scalar) value; }, pos_only()); attr("__setstate__") = cpp_function( diff --git a/tests/extra_python_package/test_files.py b/tests/extra_python_package/test_files.py index 1539b171a2..d96e9afc1f 100644 --- a/tests/extra_python_package/test_files.py +++ b/tests/extra_python_package/test_files.py @@ -83,6 +83,7 @@ "include/pybind11/detail/descr.h", "include/pybind11/detail/dynamic_raw_ptr_cast_if_possible.h", "include/pybind11/detail/function_record_pyobject.h", + "include/pybind11/detail/function_ref.h", "include/pybind11/detail/holder_caster_foreign_helpers.h", "include/pybind11/detail/init.h", "include/pybind11/detail/internals.h", diff --git a/tests/test_copy_move.py b/tests/test_copy_move.py index 3a3f293414..d843793350 100644 --- a/tests/test_copy_move.py +++ b/tests/test_copy_move.py @@ -70,12 +70,12 @@ def test_move_and_copy_loads(): assert c_m.copy_assignments + c_m.copy_constructions == 0 assert c_m.move_assignments == 6 - assert c_m.move_constructions == 9 + assert c_m.move_constructions == 21 assert c_mc.copy_assignments + c_mc.copy_constructions == 0 assert c_mc.move_assignments == 5 - assert c_mc.move_constructions == 8 + assert c_mc.move_constructions == 18 assert c_c.copy_assignments == 4 - assert c_c.copy_constructions == 6 + assert c_c.copy_constructions == 14 assert c_m.alive() + c_mc.alive() + c_c.alive() == 0 @@ -103,12 +103,12 @@ def test_move_and_copy_load_optional(): assert c_m.copy_assignments + c_m.copy_constructions == 0 assert c_m.move_assignments == 2 - assert c_m.move_constructions == 5 + assert c_m.move_constructions == 9 assert c_mc.copy_assignments + c_mc.copy_constructions == 0 assert c_mc.move_assignments == 2 - assert c_mc.move_constructions == 5 + assert c_mc.move_constructions == 9 assert c_c.copy_assignments == 2 - assert c_c.copy_constructions == 5 + assert c_c.copy_constructions == 9 assert c_m.alive() + c_mc.alive() + c_c.alive() == 0 diff --git a/tests/test_enum.py b/tests/test_enum.py index 99d4a88c8a..160708ef53 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -296,9 +296,19 @@ def test_generated_dunder_methods_pos_only(): ]: method = getattr(enum_type, binary_op, None) if method is not None: + # 1) The docs must start with the name of the op. assert ( re.match( - rf"^{binary_op}\(self: [\w\.]+, other: [\w\.]+, /\)", + rf"^{binary_op}\(", + method.__doc__, + ) + is not None + ) + # 2) The docs must contain the op's signature. This is a separate check + # and not anchored at the start because the op may be overloaded. + assert ( + re.search( + rf"{binary_op}\(self: [\w\.]+, other: [\w\.]+, /\)", method.__doc__, ) is not None