summaryrefslogtreecommitdiff
path: root/asl
diff options
context:
space:
mode:
authorSteven Le Rouzic <steven.lerouzic@gmail.com>2025-03-22 01:21:56 +0100
committerSteven Le Rouzic <steven.lerouzic@gmail.com>2025-03-22 19:24:40 +0100
commit781877bd26ed7ab01ae6cf952bf4691641593ed2 (patch)
tree1f64c53d7fdada9d751fb88c4fa35fbf03305221 /asl
parentc692909ff332de6f2e32db844458ccd03a080e53 (diff)
Add function_ref
Diffstat (limited to 'asl')
-rw-r--r--asl/testing/testing.hpp6
-rw-r--r--asl/types/BUILD.bazel21
-rw-r--r--asl/types/function.hpp10
-rw-r--r--asl/types/function_ref.hpp67
-rw-r--r--asl/types/function_ref_tests.cpp50
-rw-r--r--asl/types/function_tests.cpp12
6 files changed, 158 insertions, 8 deletions
diff --git a/asl/testing/testing.hpp b/asl/testing/testing.hpp
index 3b4a421..8ea73a3 100644
--- a/asl/testing/testing.hpp
+++ b/asl/testing/testing.hpp
@@ -46,6 +46,6 @@ struct Test
if (EXPR) {} \
else { ::asl::testing::report_failure(#EXPR); return; }
-#define ASL_TEST_EXPECT(EXPR) \
- if (EXPR) {} \
- else { ::asl::testing::report_failure(#EXPR); }
+#define ASL_TEST_EXPECT(...) \
+ if (__VA_ARGS__) {} \
+ else { ::asl::testing::report_failure(#__VA_ARGS__); }
diff --git a/asl/types/BUILD.bazel b/asl/types/BUILD.bazel
index 58d7183..198d0a2 100644
--- a/asl/types/BUILD.bazel
+++ b/asl/types/BUILD.bazel
@@ -90,6 +90,17 @@ cc_library(
visibility = ["//visibility:public"],
)
+cc_library(
+ name = "function_ref",
+ hdrs = [
+ "function_ref.hpp",
+ ],
+ deps = [
+ "//asl/base",
+ ],
+ visibility = ["//visibility:public"],
+)
+
cc_test(
name = "function_tests",
srcs = ["function_tests.cpp"],
@@ -101,6 +112,16 @@ cc_test(
)
cc_test(
+ name = "function_ref_tests",
+ srcs = ["function_ref_tests.cpp"],
+ deps = [
+ "//asl/tests:utils",
+ "//asl/testing",
+ "//asl/types:function_ref",
+ ],
+)
+
+cc_test(
name = "box_tests",
srcs = ["box_tests.cpp"],
deps = [
diff --git a/asl/types/function.hpp b/asl/types/function.hpp
index 6460711..40387ba 100644
--- a/asl/types/function.hpp
+++ b/asl/types/function.hpp
@@ -196,7 +196,7 @@ public:
{
if (m_op != nullptr)
{
- (*m_op)(
+ m_op(
&m_storage,
const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast)
function_detail::FunctionOp::kCopyFromOtherToThisUninit);
@@ -209,7 +209,7 @@ public:
{
if (m_op != nullptr)
{
- (*m_op)(
+ m_op(
&m_storage,
&other.m_storage,
function_detail::FunctionOp::kMoveFromOtherToThisUninit);
@@ -230,7 +230,7 @@ public:
m_invoke = other.m_invoke;
m_op = other.m_op;
- (*m_op)(
+ m_op(
&m_storage,
const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast)
function_detail::FunctionOp::kCopyFromOtherToThisUninit);
@@ -247,7 +247,7 @@ public:
m_invoke = asl::exchange(other.m_invoke, nullptr);
m_op = asl::exchange(other.m_op, nullptr);
- (*m_op)(
+ m_op(
&m_storage,
&other.m_storage,
function_detail::FunctionOp::kMoveFromOtherToThisUninit);
@@ -277,7 +277,7 @@ public:
constexpr R operator()(Args... args) const
{
ASL_ASSERT(m_invoke);
- return (*m_invoke)(args..., m_storage);
+ return m_invoke(args..., m_storage);
}
};
diff --git a/asl/types/function_ref.hpp b/asl/types/function_ref.hpp
new file mode 100644
index 0000000..2bd27dc
--- /dev/null
+++ b/asl/types/function_ref.hpp
@@ -0,0 +1,67 @@
+#pragma once
+
+#include "asl/base/utility.hpp"
+#include "asl/base/meta.hpp"
+#include "asl/base/functional.hpp"
+
+namespace asl
+{
+
+template<typename T>
+class function_ref;
+
+template<typename R, typename... Args>
+class function_ref<R(Args...)>
+{
+ using InvokeFn = R (*)(Args..., void*);
+
+ void* m_obj;
+ InvokeFn m_invoke;
+
+ template<typename T>
+ static R invoke(Args... args, void* obj)
+ {
+ // NOLINTNEXTLINE(*-reinterpret-cast)
+ return asl::invoke(*reinterpret_cast<T*>(obj), std::forward<Args>(args)...);
+ }
+
+public:
+ function_ref() = delete;
+
+ ASL_DEFAULT_COPY_MOVE(function_ref);
+ ~function_ref() = default;
+
+ template<typename T>
+ function_ref(T&& t) // NOLINT(*-missing-std-forward, *explicit*)
+ requires (
+ !same_as<un_cvref_t<T>, function_ref>
+ && invocable<T, Args...>
+ && same_as<invoke_result_t<T, Args...>, R>
+ )
+ // NOLINTNEXTLINE(*cast*)
+ : m_obj{const_cast<void*>(reinterpret_cast<const void*>(&t))}
+ , m_invoke{invoke<un_ref_t<T>>}
+ {}
+
+ template<typename T>
+ function_ref& operator=(T&& t) // NOLINT(*-missing-std-forward)
+ requires (
+ !same_as<un_cvref_t<T>, function_ref>
+ && invocable<T, Args...>
+ && same_as<invoke_result_t<T, Args...>, R>
+ )
+ {
+ // NOLINTNEXTLINE(*cast*)
+ m_obj = const_cast<void*>(reinterpret_cast<const void*>(&t));
+ m_invoke = invoke<un_ref_t<T>>;
+
+ return *this;
+ }
+
+ constexpr R operator()(this function_ref self, Args... args)
+ {
+ return self.m_invoke(std::forward<Args>(args)..., self.m_obj);
+ }
+};
+
+} // namespace asl
diff --git a/asl/types/function_ref_tests.cpp b/asl/types/function_ref_tests.cpp
new file mode 100644
index 0000000..37cb382
--- /dev/null
+++ b/asl/types/function_ref_tests.cpp
@@ -0,0 +1,50 @@
+#include "asl/testing/testing.hpp"
+#include "asl/types/function_ref.hpp"
+
+static int add(int a, int b)
+{
+ return a + b;
+}
+
+struct Functor
+{
+ int state = 0;
+
+ int operator()(int x, int)
+ {
+ state += x;
+ return state;
+ }
+};
+
+static int invoke_fn_ref(asl::function_ref<int(int, int)> fn, int a, int b)
+{
+ return fn(a, b);
+}
+
+ASL_TEST(function_ref)
+{
+ const asl::function_ref<int(int, int)> fn(add);
+ ASL_TEST_EXPECT(invoke_fn_ref(fn, 4, 5) == 9);
+
+ ASL_TEST_EXPECT(invoke_fn_ref(add, 4, 5) == 9);
+
+ ASL_TEST_EXPECT(invoke_fn_ref([](int a, int b) { return a * b; }, 4, 5) == 20);
+
+ Functor fun;
+ ASL_TEST_EXPECT(invoke_fn_ref(fun, 4, 5) == 4);
+ ASL_TEST_EXPECT(invoke_fn_ref(fun, 4, 5) == 8);
+ ASL_TEST_EXPECT(invoke_fn_ref(fun, 4, 5) == 12);
+
+ asl::function_ref<int(int, int)> fn2 = fn;
+ ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == 9);
+
+ fn2 = [](int a, int b) { return a - b; };
+ ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == -1);
+
+ fn2 = fn;
+ ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == 9);
+
+ fn2 = add;
+ ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == 9);
+}
diff --git a/asl/types/function_tests.cpp b/asl/types/function_tests.cpp
index 5a55885..c9849d1 100644
--- a/asl/types/function_tests.cpp
+++ b/asl/types/function_tests.cpp
@@ -177,3 +177,15 @@ ASL_TEST(replace)
fn = [](int x) { return x + 3; };
ASL_TEST_EXPECT(fn(5) == 8);
}
+
+static int foo(const asl::function<int(int, int)>& fn)
+{
+ return fn(5, 5);
+}
+
+ASL_TEST(function_parameter)
+{
+ ASL_TEST_EXPECT(foo(add) == 10);
+ ASL_TEST_EXPECT(foo([](int a, int b) { return a + b; }) == 10);
+}
+