From 0c73438b00fcbc42a861704d18b2436cbe688880 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 24 Apr 2025 13:15:44 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- extension/pytree/function_ref.h | 117 ++++++-------------- extension/pytree/test/function_ref_test.cpp | 39 ++----- 2 files changed, 39 insertions(+), 117 deletions(-) diff --git a/extension/pytree/function_ref.h b/extension/pytree/function_ref.h index 0458610c4db..ca5e78effe5 100644 --- a/extension/pytree/function_ref.h +++ b/extension/pytree/function_ref.h @@ -30,7 +30,10 @@ /// a FunctionRef. // torch::executor: modified from llvm::function_ref -// see https://www.foonathan.net/2017/01/function-ref-implementation/ +// - renamed to FunctionRef +// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses +// - use namespaced internal::remove_cvref_t + #pragma once @@ -64,99 +67,43 @@ class FunctionRef; template class FunctionRef { - Ret (*callback_)(const void* memory, Params... params) = nullptr; - union Storage { - void* callable; - Ret (*function)(Params...); - } storage_; + Ret (*callback)(intptr_t callable, Params ...params) = nullptr; + intptr_t callable; - public: - FunctionRef() = default; - explicit FunctionRef(std::nullptr_t) {} - - /** - * Case 1: A callable object passed by lvalue reference. - * Taking rvalue reference is error prone because the object will be always - * be destroyed immediately. - */ - template < - typename Callable, - // This is not the copy-constructor. - typename std::enable_if< - !std::is_same, FunctionRef>::value, - int32_t>::type = 0, - // Avoid lvalue reference to non-capturing lambda. - typename std::enable_if< - !std::is_convertible::value, - int32_t>::type = 0, - // Functor must be callable and return a suitable type. - // To make this container type safe, we need to ensure either: - // 1. The return type is void. - // 2. Or the resulting type from calling the callable is convertible to - // the declared return type. - typename std::enable_if< - std::is_void::value || - std::is_convertible< - decltype(std::declval()(std::declval()...)), - Ret>::value, - int32_t>::type = 0> - explicit FunctionRef(Callable& callable) - : callback_([](const void* memory, Params... params) { - auto& storage = *static_cast(memory); - auto& callable = *static_cast(storage.callable); - return static_cast(callable(std::forward(params)...)); - }) { - storage_.callable = &callable; + template + static Ret callback_fn(intptr_t callable, Params ...params) { + return (*reinterpret_cast(callable))( + std::forward(params)...); } - /** - * Case 2: A plain function pointer. - * Instead of storing an opaque pointer to underlying callable object, - * store a function pointer directly. - * Note that in the future a variant which coerces compatible function - * pointers could be implemented by erasing the storage type. - */ - /* implicit */ FunctionRef(Ret (*ptr)(Params...)) - : callback_([](const void* memory, Params... params) { - auto& storage = *static_cast(memory); - return storage.function(std::forward(params)...); - }) { - storage_.function = ptr; - } +public: + FunctionRef() = default; + FunctionRef(std::nullptr_t) {} - /** - * Case 3: Implicit conversion from lambda to FunctionRef. - * A common use pattern is like: - * void foo(FunctionRef<...>) {...} - * foo([](...){...}) - * Here constructors for non const lvalue reference or function pointer - * would not work because they do not cover implicit conversion from rvalue - * lambda. - * We need to define a constructor for capturing temporary callables and - * always try to convert the lambda to a function pointer behind the scene. - */ - template < - typename Function, + template + FunctionRef( + Callable &&callable, // This is not the copy-constructor. - typename std::enable_if< - !std::is_same::value, - int32_t>::type = 0, - // Function is convertible to pointer of (Params...) -> Ret. - typename std::enable_if< - std::is_convertible::value, - int32_t>::type = 0> - /* implicit */ FunctionRef(const Function& function) - : FunctionRef(static_cast(function)) {} - - Ret operator()(Params... params) const { - return callback_(&storage_, std::forward(params)...); + std::enable_if_t, + FunctionRef>::value> * = nullptr, + // Functor must be callable and return a suitable type. + std::enable_if_t::value || + std::is_convertible()( + std::declval()...)), + Ret>::value> * = nullptr) + : callback(callback_fn>), + callable(reinterpret_cast(&callable)) {} + + Ret operator()(Params ...params) const { + return callback(callable, std::forward(params)...); } - explicit operator bool() const { - return callback_; + explicit operator bool() const { return callback; } + + bool operator==(const FunctionRef &Other) const { + return callable == Other.callable; } }; - } // namespace pytree } // namespace extension } // namespace executorch diff --git a/extension/pytree/test/function_ref_test.cpp b/extension/pytree/test/function_ref_test.cpp index a3cdbd824bf..b15377b946a 100644 --- a/extension/pytree/test/function_ref_test.cpp +++ b/extension/pytree/test/function_ref_test.cpp @@ -15,21 +15,6 @@ using namespace ::testing; using ::executorch::extension::pytree::FunctionRef; namespace { -class Item { - private: - int32_t val_; - FunctionRef ref_; - - public: - /* implicit */ Item(int32_t val, FunctionRef ref) - : val_(val), ref_(ref) {} - - int32_t get() { - ref_(val_); - return val_; - } -}; - void one(int32_t& i) { i = 1; } @@ -39,8 +24,9 @@ void one(int32_t& i) { TEST(FunctionRefTest, CapturingLambda) { auto one = 1; auto f = [&](int32_t& i) { i = one; }; - Item item(0, FunctionRef{f}); - EXPECT_EQ(item.get(), 1); + int32_t val = 0; + FunctionRef{f}(val); + EXPECT_EQ(val, 1); // ERROR: // Item item1(0, f); // Item item2(0, [&](int32_t& i) { i = 2; }); @@ -58,16 +44,6 @@ TEST(FunctionRefTest, NonCapturingLambda) { FunctionRef ref1(lambda); ref1(val); EXPECT_EQ(val, 1); - - Item item(0, [](int32_t& i) { i = 1; }); - EXPECT_EQ(item.get(), 1); - - auto f = [](int32_t& i) { i = 1; }; - Item item1(0, f); - EXPECT_EQ(item1.get(), 1); - - Item item2(0, std::move(f)); - EXPECT_EQ(item2.get(), 1); } TEST(FunctionRefTest, FunctionPointer) { @@ -76,9 +52,8 @@ TEST(FunctionRefTest, FunctionPointer) { ref(val); EXPECT_EQ(val, 1); - Item item(0, one); - EXPECT_EQ(item.get(), 1); - - Item item1(0, &one); - EXPECT_EQ(item1.get(), 1); + val = 0; + FunctionRef ref2(one); + ref(val); + EXPECT_EQ(val, 1); } From fb432fa87c05314ab12093d9939974b3f2884fcb Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 24 Apr 2025 13:15:48 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- extension/pytree/function_ref.h | 110 ++---------------- extension/pytree/test/CMakeLists.txt | 2 +- extension/pytree/test/TARGETS | 6 - runtime/core/function_ref.h | 105 +++++++++++++++++ runtime/core/targets.bzl | 1 + runtime/core/test/CMakeLists.txt | 9 +- .../core}/test/function_ref_test.cpp | 0 runtime/core/test/targets.bzl | 10 ++ test/utils/OSSTestConfig.json | 10 +- 9 files changed, 134 insertions(+), 119 deletions(-) create mode 100644 runtime/core/function_ref.h rename {extension/pytree => runtime/core}/test/function_ref_test.cpp (100%) diff --git a/extension/pytree/function_ref.h b/extension/pytree/function_ref.h index ca5e78effe5..236a8baa5a7 100644 --- a/extension/pytree/function_ref.h +++ b/extension/pytree/function_ref.h @@ -6,114 +6,18 @@ * LICENSE file in the root directory of this source tree. */ -//===- llvm/ADT/STLFunctionalExtras.h - Extras for -*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains some extension to . -// -// No library is required when using these functions. -// -//===----------------------------------------------------------------------===// -// Extra additions to -//===----------------------------------------------------------------------===// - -/// An efficient, type-erasing, non-owning reference to a callable. This is -/// intended for use as the type of a function parameter that is not used -/// after the function in question returns. -/// -/// This class does not own the callable, so it is not in general safe to store -/// a FunctionRef. - -// torch::executor: modified from llvm::function_ref -// - renamed to FunctionRef -// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses -// - use namespaced internal::remove_cvref_t - - #pragma once -#include -#include -#include - -namespace executorch { -namespace extension { -namespace pytree { - -//===----------------------------------------------------------------------===// -// Features from C++20 -//===----------------------------------------------------------------------===// - -namespace internal { - -template -struct remove_cvref { - using type = - typename std::remove_cv::type>::type; -}; - -template -using remove_cvref_t = typename remove_cvref::type; - -} // namespace internal - -template -class FunctionRef; - -template -class FunctionRef { - Ret (*callback)(intptr_t callable, Params ...params) = nullptr; - intptr_t callable; - - template - static Ret callback_fn(intptr_t callable, Params ...params) { - return (*reinterpret_cast(callable))( - std::forward(params)...); - } - -public: - FunctionRef() = default; - FunctionRef(std::nullptr_t) {} - - template - FunctionRef( - Callable &&callable, - // This is not the copy-constructor. - std::enable_if_t, - FunctionRef>::value> * = nullptr, - // Functor must be callable and return a suitable type. - std::enable_if_t::value || - std::is_convertible()( - std::declval()...)), - Ret>::value> * = nullptr) - : callback(callback_fn>), - callable(reinterpret_cast(&callable)) {} - - Ret operator()(Params ...params) const { - return callback(callable, std::forward(params)...); - } +#include - explicit operator bool() const { return callback; } +/// This header is DEPRECATED; use executorch/runtime/core/function_ref.h directly instead. - bool operator==(const FunctionRef &Other) const { - return callable == Other.callable; - } -}; -} // namespace pytree -} // namespace extension -} // namespace executorch +namespace executorch::extension::pytree { +using executorch::runtime::FunctionRef; +} // namespace executorch::extension::pytree -namespace torch { -namespace executor { -namespace pytree { +namespace torch::executor::pytree { // TODO(T197294990): Remove these deprecated aliases once all users have moved // to the new `::executorch` namespaces. using ::executorch::extension::pytree::FunctionRef; -} // namespace pytree -} // namespace executor -} // namespace torch +} // namespace torch::executor::pytree diff --git a/extension/pytree/test/CMakeLists.txt b/extension/pytree/test/CMakeLists.txt index 5d99bad1339..ce9b2cec6ec 100644 --- a/extension/pytree/test/CMakeLists.txt +++ b/extension/pytree/test/CMakeLists.txt @@ -19,6 +19,6 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) -set(_test_srcs function_ref_test.cpp test_pytree.cpp) +set(_test_srcs test_pytree.cpp) et_cxx_test(extension_pytree_test SOURCES ${_test_srcs} EXTRA_LIBS) diff --git a/extension/pytree/test/TARGETS b/extension/pytree/test/TARGETS index 190bdb0bc67..e49e8cd2791 100644 --- a/extension/pytree/test/TARGETS +++ b/extension/pytree/test/TARGETS @@ -10,12 +10,6 @@ cpp_unittest( deps = ["//executorch/extension/pytree:pytree"], ) -cpp_unittest( - name = "function_ref_test", - srcs = ["function_ref_test.cpp"], - deps = ["//executorch/extension/pytree:pytree"], -) - python_unittest( name = "pybindings_test", srcs = [ diff --git a/runtime/core/function_ref.h b/runtime/core/function_ref.h new file mode 100644 index 00000000000..a9ded49ab74 --- /dev/null +++ b/runtime/core/function_ref.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//===- llvm/ADT/STLFunctionalExtras.h - Extras for -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains some extension to . +// +// No library is required when using these functions. +// +//===----------------------------------------------------------------------===// +// Extra additions to +//===----------------------------------------------------------------------===// + +/// An efficient, type-erasing, non-owning reference to a callable. This is +/// intended for use as the type of a function parameter that is not used +/// after the function in question returns. +/// +/// This class does not own the callable, so it is not in general safe to store +/// a FunctionRef. + +// torch::executor: modified from llvm::function_ref +// - renamed to FunctionRef +// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses +// - use namespaced internal::remove_cvref_t + + +#pragma once + +#include +#include +#include + +namespace executorch::runtime { + +//===----------------------------------------------------------------------===// +// Features from C++20 +//===----------------------------------------------------------------------===// + +namespace internal { + +template +struct remove_cvref { + using type = + typename std::remove_cv::type>::type; +}; + +template +using remove_cvref_t = typename remove_cvref::type; + +} // namespace internal + +template +class FunctionRef; + +template +class FunctionRef { + Ret (*callback)(intptr_t callable, Params ...params) = nullptr; + intptr_t callable; + + template + static Ret callback_fn(intptr_t callable, Params ...params) { + return (*reinterpret_cast(callable))( + std::forward(params)...); + } + +public: + FunctionRef() = default; + FunctionRef(std::nullptr_t) {} + + template + FunctionRef( + Callable &&callable, + // This is not the copy-constructor. + std::enable_if_t, + FunctionRef>::value> * = nullptr, + // Functor must be callable and return a suitable type. + std::enable_if_t::value || + std::is_convertible()( + std::declval()...)), + Ret>::value> * = nullptr) + : callback(callback_fn>), + callable(reinterpret_cast(&callable)) {} + + Ret operator()(Params ...params) const { + return callback(callable, std::forward(params)...); + } + + explicit operator bool() const { return callback; } + + bool operator==(const FunctionRef &Other) const { + return callable == Other.callable; + } +}; +} // namespace executorch::runtime diff --git a/runtime/core/targets.bzl b/runtime/core/targets.bzl index d3e02b1afb5..efc7853f3c1 100644 --- a/runtime/core/targets.bzl +++ b/runtime/core/targets.bzl @@ -41,6 +41,7 @@ def define_common_targets(): "defines.h", "error.h", "freeable_buffer.h", + "function_ref.h", "result.h", "span.h", ], diff --git a/runtime/core/test/CMakeLists.txt b/runtime/core/test/CMakeLists.txt index 70f7cbf4bfd..bdc427baf7d 100644 --- a/runtime/core/test/CMakeLists.txt +++ b/runtime/core/test/CMakeLists.txt @@ -20,14 +20,15 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) set(_test_srcs - span_test.cpp + array_ref_test.cpp error_handling_test.cpp + evalue_test.cpp event_tracer_test.cpp freeable_buffer_test.cpp - array_ref_test.cpp - memory_allocator_test.cpp + function_ref_test.cpp hierarchical_allocator_test.cpp - evalue_test.cpp + memory_allocator_test.cpp + span_test.cpp ) et_cxx_test(runtime_core_test SOURCES ${_test_srcs} EXTRA_LIBS) diff --git a/extension/pytree/test/function_ref_test.cpp b/runtime/core/test/function_ref_test.cpp similarity index 100% rename from extension/pytree/test/function_ref_test.cpp rename to runtime/core/test/function_ref_test.cpp diff --git a/runtime/core/test/targets.bzl b/runtime/core/test/targets.bzl index 7db74475c92..1ad0940c62e 100644 --- a/runtime/core/test/targets.bzl +++ b/runtime/core/test/targets.bzl @@ -33,6 +33,16 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "function_ref_test", + srcs = [ + "function_ref_test.cpp", + ], + deps = [ + "//executorch/runtime/core:core", + ], + ) + runtime.cxx_test( name = "event_tracer_test", srcs = [ diff --git a/test/utils/OSSTestConfig.json b/test/utils/OSSTestConfig.json index be594f9d5f4..2cfc4b8a995 100644 --- a/test/utils/OSSTestConfig.json +++ b/test/utils/OSSTestConfig.json @@ -45,7 +45,6 @@ { "directory": "extension/pytree/test", "sources": [ - "function_ref_test.cpp", "test_pytree.cpp" ] }, @@ -96,14 +95,15 @@ { "directory": "runtime/core/test", "sources": [ - "span_test.cpp", + "array_ref_test.cpp", "error_handling_test.cpp", + "evalue_test.cpp", "event_tracer_test.cpp", "freeable_buffer_test.cpp", - "array_ref_test.cpp", - "memory_allocator_test.cpp", + "function_ref_test.cpp", "hierarchical_allocator_test.cpp", - "evalue_test.cpp" + "memory_allocator_test.cpp", + "span_test.cpp" ] }, {