summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSteven Le Rouzic <steven.lerouzic@gmail.com>2025-03-20 00:23:28 +0100
committerSteven Le Rouzic <steven.lerouzic@gmail.com>2025-03-21 23:46:39 +0100
commitc692909ff332de6f2e32db844458ccd03a080e53 (patch)
treee1d8f6997659052d3075868e27ec2b359e69dd90
parenta665c590d5089bb4bcb72193542b60ef571409a3 (diff)
Add function
-rw-r--r--asl/memory/allocator.hpp6
-rw-r--r--asl/types/BUILD.bazel22
-rw-r--r--asl/types/function.hpp284
-rw-r--r--asl/types/function_tests.cpp179
4 files changed, 489 insertions, 2 deletions
diff --git a/asl/memory/allocator.hpp b/asl/memory/allocator.hpp
index bb6b992..a231558 100644
--- a/asl/memory/allocator.hpp
+++ b/asl/memory/allocator.hpp
@@ -50,13 +50,15 @@ void alloc_delete(allocator auto& a, T* ptr)
template<typename T>
constexpr T* alloc_new_default(auto&&... args)
{
- return alloc_new<T>(DefaultAllocator{}, std::forward<decltype(args)>(args)...);
+ DefaultAllocator allocator{};
+ return alloc_new<T>(allocator, std::forward<decltype(args)>(args)...);
}
template<typename T>
void alloc_delete_default(T* ptr)
{
- alloc_delete(DefaultAllocator{}, ptr);
+ DefaultAllocator allocator{};
+ alloc_delete(allocator, ptr);
}
} // namespace asl
diff --git a/asl/types/BUILD.bazel b/asl/types/BUILD.bazel
index b60bb27..58d7183 100644
--- a/asl/types/BUILD.bazel
+++ b/asl/types/BUILD.bazel
@@ -78,6 +78,28 @@ cc_library(
visibility = ["//visibility:public"],
)
+cc_library(
+ name = "function",
+ hdrs = [
+ "function.hpp",
+ ],
+ deps = [
+ "//asl/base",
+ "//asl/memory",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_test(
+ name = "function_tests",
+ srcs = ["function_tests.cpp"],
+ deps = [
+ "//asl/tests:utils",
+ "//asl/testing",
+ "//asl/types:function",
+ ],
+)
+
cc_test(
name = "box_tests",
srcs = ["box_tests.cpp"],
diff --git a/asl/types/function.hpp b/asl/types/function.hpp
new file mode 100644
index 0000000..6460711
--- /dev/null
+++ b/asl/types/function.hpp
@@ -0,0 +1,284 @@
+#pragma once
+
+#include "asl/base/utility.hpp"
+#include "asl/base/meta.hpp"
+#include "asl/base/integers.hpp"
+#include "asl/memory/allocator.hpp"
+#include "asl/base/functional.hpp"
+
+namespace asl
+{
+
+namespace function_detail
+{
+
+static constexpr isize_t kStorageSize = size_of<void*> * 2;
+
+struct Storage
+{
+ alignas(align_of<void*>) byte raw[kStorageSize];
+
+ [[nodiscard]]
+ void* get_ptr() const
+ {
+ // NOLINTNEXTLINE(*-const-cast)
+ return const_cast<void*>(static_cast<const void*>(raw));
+ }
+};
+
+template<typename T>
+concept can_be_stored_inline =
+ size_of<T> <= size_of<Storage> &&
+ align_of<Storage> % align_of<T> == 0;
+
+enum class FunctionOp : uint8_t
+{
+ kDestroyThis,
+ kCopyFromOtherToThisUninit,
+ kMoveFromOtherToThisUninit,
+};
+
+template<typename Functor, bool kStoreInline = can_be_stored_inline<Functor>>
+struct FunctionImplBase
+{
+ using Allocator = DefaultAllocator;
+
+ template<typename T>
+ static void create(Storage* storage, T&& t)
+ {
+ Allocator allocator{};
+ auto* ptr = alloc_new<Functor>(allocator, std::forward<T>(t));
+ asl::memcpy(storage->get_ptr(), static_cast<void*>(&ptr), size_of<void*>);
+ }
+
+ static Functor** get_functor_ptr(const Storage* storage)
+ {
+ // NOLINTNEXTLINE(*-reinterpret-cast)
+ return std::launder(reinterpret_cast<Functor**>(storage->get_ptr()));
+ }
+
+ static Functor* get_functor(const Storage* storage)
+ {
+ return *get_functor_ptr(storage);
+ }
+
+ static void op(Storage* this_storage, Storage* other_storage, FunctionOp op)
+ {
+ switch (op)
+ {
+ using enum FunctionOp;
+ case kDestroyThis:
+ {
+ Allocator allocator{};
+ alloc_delete(allocator, get_functor(this_storage));
+ break;
+ }
+ case kCopyFromOtherToThisUninit:
+ {
+ create(this_storage, *static_cast<const Functor*>(get_functor(other_storage)));
+ break;
+ }
+ case kMoveFromOtherToThisUninit:
+ {
+ auto* ptr = asl::exchange(*get_functor_ptr(other_storage), nullptr);
+ asl::memcpy(this_storage->get_ptr(), static_cast<void*>(&ptr), size_of<void*>);
+ break;
+ }
+ default: break;
+ }
+ }
+};
+
+template<typename Functor>
+struct FunctionImplBase<Functor, true>
+{
+ template<typename T>
+ static void create(Storage* storage, T&& t)
+ {
+ new (storage->get_ptr()) Functor(std::forward<T>(t));
+ }
+
+ static Functor* get_functor(const Storage* storage)
+ {
+ // NOLINTNEXTLINE(*-reinterpret-cast)
+ return std::launder(reinterpret_cast<Functor*>(storage->get_ptr()));
+ }
+
+ static void op(Storage* this_storage, Storage* other_storage, FunctionOp op)
+ {
+ switch (op)
+ {
+ using enum FunctionOp;
+ case kDestroyThis:
+ {
+ destroy(get_functor(this_storage));
+ break;
+ }
+ case kCopyFromOtherToThisUninit:
+ {
+ create(this_storage, *static_cast<const Functor*>(get_functor(other_storage)));
+ break;
+ }
+ case kMoveFromOtherToThisUninit:
+ {
+ auto* other_functor = get_functor(other_storage);
+ create(this_storage, std::move(*static_cast<const Functor*>(other_functor)));
+ destroy(other_functor);
+ break;
+ }
+ default: break;
+ }
+ }
+};
+
+template<typename Functor, typename R, typename... Args>
+struct FunctionImpl : FunctionImplBase<Functor>
+{
+ static R invoke(Args... args, const Storage& storage)
+ {
+ auto* functor = FunctionImplBase<Functor>::get_functor(&storage);
+ return asl::invoke(*functor, std::forward<Args>(args)...);
+ }
+};
+
+
+template<typename T, typename R, typename... Args>
+concept valid_functor =
+ copy_constructible<T>
+ && move_constructible<T>
+ && invocable<T, Args...>
+ && same_as<R, invoke_result_t<T, Args...>>;
+
+} // namespace function_detail
+
+template<typename T>
+class function;
+
+template<typename R, typename... Args>
+class function<R(Args...)> // NOLINT(*-member-init)
+{
+ using InvokeFn = R (*)(Args..., const function_detail::Storage&);
+ using OpFn = void (*)(function_detail::Storage*, function_detail::Storage*, function_detail::FunctionOp);
+
+ function_detail::Storage m_storage;
+ InvokeFn m_invoke{};
+ OpFn m_op{};
+
+ void destroy()
+ {
+ if (m_op != nullptr)
+ {
+ (*m_op)(&m_storage, nullptr, function_detail::FunctionOp::kDestroyThis);
+ }
+ }
+
+public:
+ function() = default;
+
+ template<typename T>
+ function(T&& func) // NOLINT(*explicit*,*-member-init)
+ requires (
+ !same_as<function, un_cvref_t<T>>
+ && function_detail::valid_functor<T, R, Args...>
+ )
+ {
+ using Functor = decay_t<T>;
+ using Impl = function_detail::FunctionImpl<Functor, R, Args...>;
+
+ Impl::create(&m_storage, std::forward<T>(func));
+ m_invoke = &Impl::invoke; // NOLINT(*-member-initializer)
+ m_op = &Impl::op; // NOLINT(*-member-initializer)
+ }
+
+ function(const function& other) // NOLINT(*-member-init)
+ : m_invoke{other.m_invoke}
+ , m_op{other.m_op}
+ {
+ if (m_op != nullptr)
+ {
+ (*m_op)(
+ &m_storage,
+ const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast)
+ function_detail::FunctionOp::kCopyFromOtherToThisUninit);
+ }
+ }
+
+ function(function&& other) // NOLINT(*-member-init)
+ : m_invoke{asl::exchange(other.m_invoke, nullptr)}
+ , m_op{asl::exchange(other.m_op, nullptr)}
+ {
+ if (m_op != nullptr)
+ {
+ (*m_op)(
+ &m_storage,
+ &other.m_storage,
+ function_detail::FunctionOp::kMoveFromOtherToThisUninit);
+ }
+ }
+
+ ~function()
+ {
+ destroy();
+ }
+
+ function& operator=(const function& other)
+ {
+ if (this != &other)
+ {
+ destroy();
+
+ m_invoke = other.m_invoke;
+ m_op = other.m_op;
+
+ (*m_op)(
+ &m_storage,
+ const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast)
+ function_detail::FunctionOp::kCopyFromOtherToThisUninit);
+ }
+ return *this;
+ }
+
+ function& operator=(function&& other)
+ {
+ if (this != &other)
+ {
+ destroy();
+
+ m_invoke = asl::exchange(other.m_invoke, nullptr);
+ m_op = asl::exchange(other.m_op, nullptr);
+
+ (*m_op)(
+ &m_storage,
+ &other.m_storage,
+ function_detail::FunctionOp::kMoveFromOtherToThisUninit);
+ }
+ return *this;
+ }
+
+ template<typename T>
+ function& operator=(T&& func)
+ requires (
+ !same_as<function, un_cvref_t<T>>
+ && function_detail::valid_functor<T, R, Args...>
+ )
+ {
+ destroy();
+
+ using Functor = decay_t<T>;
+ using Impl = function_detail::FunctionImpl<Functor, R, Args...>;
+
+ Impl::create(&m_storage, std::forward<T>(func));
+ m_invoke = &Impl::invoke;
+ m_op = &Impl::op;
+
+ return *this;
+ }
+
+ constexpr R operator()(Args... args) const
+ {
+ ASL_ASSERT(m_invoke);
+ return (*m_invoke)(args..., m_storage);
+ }
+};
+
+} // namespace asl
diff --git a/asl/types/function_tests.cpp b/asl/types/function_tests.cpp
new file mode 100644
index 0000000..5a55885
--- /dev/null
+++ b/asl/types/function_tests.cpp
@@ -0,0 +1,179 @@
+#include "asl/testing/testing.hpp"
+#include "asl/types/function.hpp"
+
+static_assert(asl::function_detail::can_be_stored_inline<int(*)(int, int, int)>);
+static_assert(asl::function_detail::can_be_stored_inline<decltype([](){})>);
+static_assert(asl::function_detail::can_be_stored_inline<decltype([]() static {})>);
+static_assert(asl::function_detail::can_be_stored_inline<decltype([a = 1ULL, b = 2ULL](){ return a + b; })>); // NOLINT
+static_assert(asl::function_detail::can_be_stored_inline<decltype([a = 1ULL, b = 2ULL]() mutable { return a = b++; })>); // NOLINT
+static_assert(!asl::function_detail::can_be_stored_inline<decltype([a = 1ULL, b = 2ULL, c = 3ULL](){ return a + b + c; })>); // NOLINT
+
+static int add(int x, int y)
+{
+ return x + y;
+}
+
+ASL_TEST(function_pointer)
+{
+ const asl::function<int(int, int)> fn = add;
+ ASL_TEST_EXPECT(fn(4, 5) == 9);
+ ASL_TEST_EXPECT(fn(40, 50) == 90);
+}
+
+ASL_TEST(lambda)
+{
+ const asl::function<int(int, int)> fn = [](int a, int b) {
+ return a + b;
+ };
+ ASL_TEST_EXPECT(fn(4, 5) == 9);
+}
+
+ASL_TEST(lambda_static)
+{
+ const asl::function<int(int, int)> fn = [](int a, int b) static {
+ return a + b;
+ };
+ ASL_TEST_EXPECT(fn(4, 5) == 9);
+}
+
+ASL_TEST(lambda_static_state)
+{
+ const asl::function<int(int)> fn = [state = 0](int b) mutable {
+ state += b;
+ return state;
+ };
+
+ ASL_TEST_EXPECT(fn(1) == 1);
+ ASL_TEST_EXPECT(fn(2) == 3);
+ ASL_TEST_EXPECT(fn(3) == 6);
+ ASL_TEST_EXPECT(fn(4) == 10);
+}
+
+ASL_TEST(lambda_state)
+{
+ int state = 0;
+ const asl::function<void(int)> fn = [&state](int x) {
+ state += x;
+ };
+
+ ASL_TEST_EXPECT(state == 0);
+
+ fn(5);
+ ASL_TEST_EXPECT(state == 5);
+
+ fn(4);
+ ASL_TEST_EXPECT(state == 9);
+}
+
+ASL_TEST(lambda_big_state)
+{
+ int s0 = 0;
+ int s1 = 0;
+ int s2 = 0;
+ int s3 = 0;
+
+ const asl::function<void(int)> fn = [&](int x) {
+ s0 += x;
+ s1 += x + 1;
+ s2 += x + 2;
+ s3 += x + 3;
+ };
+
+ ASL_TEST_EXPECT(s0 == 0);
+ ASL_TEST_EXPECT(s1 == 0);
+ ASL_TEST_EXPECT(s2 == 0);
+ ASL_TEST_EXPECT(s3 == 0);
+
+ fn(5);
+ ASL_TEST_EXPECT(s0 == 5);
+ ASL_TEST_EXPECT(s1 == 6);
+ ASL_TEST_EXPECT(s2 == 7);
+ ASL_TEST_EXPECT(s3 == 8);
+
+ fn(4);
+ ASL_TEST_EXPECT(s0 == 9);
+ ASL_TEST_EXPECT(s1 == 11);
+ ASL_TEST_EXPECT(s2 == 13);
+ ASL_TEST_EXPECT(s3 == 15);
+}
+
+struct Functor
+{
+ int state{};
+
+ int operator()(int x)
+ {
+ state += x;
+ return state;
+ }
+};
+
+ASL_TEST(functor)
+{
+ const asl::function<int(int)> fn = Functor{};
+
+ ASL_TEST_EXPECT(fn(1) == 1);
+ ASL_TEST_EXPECT(fn(2) == 3);
+ ASL_TEST_EXPECT(fn(3) == 6);
+ ASL_TEST_EXPECT(fn(4) == 10);
+}
+
+ASL_TEST(copy_move_construct_small)
+{
+ asl::function<int(int, int)> fn = [x = 0](int a, int b) mutable { return x++ + a + b; };
+ ASL_TEST_EXPECT(fn(1, 3) == 4);
+
+ asl::function<int(int, int)> fn2 = fn;
+ ASL_TEST_EXPECT(fn(1, 3) == 5);
+ ASL_TEST_EXPECT(fn2(5, 3) == 9);
+
+ asl::function<int(int, int)> fn3 = std::move(fn2);
+ ASL_TEST_EXPECT(fn(1, 3) == 6);
+ ASL_TEST_EXPECT(fn3(5, 3) == 10);
+
+ fn2 = fn;
+ ASL_TEST_EXPECT(fn(1, 3) == 7);
+ ASL_TEST_EXPECT(fn2(5, 3) == 11);
+ ASL_TEST_EXPECT(fn3(5, 3) == 11);
+
+ fn3 = std::move(fn);
+ ASL_TEST_EXPECT(fn2(5, 3) == 12);
+ ASL_TEST_EXPECT(fn3(5, 3) == 12);
+}
+
+ASL_TEST(copy_move_construct_big)
+{
+ const int64_t v1 = 1;
+ const int64_t v2 = 2;
+ const int64_t v3 = 3;
+ const int64_t v4 = 4;
+
+ asl::function<int64_t(int)> fn = [=](int x) { return v1 + v2 + v3 + v4 + x; };
+ ASL_TEST_EXPECT(fn(1) == 11);
+
+ asl::function<int64_t(int)> fn2 = fn;
+ ASL_TEST_EXPECT(fn(3) == 13);
+ ASL_TEST_EXPECT(fn2(5) == 15);
+
+ asl::function<int64_t(int)> fn3 = std::move(fn2);
+ ASL_TEST_EXPECT(fn(1) == 11);
+ ASL_TEST_EXPECT(fn3(3) == 13);
+
+ fn2 = fn;
+ ASL_TEST_EXPECT(fn(1) == 11);
+ ASL_TEST_EXPECT(fn2(5) == 15);
+ ASL_TEST_EXPECT(fn3(3) == 13);
+
+ fn3 = std::move(fn);
+ ASL_TEST_EXPECT(fn2(5) == 15);
+ ASL_TEST_EXPECT(fn3(3) == 13);
+}
+
+ASL_TEST(replace)
+{
+ asl::function<int(int)> fn = [](int x) { return x + 1; };
+ ASL_TEST_EXPECT(fn(5) == 6);
+
+ fn = [](int x) { return x + 3; };
+ ASL_TEST_EXPECT(fn(5) == 8);
+}