Add function_ref

This commit is contained in:
2025-03-22 01:21:56 +01:00
parent c692909ff3
commit 781877bd26
6 changed files with 158 additions and 8 deletions

View File

@ -46,6 +46,6 @@ struct Test
if (EXPR) {} \ if (EXPR) {} \
else { ::asl::testing::report_failure(#EXPR); return; } else { ::asl::testing::report_failure(#EXPR); return; }
#define ASL_TEST_EXPECT(EXPR) \ #define ASL_TEST_EXPECT(...) \
if (EXPR) {} \ if (__VA_ARGS__) {} \
else { ::asl::testing::report_failure(#EXPR); } else { ::asl::testing::report_failure(#__VA_ARGS__); }

View File

@ -90,6 +90,17 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
cc_library(
name = "function_ref",
hdrs = [
"function_ref.hpp",
],
deps = [
"//asl/base",
],
visibility = ["//visibility:public"],
)
cc_test( cc_test(
name = "function_tests", name = "function_tests",
srcs = ["function_tests.cpp"], srcs = ["function_tests.cpp"],
@ -100,6 +111,16 @@ cc_test(
], ],
) )
cc_test(
name = "function_ref_tests",
srcs = ["function_ref_tests.cpp"],
deps = [
"//asl/tests:utils",
"//asl/testing",
"//asl/types:function_ref",
],
)
cc_test( cc_test(
name = "box_tests", name = "box_tests",
srcs = ["box_tests.cpp"], srcs = ["box_tests.cpp"],

View File

@ -196,7 +196,7 @@ public:
{ {
if (m_op != nullptr) if (m_op != nullptr)
{ {
(*m_op)( m_op(
&m_storage, &m_storage,
const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast) const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast)
function_detail::FunctionOp::kCopyFromOtherToThisUninit); function_detail::FunctionOp::kCopyFromOtherToThisUninit);
@ -209,7 +209,7 @@ public:
{ {
if (m_op != nullptr) if (m_op != nullptr)
{ {
(*m_op)( m_op(
&m_storage, &m_storage,
&other.m_storage, &other.m_storage,
function_detail::FunctionOp::kMoveFromOtherToThisUninit); function_detail::FunctionOp::kMoveFromOtherToThisUninit);
@ -230,7 +230,7 @@ public:
m_invoke = other.m_invoke; m_invoke = other.m_invoke;
m_op = other.m_op; m_op = other.m_op;
(*m_op)( m_op(
&m_storage, &m_storage,
const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast) const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast)
function_detail::FunctionOp::kCopyFromOtherToThisUninit); function_detail::FunctionOp::kCopyFromOtherToThisUninit);
@ -247,7 +247,7 @@ public:
m_invoke = asl::exchange(other.m_invoke, nullptr); m_invoke = asl::exchange(other.m_invoke, nullptr);
m_op = asl::exchange(other.m_op, nullptr); m_op = asl::exchange(other.m_op, nullptr);
(*m_op)( m_op(
&m_storage, &m_storage,
&other.m_storage, &other.m_storage,
function_detail::FunctionOp::kMoveFromOtherToThisUninit); function_detail::FunctionOp::kMoveFromOtherToThisUninit);
@ -277,7 +277,7 @@ public:
constexpr R operator()(Args... args) const constexpr R operator()(Args... args) const
{ {
ASL_ASSERT(m_invoke); ASL_ASSERT(m_invoke);
return (*m_invoke)(args..., m_storage); return m_invoke(args..., m_storage);
} }
}; };

View File

@ -0,0 +1,67 @@
#pragma once
#include "asl/base/utility.hpp"
#include "asl/base/meta.hpp"
#include "asl/base/functional.hpp"
namespace asl
{
template<typename T>
class function_ref;
template<typename R, typename... Args>
class function_ref<R(Args...)>
{
using InvokeFn = R (*)(Args..., void*);
void* m_obj;
InvokeFn m_invoke;
template<typename T>
static R invoke(Args... args, void* obj)
{
// NOLINTNEXTLINE(*-reinterpret-cast)
return asl::invoke(*reinterpret_cast<T*>(obj), std::forward<Args>(args)...);
}
public:
function_ref() = delete;
ASL_DEFAULT_COPY_MOVE(function_ref);
~function_ref() = default;
template<typename T>
function_ref(T&& t) // NOLINT(*-missing-std-forward, *explicit*)
requires (
!same_as<un_cvref_t<T>, function_ref>
&& invocable<T, Args...>
&& same_as<invoke_result_t<T, Args...>, R>
)
// NOLINTNEXTLINE(*cast*)
: m_obj{const_cast<void*>(reinterpret_cast<const void*>(&t))}
, m_invoke{invoke<un_ref_t<T>>}
{}
template<typename T>
function_ref& operator=(T&& t) // NOLINT(*-missing-std-forward)
requires (
!same_as<un_cvref_t<T>, function_ref>
&& invocable<T, Args...>
&& same_as<invoke_result_t<T, Args...>, R>
)
{
// NOLINTNEXTLINE(*cast*)
m_obj = const_cast<void*>(reinterpret_cast<const void*>(&t));
m_invoke = invoke<un_ref_t<T>>;
return *this;
}
constexpr R operator()(this function_ref self, Args... args)
{
return self.m_invoke(std::forward<Args>(args)..., self.m_obj);
}
};
} // namespace asl

View File

@ -0,0 +1,50 @@
#include "asl/testing/testing.hpp"
#include "asl/types/function_ref.hpp"
static int add(int a, int b)
{
return a + b;
}
struct Functor
{
int state = 0;
int operator()(int x, int)
{
state += x;
return state;
}
};
static int invoke_fn_ref(asl::function_ref<int(int, int)> fn, int a, int b)
{
return fn(a, b);
}
ASL_TEST(function_ref)
{
const asl::function_ref<int(int, int)> fn(add);
ASL_TEST_EXPECT(invoke_fn_ref(fn, 4, 5) == 9);
ASL_TEST_EXPECT(invoke_fn_ref(add, 4, 5) == 9);
ASL_TEST_EXPECT(invoke_fn_ref([](int a, int b) { return a * b; }, 4, 5) == 20);
Functor fun;
ASL_TEST_EXPECT(invoke_fn_ref(fun, 4, 5) == 4);
ASL_TEST_EXPECT(invoke_fn_ref(fun, 4, 5) == 8);
ASL_TEST_EXPECT(invoke_fn_ref(fun, 4, 5) == 12);
asl::function_ref<int(int, int)> fn2 = fn;
ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == 9);
fn2 = [](int a, int b) { return a - b; };
ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == -1);
fn2 = fn;
ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == 9);
fn2 = add;
ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == 9);
}

View File

@ -177,3 +177,15 @@ ASL_TEST(replace)
fn = [](int x) { return x + 3; }; fn = [](int x) { return x + 3; };
ASL_TEST_EXPECT(fn(5) == 8); ASL_TEST_EXPECT(fn(5) == 8);
} }
static int foo(const asl::function<int(int, int)>& fn)
{
return fn(5, 5);
}
ASL_TEST(function_parameter)
{
ASL_TEST_EXPECT(foo(add) == 10);
ASL_TEST_EXPECT(foo([](int a, int b) { return a + b; }) == 10);
}