From f19d93a69a0ec5c7a89dcb4c064c984aac90ba71 Mon Sep 17 00:00:00 2001
From: Steven Le Rouzic <steven.lerouzic@gmail.com>
Date: Wed, 26 Mar 2025 18:54:54 +0100
Subject: Improve implementation of invoke

---
 asl/base/functional.hpp       | 61 +++++++++++++++++++------------------------
 asl/base/functional_tests.cpp | 26 +++++++++++++++---
 asl/base/meta.hpp             | 15 ++++++-----
 asl/base/meta_tests.cpp       | 14 ++++++++++
 4 files changed, 73 insertions(+), 43 deletions(-)

(limited to 'asl/base')

diff --git a/asl/base/functional.hpp b/asl/base/functional.hpp
index 509a2b2..5649bf8 100644
--- a/asl/base/functional.hpp
+++ b/asl/base/functional.hpp
@@ -9,56 +9,49 @@
 
 namespace asl {
 
-template<typename... Args, typename C>
-constexpr auto invoke(is_func auto C::* f, auto&& self, Args&&... args)
-    -> decltype((self.*f)(std::forward<Args>(args)...))
-    requires requires {
-        (self.*f)(std::forward<Args>(args)...);
+template<typename F, typename... Args>
+constexpr auto invoke(F&& f, Args&&... args)
+    -> decltype(std::forward<F>(f)(std::forward<Args>(args)...))
+    requires (!is_member_ptr<un_cvref_t<F>>) && requires {
+        f(std::forward<Args>(args)...);
     }
 {
-    return (std::forward<decltype(self)>(self).*f)(std::forward<Args>(args)...);
+    return std::forward<F>(f)(std::forward<Args>(args)...);
 }
 
-template<typename... Args, typename C>
-constexpr auto invoke(is_func auto C::* f, auto* self, Args&&... args)
-    -> decltype((self->*f)(std::forward<Args>(args)...))
-    requires requires {
-        (self->*f)(std::forward<Args>(args)...);
-    }
+template<typename C>
+constexpr auto&& invoke(auto C::* f, same_or_derived_from<C> auto&& arg)
 {
-    return (self->*f)(std::forward<Args>(args)...);
+    return std::forward<decltype(arg)>(arg).*f;
 }
 
-template<typename... Args, typename C>
-constexpr auto invoke(is_object auto C::* m, auto&& self, Args&&...)
-    -> decltype(self.*m)
+template<typename C>
+constexpr auto&& invoke(auto C::* f, auto&& arg)
     requires (
-        sizeof...(Args) == 0 &&
-        requires { self.*m; }
+        !same_or_derived_from<decltype(arg), C>
+        && requires { (*arg).*f; }
     )
 {
-    return std::forward<decltype(self)>(self).*m;
+    return (*std::forward<decltype(arg)>(arg)).*f;
 }
 
-template<typename... Args, typename C>
-constexpr auto invoke(is_object auto C::* m, auto* self, Args&&...)
-    -> decltype(self->*m)
-    requires (
-        sizeof...(Args) == 0 &&
-        requires { self->*m; }
-    )
+template<typename C, typename... Args>
+constexpr auto invoke(is_func auto C::* f, same_or_derived_from<C> auto&& self, Args&&... args)
+    -> decltype((std::forward<decltype(self)>(self).*f)(std::forward<Args>(args)...))
+    requires requires { (self.*f)(std::forward<Args>(args)...); }
 {
-    return self->*m;
+    return (std::forward<decltype(self)>(self).*f)(std::forward<Args>(args)...);
 }
 
-template<typename... Args>
-constexpr auto invoke(auto&& f, Args&&... args)
-    -> decltype(f(std::forward<Args>(args)...))
-    requires requires {
-        f(std::forward<Args>(args)...);
-    }
+template<typename C, typename... Args>
+constexpr auto invoke(is_func auto C::* f, auto&& self, Args&&... args)
+    -> decltype(((*std::forward<decltype(self)>(self)).*f)(std::forward<Args>(args)...))
+    requires (
+        !same_or_derived_from<decltype(self), C>
+        && requires { ((*self).*f)(std::forward<Args>(args)...); }
+    )
 {
-    return std::forward<decltype(f)>(f)(std::forward<Args>(args)...);
+    return ((*std::forward<decltype(self)>(self)).*f)(std::forward<Args>(args)...);
 }
 
 template<typename Void, typename F, typename... Args>
diff --git a/asl/base/functional_tests.cpp b/asl/base/functional_tests.cpp
index 5e7b052..6332784 100644
--- a/asl/base/functional_tests.cpp
+++ b/asl/base/functional_tests.cpp
@@ -7,9 +7,12 @@
 
 struct HasFunction
 {
-    void do_something(int, float) {}
+    void do_something(int, float) const {}
+    int& do_something2(int, float) &;
 };
 
+struct HasFunction2 : public HasFunction {};
+
 struct HasMember
 {
     int member{};
@@ -17,6 +20,8 @@ struct HasMember
     void (*member_fn)(){};
 };
 
+struct HasMember2 : public HasMember {};
+
 struct Functor
 {
     int64_t operator()() { return 35; }
@@ -34,7 +39,13 @@ static_assert(asl::same_as<asl::invoke_result_t<Functor>, int64_t>);
 static_assert(asl::same_as<asl::invoke_result_t<Functor, int>, int>);
 static_assert(asl::same_as<asl::invoke_result_t<decltype(static_cast<float(*)(float)>(some_func1)), float>, float>);
 static_assert(asl::same_as<asl::invoke_result_t<decltype(&HasFunction::do_something), HasFunction, int, float>, void>);
-static_assert(asl::same_as<asl::invoke_result_t<decltype(&HasMember::member), const HasMember>, const int&>);
+static_assert(asl::same_as<asl::invoke_result_t<decltype(&HasFunction::do_something), const HasFunction2&, int, float>, void>);
+static_assert(asl::same_as<asl::invoke_result_t<decltype(&HasFunction::do_something), HasFunction*, int, float>, void>);
+static_assert(asl::same_as<asl::invoke_result_t<decltype(&HasFunction::do_something2), HasFunction2&, int, float>, int&>);
+static_assert(asl::same_as<asl::invoke_result_t<decltype(&HasFunction::do_something2), HasFunction*, int, float>, int&>);
+static_assert(asl::same_as<asl::invoke_result_t<decltype(&HasMember::member), HasMember>, int&&>);
+static_assert(asl::same_as<asl::invoke_result_t<decltype(&HasMember::member), const HasMember&>, const int&>);
+static_assert(asl::same_as<asl::invoke_result_t<decltype(&HasMember::member), const HasMember2*>, const int&>);
 
 static_assert(asl::invocable<int()>);
 static_assert(!asl::invocable<int(), int>);
@@ -45,8 +56,17 @@ static_assert(asl::invocable<Functor, int>);
 static_assert(!asl::invocable<Functor, void*>);
 static_assert(asl::invocable<decltype(static_cast<float(*)(float)>(some_func1)), float>);
 static_assert(asl::invocable<decltype(&HasFunction::do_something), HasFunction, int, float>);
+static_assert(asl::invocable<decltype(&HasFunction::do_something), const HasFunction2&, int, float>);
+static_assert(asl::invocable<decltype(&HasFunction::do_something), HasFunction*, int, float>);
 static_assert(!asl::invocable<decltype(&HasFunction::do_something), HasFunction, int, int*>);
-static_assert(asl::invocable<decltype(&HasMember::member), const HasMember>);
+static_assert(!asl::invocable<decltype(&HasFunction::do_something2), HasFunction, int, float>);
+static_assert(!asl::invocable<decltype(&HasFunction::do_something2), const HasFunction2&, int, float>);
+static_assert(asl::invocable<decltype(&HasFunction::do_something2), HasFunction2&, int, float>);
+static_assert(asl::invocable<decltype(&HasFunction::do_something2), HasFunction*, int, float>);
+static_assert(!asl::invocable<decltype(&HasFunction::do_something2), HasFunction, int, int*>);
+static_assert(asl::invocable<decltype(&HasMember::member), const HasMember2>);
+static_assert(asl::invocable<decltype(&HasMember::member), const HasMember&>);
+static_assert(asl::invocable<decltype(&HasMember::member), const HasMember2*>);
 
 ASL_TEST(invoke_member_function)
 {
diff --git a/asl/base/meta.hpp b/asl/base/meta.hpp
index b17f05c..1d57367 100644
--- a/asl/base/meta.hpp
+++ b/asl/base/meta.hpp
@@ -92,12 +92,6 @@ template<typename T> concept trivially_destructible = __is_trivially_destructibl
 template<typename T> concept copyable = copy_constructible<T> && copy_assignable<T>;
 template<typename T> concept moveable = move_constructible<T> && move_assignable<T>;
 
-template<typename From, typename To>
-concept convertible_to = __is_convertible(From, To);
-
-template<typename Derived, class Base>
-concept derived_from = __is_class(Derived) && __is_class(Base) && convertible_to<const volatile Derived*, const volatile Base*>;
-
 using nullptr_t = decltype(nullptr);
 
 template<typename T> struct _un_const_helper          { using type = T; };
@@ -154,6 +148,15 @@ template<typename T> struct _is_ptr_helper<T*> : true_type {};
 
 template<typename T> concept is_ptr = _is_ptr_helper<un_cv_t<T>>::value;
 
+template<typename From, typename To>
+concept convertible_to = __is_convertible(From, To);
+
+template<typename Derived, typename Base>
+concept derived_from = __is_class(Derived) && __is_class(Base) && convertible_to<const volatile Derived*, const volatile Base*>;
+
+template<typename Derived, typename Base>
+concept same_or_derived_from = same_as<un_cvref_t<Derived>, Base> || derived_from<un_cvref_t<Derived>, Base>;
+
 template<typename T> struct _tame_helper { using type = T; };
 
 #define TAME_HELPER_IMPL(TRAILING)                                  \
diff --git a/asl/base/meta_tests.cpp b/asl/base/meta_tests.cpp
index 65f367a..c221dcf 100644
--- a/asl/base/meta_tests.cpp
+++ b/asl/base/meta_tests.cpp
@@ -168,6 +168,8 @@ static_assert(!asl::is_member_func_ptr<void()>);
 static_assert(!asl::is_member_func_ptr<void() const &&>);
 static_assert(!asl::is_member_func_ptr<int MyClass::*>);
 static_assert(asl::is_member_func_ptr<int (MyClass::*)(int)>);
+static_assert(asl::is_member_func_ptr<int (MyClass::*)(int) const>);
+static_assert(asl::is_member_func_ptr<int (MyClass::*)(int) volatile &&>);
 
 static_assert(asl::same_as<int, asl::tame_t<int>>);
 static_assert(asl::same_as<int(), asl::tame_t<int()>>);
@@ -250,6 +252,8 @@ static_assert(!asl::convertible_to<const int16_t(&)[], int16_t(&)[]>);
 static_assert(!asl::convertible_to<D(&)[], C(&)[]>);
 
 static_assert(asl::derived_from<Derived, Base>);
+static_assert(asl::derived_from<Derived, Derived>);
+static_assert(asl::derived_from<Base, Base>);
 static_assert(!asl::derived_from<Base, Derived>);
 static_assert(!asl::derived_from<D, C>);
 static_assert(!asl::derived_from<C, D>);
@@ -257,6 +261,16 @@ static_assert(!asl::derived_from<uint8_t, uint16_t>);
 static_assert(!asl::derived_from<uint16_t, uint8_t>);
 static_assert(!asl::derived_from<int, int>);
 
+static_assert(asl::same_or_derived_from<Derived, Base>);
+static_assert(asl::same_or_derived_from<Derived, Derived>);
+static_assert(asl::same_or_derived_from<Base, Base>);
+static_assert(!asl::same_or_derived_from<Base, Derived>);
+static_assert(!asl::same_or_derived_from<D, C>);
+static_assert(!asl::same_or_derived_from<C, D>);
+static_assert(!asl::same_or_derived_from<uint8_t, uint16_t>);
+static_assert(!asl::same_or_derived_from<uint16_t, uint8_t>);
+static_assert(asl::same_or_derived_from<int, int>);
+
 static_assert(!asl::is_const<int>);
 static_assert(asl::is_const<const int>);
 static_assert(!asl::is_const<const int*>);
-- 
cgit