From 781877bd26ed7ab01ae6cf952bf4691641593ed2 Mon Sep 17 00:00:00 2001
From: Steven Le Rouzic <steven.lerouzic@gmail.com>
Date: Sat, 22 Mar 2025 01:21:56 +0100
Subject: Add function_ref

---
 asl/testing/testing.hpp          |  6 ++--
 asl/types/BUILD.bazel            | 21 +++++++++++++
 asl/types/function.hpp           | 10 +++---
 asl/types/function_ref.hpp       | 67 ++++++++++++++++++++++++++++++++++++++++
 asl/types/function_ref_tests.cpp | 50 ++++++++++++++++++++++++++++++
 asl/types/function_tests.cpp     | 12 +++++++
 6 files changed, 158 insertions(+), 8 deletions(-)
 create mode 100644 asl/types/function_ref.hpp
 create mode 100644 asl/types/function_ref_tests.cpp

diff --git a/asl/testing/testing.hpp b/asl/testing/testing.hpp
index 3b4a421..8ea73a3 100644
--- a/asl/testing/testing.hpp
+++ b/asl/testing/testing.hpp
@@ -46,6 +46,6 @@ struct Test
     if (EXPR) {}              \
     else { ::asl::testing::report_failure(#EXPR); return; }
 
-#define ASL_TEST_EXPECT(EXPR) \
-    if (EXPR) {}              \
-    else { ::asl::testing::report_failure(#EXPR); }
+#define ASL_TEST_EXPECT(...) \
+    if (__VA_ARGS__) {}      \
+    else { ::asl::testing::report_failure(#__VA_ARGS__); }
diff --git a/asl/types/BUILD.bazel b/asl/types/BUILD.bazel
index 58d7183..198d0a2 100644
--- a/asl/types/BUILD.bazel
+++ b/asl/types/BUILD.bazel
@@ -90,6 +90,17 @@ cc_library(
     visibility = ["//visibility:public"],
 )
 
+cc_library(
+    name = "function_ref",
+    hdrs = [
+        "function_ref.hpp",
+    ],
+    deps = [
+        "//asl/base",
+    ],
+    visibility = ["//visibility:public"],
+)
+
 cc_test(
     name = "function_tests",
     srcs = ["function_tests.cpp"],
@@ -100,6 +111,16 @@ cc_test(
     ],
 )
 
+cc_test(
+    name = "function_ref_tests",
+    srcs = ["function_ref_tests.cpp"],
+    deps = [
+        "//asl/tests:utils",
+        "//asl/testing",
+        "//asl/types:function_ref",
+    ],
+)
+
 cc_test(
     name = "box_tests",
     srcs = ["box_tests.cpp"],
diff --git a/asl/types/function.hpp b/asl/types/function.hpp
index 6460711..40387ba 100644
--- a/asl/types/function.hpp
+++ b/asl/types/function.hpp
@@ -196,7 +196,7 @@ public:
     {
         if (m_op != nullptr)
         {
-            (*m_op)(
+            m_op(
                 &m_storage,
                 const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast)
                 function_detail::FunctionOp::kCopyFromOtherToThisUninit);
@@ -209,7 +209,7 @@ public:
     {
         if (m_op != nullptr)
         {
-            (*m_op)(
+            m_op(
                 &m_storage,
                 &other.m_storage,
                 function_detail::FunctionOp::kMoveFromOtherToThisUninit);
@@ -230,7 +230,7 @@ public:
             m_invoke = other.m_invoke;
             m_op = other.m_op;
 
-            (*m_op)(
+            m_op(
                 &m_storage,
                 const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast)
                 function_detail::FunctionOp::kCopyFromOtherToThisUninit);
@@ -247,7 +247,7 @@ public:
             m_invoke = asl::exchange(other.m_invoke, nullptr);
             m_op = asl::exchange(other.m_op, nullptr);
 
-            (*m_op)(
+            m_op(
                 &m_storage,
                 &other.m_storage,
                 function_detail::FunctionOp::kMoveFromOtherToThisUninit);
@@ -277,7 +277,7 @@ public:
     constexpr R operator()(Args... args) const
     {
         ASL_ASSERT(m_invoke);
-        return (*m_invoke)(args..., m_storage);
+        return m_invoke(args..., m_storage);
     }
 };
 
diff --git a/asl/types/function_ref.hpp b/asl/types/function_ref.hpp
new file mode 100644
index 0000000..2bd27dc
--- /dev/null
+++ b/asl/types/function_ref.hpp
@@ -0,0 +1,67 @@
+#pragma once
+
+#include "asl/base/utility.hpp"
+#include "asl/base/meta.hpp"
+#include "asl/base/functional.hpp"
+
+namespace asl
+{
+
+template<typename T>
+class function_ref;
+
+template<typename R, typename... Args>
+class function_ref<R(Args...)>
+{
+    using InvokeFn = R (*)(Args..., void*);
+
+    void*       m_obj;
+    InvokeFn    m_invoke;
+
+    template<typename T>
+    static R invoke(Args... args, void* obj)
+    {
+        // NOLINTNEXTLINE(*-reinterpret-cast)
+        return asl::invoke(*reinterpret_cast<T*>(obj), std::forward<Args>(args)...);
+    }
+
+public:
+    function_ref() = delete;
+
+    ASL_DEFAULT_COPY_MOVE(function_ref);
+    ~function_ref() = default;
+
+    template<typename T>
+    function_ref(T&& t) // NOLINT(*-missing-std-forward, *explicit*)
+        requires (
+            !same_as<un_cvref_t<T>, function_ref>
+            && invocable<T, Args...>
+            && same_as<invoke_result_t<T, Args...>, R>
+        )
+        // NOLINTNEXTLINE(*cast*)
+        : m_obj{const_cast<void*>(reinterpret_cast<const void*>(&t))}
+        , m_invoke{invoke<un_ref_t<T>>}
+    {}
+
+    template<typename T>
+    function_ref& operator=(T&& t) // NOLINT(*-missing-std-forward)
+        requires (
+            !same_as<un_cvref_t<T>, function_ref>
+            && invocable<T, Args...>
+            && same_as<invoke_result_t<T, Args...>, R>
+        )
+    {
+        // NOLINTNEXTLINE(*cast*)
+        m_obj = const_cast<void*>(reinterpret_cast<const void*>(&t));
+        m_invoke = invoke<un_ref_t<T>>;
+
+        return *this;
+    }
+
+    constexpr R operator()(this function_ref self, Args... args)
+    {
+        return self.m_invoke(std::forward<Args>(args)..., self.m_obj);
+    }
+};
+
+} // namespace asl
diff --git a/asl/types/function_ref_tests.cpp b/asl/types/function_ref_tests.cpp
new file mode 100644
index 0000000..37cb382
--- /dev/null
+++ b/asl/types/function_ref_tests.cpp
@@ -0,0 +1,50 @@
+#include "asl/testing/testing.hpp"
+#include "asl/types/function_ref.hpp"
+
+static int add(int a, int b)
+{
+    return a + b;
+}
+
+struct Functor
+{
+    int state = 0;
+
+    int operator()(int x, int)
+    {
+        state += x;
+        return state;
+    }
+};
+
+static int invoke_fn_ref(asl::function_ref<int(int, int)> fn, int a, int b)
+{
+    return fn(a, b);
+}
+
+ASL_TEST(function_ref)
+{
+    const asl::function_ref<int(int, int)> fn(add);
+    ASL_TEST_EXPECT(invoke_fn_ref(fn, 4, 5) == 9);
+
+    ASL_TEST_EXPECT(invoke_fn_ref(add, 4, 5) == 9);
+
+    ASL_TEST_EXPECT(invoke_fn_ref([](int a, int b) { return a * b; }, 4, 5) == 20);
+
+    Functor fun;
+    ASL_TEST_EXPECT(invoke_fn_ref(fun, 4, 5) == 4);
+    ASL_TEST_EXPECT(invoke_fn_ref(fun, 4, 5) == 8);
+    ASL_TEST_EXPECT(invoke_fn_ref(fun, 4, 5) == 12);
+
+    asl::function_ref<int(int, int)> fn2 = fn;
+    ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == 9);
+
+    fn2 = [](int a, int b) { return a - b; };
+    ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == -1);
+
+    fn2 = fn;
+    ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == 9);
+
+    fn2 = add;
+    ASL_TEST_EXPECT(invoke_fn_ref(fn2, 4, 5) == 9);
+}
diff --git a/asl/types/function_tests.cpp b/asl/types/function_tests.cpp
index 5a55885..c9849d1 100644
--- a/asl/types/function_tests.cpp
+++ b/asl/types/function_tests.cpp
@@ -177,3 +177,15 @@ ASL_TEST(replace)
     fn = [](int x) { return x + 3; };
     ASL_TEST_EXPECT(fn(5) == 8);
 }
+
+static int foo(const asl::function<int(int, int)>& fn)
+{
+    return fn(5, 5);
+}
+
+ASL_TEST(function_parameter)
+{
+    ASL_TEST_EXPECT(foo(add) == 10);
+    ASL_TEST_EXPECT(foo([](int a, int b) { return a + b; }) == 10);
+}
+
-- 
cgit