diff options
author | Steven Le Rouzic <steven.lerouzic@gmail.com> | 2025-03-22 01:21:56 +0100 |
---|---|---|
committer | Steven Le Rouzic <steven.lerouzic@gmail.com> | 2025-03-22 19:24:40 +0100 |
commit | 781877bd26ed7ab01ae6cf952bf4691641593ed2 (patch) | |
tree | 1f64c53d7fdada9d751fb88c4fa35fbf03305221 /asl | |
parent | c692909ff332de6f2e32db844458ccd03a080e53 (diff) |
Add function_ref
Diffstat (limited to 'asl')
-rw-r--r-- | asl/testing/testing.hpp | 6 | ||||
-rw-r--r-- | asl/types/BUILD.bazel | 21 | ||||
-rw-r--r-- | asl/types/function.hpp | 10 | ||||
-rw-r--r-- | asl/types/function_ref.hpp | 67 | ||||
-rw-r--r-- | asl/types/function_ref_tests.cpp | 50 | ||||
-rw-r--r-- | asl/types/function_tests.cpp | 12 |
6 files changed, 158 insertions, 8 deletions
diff --git a/asl/testing/testing.hpp b/asl/testing/testing.hpp index 3b4a421..8ea73a3 100644 --- a/asl/testing/testing.hpp +++ b/asl/testing/testing.hpp @@ -46,6 +46,6 @@ struct Test if (EXPR) {} \ else { ::asl::testing::report_failure(#EXPR); return; } -#define ASL_TEST_EXPECT(EXPR) \ - if (EXPR) {} \ - else { ::asl::testing::report_failure(#EXPR); } +#define ASL_TEST_EXPECT(...) \ + if (__VA_ARGS__) {} \ + else { ::asl::testing::report_failure(#__VA_ARGS__); } diff --git a/asl/types/BUILD.bazel b/asl/types/BUILD.bazel index 58d7183..198d0a2 100644 --- a/asl/types/BUILD.bazel +++ b/asl/types/BUILD.bazel @@ -90,6 +90,17 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "function_ref", + hdrs = [ + "function_ref.hpp", + ], + deps = [ + "//asl/base", + ], + visibility = ["//visibility:public"], +) + cc_test( name = "function_tests", srcs = ["function_tests.cpp"], @@ -101,6 +112,16 @@ cc_test( ) cc_test( + name = "function_ref_tests", + srcs = ["function_ref_tests.cpp"], + deps = [ + "//asl/tests:utils", + "//asl/testing", + "//asl/types:function_ref", + ], +) + +cc_test( name = "box_tests", srcs = ["box_tests.cpp"], deps = [ diff --git a/asl/types/function.hpp b/asl/types/function.hpp index 6460711..40387ba 100644 --- a/asl/types/function.hpp +++ b/asl/types/function.hpp @@ -196,7 +196,7 @@ public: { if (m_op != nullptr) { - (*m_op)( + m_op( &m_storage, const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast) function_detail::FunctionOp::kCopyFromOtherToThisUninit); @@ -209,7 +209,7 @@ public: { if (m_op != nullptr) { - (*m_op)( + m_op( &m_storage, &other.m_storage, function_detail::FunctionOp::kMoveFromOtherToThisUninit); @@ -230,7 +230,7 @@ public: m_invoke = other.m_invoke; m_op = other.m_op; - (*m_op)( + m_op( &m_storage, const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast) function_detail::FunctionOp::kCopyFromOtherToThisUninit); @@ -247,7 +247,7 @@ public: m_invoke = asl::exchange(other.m_invoke, nullptr); m_op = asl::exchange(other.m_op, nullptr); - (*m_op)( + m_op( &m_storage, &other.m_storage, function_detail::FunctionOp::kMoveFromOtherToThisUninit); @@ -277,7 +277,7 @@ public: constexpr R operator()(Args... args) const { ASL_ASSERT(m_invoke); - return (*m_invoke)(args..., m_storage); + return m_invoke(args..., m_storage); } }; diff --git a/asl/types/function_ref.hpp b/asl/types/function_ref.hpp new file mode 100644 index 0000000..2bd27dc --- /dev/null +++ b/asl/types/function_ref.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include "asl/base/utility.hpp" +#include "asl/base/meta.hpp" +#include "asl/base/functional.hpp" + +namespace asl +{ + +template<typename T> +class function_ref; + +template<typename R, typename... Args> +class function_ref<R(Args...)> +{ + using InvokeFn = R (*)(Args..., void*); + + void* m_obj; + InvokeFn m_invoke; + + template<typename T> + static R invoke(Args... args, void* obj) + { + // NOLINTNEXTLINE(*-reinterpret-cast) + return asl::invoke(*reinterpret_cast<T*>(obj), std::forward<Args>(args)...); + } + +public: + function_ref() = delete; + + ASL_DEFAULT_COPY_MOVE(function_ref); + ~function_ref() = default; + + template<typename T> + function_ref(T&& t) // NOLINT(*-missing-std-forward, *explicit*) + requires ( + !same_as<un_cvref_t<T>, function_ref> + && invocable<T, Args...> + && same_as<invoke_result_t<T, Args...>, R> + ) + // NOLINTNEXTLINE(*cast*) + : m_obj{const_cast<void*>(reinterpret_cast<const void*>(&t))} + , m_invoke{invoke<un_ref_t<T>>} + {} + + template<typename T> + function_ref& operator=(T&& t) // NOLINT(*-missing-std-forward) + requires ( + !same_as<un_cvref_t<T>, function_ref> + && invocable<T, Args...> + && same_as<invoke_result_t<T, Args...>, R> + ) + { + // NOLINTNEXTLINE(*cast*) + m_obj = const_cast<void*>(reinterpret_cast<const void*>(&t)); + m_invoke = invoke<un_ref_t<T>>; + + return *this; + } + + constexpr R operator()(this function_ref self, Args... args) + { + return self.m_invoke(std::forward<Args>(args)..., self.m_obj); + } +}; + +} // namespace asl diff --git a/asl/types/function_ref_tests.cpp b/asl/types/function_ref_tests.cpp new file mode 100644 index 0000000..37cb382 --- /dev/null +++ b/asl/types/function_ref_tests.cpp @@ -0,0 +1,50 @@ +#include "asl/testing/testing.hpp" +#include "asl/types/function_ref.hpp" + +static int add(int a, int b) +{ + return a + b; +} + +struct Functor +{ + int state = 0; + + int operator()(int x, int) + { + state += x; + return state; + } +}; + +static int invoke_fn_ref(asl::function_ref<int(int, int)> fn, int a, int b) +{ + return fn(a, b); +} + +ASL_TEST(function_ref) +{ + const asl::function_ref<int(int, int)> fn(add); + ASL_TEST_EXPECT(invoke_fn_ref(fn, 4, 5) == 9); + + ASL_TEST_EXPECT(invoke_fn_ref(add, 4, 5) == 9); + + ASL_TEST_EXPECT(invoke_fn_ref([](int a, int b) { return a * b; }, 4, 5) == 20); + + Functor fun; + ASL_TEST_EXPECT(invoke_fn_ref(fun, 4, 5) == 4); + ASL_TEST_EXPECT(invoke_fn_ref(fun, 4, 5) == 8); + ASL_TEST_EXPECT(invoke_fn_ref(fun, 4, 5) == 12); + + asl::function_ref<int(int, int)> fn2 = fn; + ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == 9); + + fn2 = [](int a, int b) { return a - b; }; + ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == -1); + + fn2 = fn; + ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == 9); + + fn2 = add; + ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == 9); +} diff --git a/asl/types/function_tests.cpp b/asl/types/function_tests.cpp index 5a55885..c9849d1 100644 --- a/asl/types/function_tests.cpp +++ b/asl/types/function_tests.cpp @@ -177,3 +177,15 @@ ASL_TEST(replace) fn = [](int x) { return x + 3; }; ASL_TEST_EXPECT(fn(5) == 8); } + +static int foo(const asl::function<int(int, int)>& fn) +{ + return fn(5, 5); +} + +ASL_TEST(function_parameter) +{ + ASL_TEST_EXPECT(foo(add) == 10); + ASL_TEST_EXPECT(foo([](int a, int b) { return a + b; }) == 10); +} + |