diff options
author | Steven Le Rouzic <steven.lerouzic@gmail.com> | 2025-03-20 00:23:28 +0100 |
---|---|---|
committer | Steven Le Rouzic <steven.lerouzic@gmail.com> | 2025-03-21 23:46:39 +0100 |
commit | c692909ff332de6f2e32db844458ccd03a080e53 (patch) | |
tree | e1d8f6997659052d3075868e27ec2b359e69dd90 | |
parent | a665c590d5089bb4bcb72193542b60ef571409a3 (diff) |
Add function
-rw-r--r-- | asl/memory/allocator.hpp | 6 | ||||
-rw-r--r-- | asl/types/BUILD.bazel | 22 | ||||
-rw-r--r-- | asl/types/function.hpp | 284 | ||||
-rw-r--r-- | asl/types/function_tests.cpp | 179 |
4 files changed, 489 insertions, 2 deletions
diff --git a/asl/memory/allocator.hpp b/asl/memory/allocator.hpp index bb6b992..a231558 100644 --- a/asl/memory/allocator.hpp +++ b/asl/memory/allocator.hpp @@ -50,13 +50,15 @@ void alloc_delete(allocator auto& a, T* ptr) template<typename T> constexpr T* alloc_new_default(auto&&... args) { - return alloc_new<T>(DefaultAllocator{}, std::forward<decltype(args)>(args)...); + DefaultAllocator allocator{}; + return alloc_new<T>(allocator, std::forward<decltype(args)>(args)...); } template<typename T> void alloc_delete_default(T* ptr) { - alloc_delete(DefaultAllocator{}, ptr); + DefaultAllocator allocator{}; + alloc_delete(allocator, ptr); } } // namespace asl diff --git a/asl/types/BUILD.bazel b/asl/types/BUILD.bazel index b60bb27..58d7183 100644 --- a/asl/types/BUILD.bazel +++ b/asl/types/BUILD.bazel @@ -78,6 +78,28 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "function", + hdrs = [ + "function.hpp", + ], + deps = [ + "//asl/base", + "//asl/memory", + ], + visibility = ["//visibility:public"], +) + +cc_test( + name = "function_tests", + srcs = ["function_tests.cpp"], + deps = [ + "//asl/tests:utils", + "//asl/testing", + "//asl/types:function", + ], +) + cc_test( name = "box_tests", srcs = ["box_tests.cpp"], diff --git a/asl/types/function.hpp b/asl/types/function.hpp new file mode 100644 index 0000000..6460711 --- /dev/null +++ b/asl/types/function.hpp @@ -0,0 +1,284 @@ +#pragma once + +#include "asl/base/utility.hpp" +#include "asl/base/meta.hpp" +#include "asl/base/integers.hpp" +#include "asl/memory/allocator.hpp" +#include "asl/base/functional.hpp" + +namespace asl +{ + +namespace function_detail +{ + +static constexpr isize_t kStorageSize = size_of<void*> * 2; + +struct Storage +{ + alignas(align_of<void*>) byte raw[kStorageSize]; + + [[nodiscard]] + void* get_ptr() const + { + // NOLINTNEXTLINE(*-const-cast) + return const_cast<void*>(static_cast<const void*>(raw)); + } +}; + +template<typename T> +concept can_be_stored_inline = + size_of<T> <= size_of<Storage> && + align_of<Storage> % align_of<T> == 0; + +enum class FunctionOp : uint8_t +{ + kDestroyThis, + kCopyFromOtherToThisUninit, + kMoveFromOtherToThisUninit, +}; + +template<typename Functor, bool kStoreInline = can_be_stored_inline<Functor>> +struct FunctionImplBase +{ + using Allocator = DefaultAllocator; + + template<typename T> + static void create(Storage* storage, T&& t) + { + Allocator allocator{}; + auto* ptr = alloc_new<Functor>(allocator, std::forward<T>(t)); + asl::memcpy(storage->get_ptr(), static_cast<void*>(&ptr), size_of<void*>); + } + + static Functor** get_functor_ptr(const Storage* storage) + { + // NOLINTNEXTLINE(*-reinterpret-cast) + return std::launder(reinterpret_cast<Functor**>(storage->get_ptr())); + } + + static Functor* get_functor(const Storage* storage) + { + return *get_functor_ptr(storage); + } + + static void op(Storage* this_storage, Storage* other_storage, FunctionOp op) + { + switch (op) + { + using enum FunctionOp; + case kDestroyThis: + { + Allocator allocator{}; + alloc_delete(allocator, get_functor(this_storage)); + break; + } + case kCopyFromOtherToThisUninit: + { + create(this_storage, *static_cast<const Functor*>(get_functor(other_storage))); + break; + } + case kMoveFromOtherToThisUninit: + { + auto* ptr = asl::exchange(*get_functor_ptr(other_storage), nullptr); + asl::memcpy(this_storage->get_ptr(), static_cast<void*>(&ptr), size_of<void*>); + break; + } + default: break; + } + } +}; + +template<typename Functor> +struct FunctionImplBase<Functor, true> +{ + template<typename T> + static void create(Storage* storage, T&& t) + { + new (storage->get_ptr()) Functor(std::forward<T>(t)); + } + + static Functor* get_functor(const Storage* storage) + { + // NOLINTNEXTLINE(*-reinterpret-cast) + return std::launder(reinterpret_cast<Functor*>(storage->get_ptr())); + } + + static void op(Storage* this_storage, Storage* other_storage, FunctionOp op) + { + switch (op) + { + using enum FunctionOp; + case kDestroyThis: + { + destroy(get_functor(this_storage)); + break; + } + case kCopyFromOtherToThisUninit: + { + create(this_storage, *static_cast<const Functor*>(get_functor(other_storage))); + break; + } + case kMoveFromOtherToThisUninit: + { + auto* other_functor = get_functor(other_storage); + create(this_storage, std::move(*static_cast<const Functor*>(other_functor))); + destroy(other_functor); + break; + } + default: break; + } + } +}; + +template<typename Functor, typename R, typename... Args> +struct FunctionImpl : FunctionImplBase<Functor> +{ + static R invoke(Args... args, const Storage& storage) + { + auto* functor = FunctionImplBase<Functor>::get_functor(&storage); + return asl::invoke(*functor, std::forward<Args>(args)...); + } +}; + + +template<typename T, typename R, typename... Args> +concept valid_functor = + copy_constructible<T> + && move_constructible<T> + && invocable<T, Args...> + && same_as<R, invoke_result_t<T, Args...>>; + +} // namespace function_detail + +template<typename T> +class function; + +template<typename R, typename... Args> +class function<R(Args...)> // NOLINT(*-member-init) +{ + using InvokeFn = R (*)(Args..., const function_detail::Storage&); + using OpFn = void (*)(function_detail::Storage*, function_detail::Storage*, function_detail::FunctionOp); + + function_detail::Storage m_storage; + InvokeFn m_invoke{}; + OpFn m_op{}; + + void destroy() + { + if (m_op != nullptr) + { + (*m_op)(&m_storage, nullptr, function_detail::FunctionOp::kDestroyThis); + } + } + +public: + function() = default; + + template<typename T> + function(T&& func) // NOLINT(*explicit*,*-member-init) + requires ( + !same_as<function, un_cvref_t<T>> + && function_detail::valid_functor<T, R, Args...> + ) + { + using Functor = decay_t<T>; + using Impl = function_detail::FunctionImpl<Functor, R, Args...>; + + Impl::create(&m_storage, std::forward<T>(func)); + m_invoke = &Impl::invoke; // NOLINT(*-member-initializer) + m_op = &Impl::op; // NOLINT(*-member-initializer) + } + + function(const function& other) // NOLINT(*-member-init) + : m_invoke{other.m_invoke} + , m_op{other.m_op} + { + if (m_op != nullptr) + { + (*m_op)( + &m_storage, + const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast) + function_detail::FunctionOp::kCopyFromOtherToThisUninit); + } + } + + function(function&& other) // NOLINT(*-member-init) + : m_invoke{asl::exchange(other.m_invoke, nullptr)} + , m_op{asl::exchange(other.m_op, nullptr)} + { + if (m_op != nullptr) + { + (*m_op)( + &m_storage, + &other.m_storage, + function_detail::FunctionOp::kMoveFromOtherToThisUninit); + } + } + + ~function() + { + destroy(); + } + + function& operator=(const function& other) + { + if (this != &other) + { + destroy(); + + m_invoke = other.m_invoke; + m_op = other.m_op; + + (*m_op)( + &m_storage, + const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast) + function_detail::FunctionOp::kCopyFromOtherToThisUninit); + } + return *this; + } + + function& operator=(function&& other) + { + if (this != &other) + { + destroy(); + + m_invoke = asl::exchange(other.m_invoke, nullptr); + m_op = asl::exchange(other.m_op, nullptr); + + (*m_op)( + &m_storage, + &other.m_storage, + function_detail::FunctionOp::kMoveFromOtherToThisUninit); + } + return *this; + } + + template<typename T> + function& operator=(T&& func) + requires ( + !same_as<function, un_cvref_t<T>> + && function_detail::valid_functor<T, R, Args...> + ) + { + destroy(); + + using Functor = decay_t<T>; + using Impl = function_detail::FunctionImpl<Functor, R, Args...>; + + Impl::create(&m_storage, std::forward<T>(func)); + m_invoke = &Impl::invoke; + m_op = &Impl::op; + + return *this; + } + + constexpr R operator()(Args... args) const + { + ASL_ASSERT(m_invoke); + return (*m_invoke)(args..., m_storage); + } +}; + +} // namespace asl diff --git a/asl/types/function_tests.cpp b/asl/types/function_tests.cpp new file mode 100644 index 0000000..5a55885 --- /dev/null +++ b/asl/types/function_tests.cpp @@ -0,0 +1,179 @@ +#include "asl/testing/testing.hpp" +#include "asl/types/function.hpp" + +static_assert(asl::function_detail::can_be_stored_inline<int(*)(int, int, int)>); +static_assert(asl::function_detail::can_be_stored_inline<decltype([](){})>); +static_assert(asl::function_detail::can_be_stored_inline<decltype([]() static {})>); +static_assert(asl::function_detail::can_be_stored_inline<decltype([a = 1ULL, b = 2ULL](){ return a + b; })>); // NOLINT +static_assert(asl::function_detail::can_be_stored_inline<decltype([a = 1ULL, b = 2ULL]() mutable { return a = b++; })>); // NOLINT +static_assert(!asl::function_detail::can_be_stored_inline<decltype([a = 1ULL, b = 2ULL, c = 3ULL](){ return a + b + c; })>); // NOLINT + +static int add(int x, int y) +{ + return x + y; +} + +ASL_TEST(function_pointer) +{ + const asl::function<int(int, int)> fn = add; + ASL_TEST_EXPECT(fn(4, 5) == 9); + ASL_TEST_EXPECT(fn(40, 50) == 90); +} + +ASL_TEST(lambda) +{ + const asl::function<int(int, int)> fn = [](int a, int b) { + return a + b; + }; + ASL_TEST_EXPECT(fn(4, 5) == 9); +} + +ASL_TEST(lambda_static) +{ + const asl::function<int(int, int)> fn = [](int a, int b) static { + return a + b; + }; + ASL_TEST_EXPECT(fn(4, 5) == 9); +} + +ASL_TEST(lambda_static_state) +{ + const asl::function<int(int)> fn = [state = 0](int b) mutable { + state += b; + return state; + }; + + ASL_TEST_EXPECT(fn(1) == 1); + ASL_TEST_EXPECT(fn(2) == 3); + ASL_TEST_EXPECT(fn(3) == 6); + ASL_TEST_EXPECT(fn(4) == 10); +} + +ASL_TEST(lambda_state) +{ + int state = 0; + const asl::function<void(int)> fn = [&state](int x) { + state += x; + }; + + ASL_TEST_EXPECT(state == 0); + + fn(5); + ASL_TEST_EXPECT(state == 5); + + fn(4); + ASL_TEST_EXPECT(state == 9); +} + +ASL_TEST(lambda_big_state) +{ + int s0 = 0; + int s1 = 0; + int s2 = 0; + int s3 = 0; + + const asl::function<void(int)> fn = [&](int x) { + s0 += x; + s1 += x + 1; + s2 += x + 2; + s3 += x + 3; + }; + + ASL_TEST_EXPECT(s0 == 0); + ASL_TEST_EXPECT(s1 == 0); + ASL_TEST_EXPECT(s2 == 0); + ASL_TEST_EXPECT(s3 == 0); + + fn(5); + ASL_TEST_EXPECT(s0 == 5); + ASL_TEST_EXPECT(s1 == 6); + ASL_TEST_EXPECT(s2 == 7); + ASL_TEST_EXPECT(s3 == 8); + + fn(4); + ASL_TEST_EXPECT(s0 == 9); + ASL_TEST_EXPECT(s1 == 11); + ASL_TEST_EXPECT(s2 == 13); + ASL_TEST_EXPECT(s3 == 15); +} + +struct Functor +{ + int state{}; + + int operator()(int x) + { + state += x; + return state; + } +}; + +ASL_TEST(functor) +{ + const asl::function<int(int)> fn = Functor{}; + + ASL_TEST_EXPECT(fn(1) == 1); + ASL_TEST_EXPECT(fn(2) == 3); + ASL_TEST_EXPECT(fn(3) == 6); + ASL_TEST_EXPECT(fn(4) == 10); +} + +ASL_TEST(copy_move_construct_small) +{ + asl::function<int(int, int)> fn = [x = 0](int a, int b) mutable { return x++ + a + b; }; + ASL_TEST_EXPECT(fn(1, 3) == 4); + + asl::function<int(int, int)> fn2 = fn; + ASL_TEST_EXPECT(fn(1, 3) == 5); + ASL_TEST_EXPECT(fn2(5, 3) == 9); + + asl::function<int(int, int)> fn3 = std::move(fn2); + ASL_TEST_EXPECT(fn(1, 3) == 6); + ASL_TEST_EXPECT(fn3(5, 3) == 10); + + fn2 = fn; + ASL_TEST_EXPECT(fn(1, 3) == 7); + ASL_TEST_EXPECT(fn2(5, 3) == 11); + ASL_TEST_EXPECT(fn3(5, 3) == 11); + + fn3 = std::move(fn); + ASL_TEST_EXPECT(fn2(5, 3) == 12); + ASL_TEST_EXPECT(fn3(5, 3) == 12); +} + +ASL_TEST(copy_move_construct_big) +{ + const int64_t v1 = 1; + const int64_t v2 = 2; + const int64_t v3 = 3; + const int64_t v4 = 4; + + asl::function<int64_t(int)> fn = [=](int x) { return v1 + v2 + v3 + v4 + x; }; + ASL_TEST_EXPECT(fn(1) == 11); + + asl::function<int64_t(int)> fn2 = fn; + ASL_TEST_EXPECT(fn(3) == 13); + ASL_TEST_EXPECT(fn2(5) == 15); + + asl::function<int64_t(int)> fn3 = std::move(fn2); + ASL_TEST_EXPECT(fn(1) == 11); + ASL_TEST_EXPECT(fn3(3) == 13); + + fn2 = fn; + ASL_TEST_EXPECT(fn(1) == 11); + ASL_TEST_EXPECT(fn2(5) == 15); + ASL_TEST_EXPECT(fn3(3) == 13); + + fn3 = std::move(fn); + ASL_TEST_EXPECT(fn2(5) == 15); + ASL_TEST_EXPECT(fn3(3) == 13); +} + +ASL_TEST(replace) +{ + asl::function<int(int)> fn = [](int x) { return x + 1; }; + ASL_TEST_EXPECT(fn(5) == 6); + + fn = [](int x) { return x + 3; }; + ASL_TEST_EXPECT(fn(5) == 8); +} |