Skip to content

Move fully-featured FunctionRef from extension/pytree to ExecuTorch core #10441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 8 additions & 106 deletions extension/pytree/function_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,117 +6,19 @@
* LICENSE file in the root directory of this source tree.
*/

//===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains some extension to <functional>.
//
// No library is required when using these functions.
//
//===----------------------------------------------------------------------===//
// Extra additions to <functional>
//===----------------------------------------------------------------------===//

/// An efficient, type-erasing, non-owning reference to a callable. This is
/// intended for use as the type of a function parameter that is not used
/// after the function in question returns.
///
/// This class does not own the callable, so it is not in general safe to store
/// a FunctionRef.

// torch::executor: modified from llvm::function_ref
// - renamed to FunctionRef
// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses
// - use namespaced internal::remove_cvref_t

#pragma once

#include <cstdint>
#include <type_traits>
#include <utility>

namespace executorch {
namespace extension {
namespace pytree {

//===----------------------------------------------------------------------===//
// Features from C++20
//===----------------------------------------------------------------------===//

namespace internal {

template <typename T>
struct remove_cvref {
using type =
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
};

template <typename T>
using remove_cvref_t = typename remove_cvref<T>::type;

} // namespace internal

template <typename Fn>
class FunctionRef;

template <typename Ret, typename... Params>
class FunctionRef<Ret(Params...)> {
Ret (*callback)(intptr_t callable, Params... params) = nullptr;
intptr_t callable;

template <typename Callable>
static Ret callback_fn(intptr_t callable, Params... params) {
return (*reinterpret_cast<Callable*>(callable))(
std::forward<Params>(params)...);
}

public:
FunctionRef() = default;
FunctionRef(std::nullptr_t) {}

template <typename Callable>
FunctionRef(
Callable&& callable,
// This is not the copy-constructor.
std::enable_if_t<!std::is_same<
internal::remove_cvref_t<Callable>,
FunctionRef>::value>* = nullptr,
// Functor must be callable and return a suitable type.
std::enable_if_t<
std::is_void<Ret>::value ||
std::is_convertible<
decltype(std::declval<Callable>()(std::declval<Params>()...)),
Ret>::value>* = nullptr)
: callback(callback_fn<std::remove_reference_t<Callable>>),
callable(reinterpret_cast<intptr_t>(&callable)) {}

Ret operator()(Params... params) const {
return callback(callable, std::forward<Params>(params)...);
}
#include <executorch/runtime/core/function_ref.h>

explicit operator bool() const {
return callback;
}
/// This header is DEPRECATED; use executorch/runtime/core/function_ref.h
/// directly instead.

bool operator==(const FunctionRef<Ret(Params...)>& Other) const {
return callable == Other.callable;
}
};
} // namespace pytree
} // namespace extension
} // namespace executorch
namespace executorch::extension::pytree {
using executorch::runtime::FunctionRef;
} // namespace executorch::extension::pytree

namespace torch {
namespace executor {
namespace pytree {
namespace torch::executor::pytree {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::extension::pytree::FunctionRef;
} // namespace pytree
} // namespace executor
} // namespace torch
} // namespace torch::executor::pytree
2 changes: 1 addition & 1 deletion extension/pytree/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)

include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)

set(_test_srcs function_ref_test.cpp test_pytree.cpp)
set(_test_srcs test_pytree.cpp)

et_cxx_test(extension_pytree_test SOURCES ${_test_srcs} EXTRA_LIBS)
6 changes: 0 additions & 6 deletions extension/pytree/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@ cpp_unittest(
deps = ["//executorch/extension/pytree:pytree"],
)

cpp_unittest(
name = "function_ref_test",
srcs = ["function_ref_test.cpp"],
deps = ["//executorch/extension/pytree:pytree"],
)

python_unittest(
name = "pybindings_test",
srcs = [
Expand Down
108 changes: 108 additions & 0 deletions runtime/core/function_ref.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

//===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains some extension to <functional>.
//
// No library is required when using these functions.
//
//===----------------------------------------------------------------------===//
// Extra additions to <functional>
//===----------------------------------------------------------------------===//

/// An efficient, type-erasing, non-owning reference to a callable. This is
/// intended for use as the type of a function parameter that is not used
/// after the function in question returns.
///
/// This class does not own the callable, so it is not in general safe to store
/// a FunctionRef.

// torch::executor: modified from llvm::function_ref
// - renamed to FunctionRef
// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses
// - use namespaced internal::remove_cvref_t

#pragma once

#include <cstdint>
#include <type_traits>
#include <utility>

namespace executorch::runtime {

//===----------------------------------------------------------------------===//
// Features from C++20
//===----------------------------------------------------------------------===//

namespace internal {

template <typename T>
struct remove_cvref {
using type =
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
};

template <typename T>
using remove_cvref_t = typename remove_cvref<T>::type;

} // namespace internal

template <typename Fn>
class FunctionRef;

template <typename Ret, typename... Params>
class FunctionRef<Ret(Params...)> {
Ret (*callback)(intptr_t callable, Params... params) = nullptr;
intptr_t callable;

template <typename Callable>
static Ret callback_fn(intptr_t callable, Params... params) {
return (*reinterpret_cast<Callable*>(callable))(
std::forward<Params>(params)...);
}

public:
FunctionRef() = default;
FunctionRef(std::nullptr_t) {}

template <typename Callable>
FunctionRef(
Callable&& callable,
// This is not the copy-constructor.
std::enable_if_t<!std::is_same<
internal::remove_cvref_t<Callable>,
FunctionRef>::value>* = nullptr,
// Functor must be callable and return a suitable type.
std::enable_if_t<
std::is_void<Ret>::value ||
std::is_convertible<
decltype(std::declval<Callable>()(std::declval<Params>()...)),
Ret>::value>* = nullptr)
: callback(callback_fn<std::remove_reference_t<Callable>>),
callable(reinterpret_cast<intptr_t>(&callable)) {}

Ret operator()(Params... params) const {
return callback(callable, std::forward<Params>(params)...);
}

explicit operator bool() const {
return callback;
}

bool operator==(const FunctionRef<Ret(Params...)>& Other) const {
return callable == Other.callable;
}
};
} // namespace executorch::runtime
1 change: 1 addition & 0 deletions runtime/core/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def define_common_targets():
"defines.h",
"error.h",
"freeable_buffer.h",
"function_ref.h",
"result.h",
"span.h",
],
Expand Down
9 changes: 5 additions & 4 deletions runtime/core/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)

set(_test_srcs
span_test.cpp
array_ref_test.cpp
error_handling_test.cpp
evalue_test.cpp
event_tracer_test.cpp
freeable_buffer_test.cpp
array_ref_test.cpp
memory_allocator_test.cpp
function_ref_test.cpp
hierarchical_allocator_test.cpp
evalue_test.cpp
memory_allocator_test.cpp
span_test.cpp
)

et_cxx_test(runtime_core_test SOURCES ${_test_srcs} EXTRA_LIBS)
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/pytree/function_ref.h>
#include <executorch/runtime/core/function_ref.h>

#include <gtest/gtest.h>

using namespace ::testing;

using ::executorch::extension::pytree::FunctionRef;
using ::executorch::runtime::FunctionRef;

namespace {
void one(int32_t& i) {
Expand Down
10 changes: 10 additions & 0 deletions runtime/core/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ def define_common_targets():
],
)

runtime.cxx_test(
name = "function_ref_test",
srcs = [
"function_ref_test.cpp",
],
deps = [
"//executorch/runtime/core:core",
],
)

runtime.cxx_test(
name = "event_tracer_test",
srcs = [
Expand Down
10 changes: 5 additions & 5 deletions test/utils/OSSTestConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
{
"directory": "extension/pytree/test",
"sources": [
"function_ref_test.cpp",
"test_pytree.cpp"
]
},
Expand Down Expand Up @@ -96,14 +95,15 @@
{
"directory": "runtime/core/test",
"sources": [
"span_test.cpp",
"array_ref_test.cpp",
"error_handling_test.cpp",
"evalue_test.cpp",
"event_tracer_test.cpp",
"freeable_buffer_test.cpp",
"array_ref_test.cpp",
"memory_allocator_test.cpp",
"function_ref_test.cpp",
"hierarchical_allocator_test.cpp",
"evalue_test.cpp"
"memory_allocator_test.cpp",
"span_test.cpp"
]
},
{
Expand Down
Loading