summaryrefslogtreecommitdiff
path: root/asl/types/function.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'asl/types/function.hpp')
-rw-r--r--asl/types/function.hpp284
1 files changed, 284 insertions, 0 deletions
diff --git a/asl/types/function.hpp b/asl/types/function.hpp
new file mode 100644
index 0000000..6460711
--- /dev/null
+++ b/asl/types/function.hpp
@@ -0,0 +1,284 @@
+#pragma once
+
+#include "asl/base/utility.hpp"
+#include "asl/base/meta.hpp"
+#include "asl/base/integers.hpp"
+#include "asl/memory/allocator.hpp"
+#include "asl/base/functional.hpp"
+
+namespace asl
+{
+
+namespace function_detail
+{
+
+static constexpr isize_t kStorageSize = size_of<void*> * 2;
+
+struct Storage
+{
+ alignas(align_of<void*>) byte raw[kStorageSize];
+
+ [[nodiscard]]
+ void* get_ptr() const
+ {
+ // NOLINTNEXTLINE(*-const-cast)
+ return const_cast<void*>(static_cast<const void*>(raw));
+ }
+};
+
+template<typename T>
+concept can_be_stored_inline =
+ size_of<T> <= size_of<Storage> &&
+ align_of<Storage> % align_of<T> == 0;
+
+enum class FunctionOp : uint8_t
+{
+ kDestroyThis,
+ kCopyFromOtherToThisUninit,
+ kMoveFromOtherToThisUninit,
+};
+
+template<typename Functor, bool kStoreInline = can_be_stored_inline<Functor>>
+struct FunctionImplBase
+{
+ using Allocator = DefaultAllocator;
+
+ template<typename T>
+ static void create(Storage* storage, T&& t)
+ {
+ Allocator allocator{};
+ auto* ptr = alloc_new<Functor>(allocator, std::forward<T>(t));
+ asl::memcpy(storage->get_ptr(), static_cast<void*>(&ptr), size_of<void*>);
+ }
+
+ static Functor** get_functor_ptr(const Storage* storage)
+ {
+ // NOLINTNEXTLINE(*-reinterpret-cast)
+ return std::launder(reinterpret_cast<Functor**>(storage->get_ptr()));
+ }
+
+ static Functor* get_functor(const Storage* storage)
+ {
+ return *get_functor_ptr(storage);
+ }
+
+ static void op(Storage* this_storage, Storage* other_storage, FunctionOp op)
+ {
+ switch (op)
+ {
+ using enum FunctionOp;
+ case kDestroyThis:
+ {
+ Allocator allocator{};
+ alloc_delete(allocator, get_functor(this_storage));
+ break;
+ }
+ case kCopyFromOtherToThisUninit:
+ {
+ create(this_storage, *static_cast<const Functor*>(get_functor(other_storage)));
+ break;
+ }
+ case kMoveFromOtherToThisUninit:
+ {
+ auto* ptr = asl::exchange(*get_functor_ptr(other_storage), nullptr);
+ asl::memcpy(this_storage->get_ptr(), static_cast<void*>(&ptr), size_of<void*>);
+ break;
+ }
+ default: break;
+ }
+ }
+};
+
+template<typename Functor>
+struct FunctionImplBase<Functor, true>
+{
+ template<typename T>
+ static void create(Storage* storage, T&& t)
+ {
+ new (storage->get_ptr()) Functor(std::forward<T>(t));
+ }
+
+ static Functor* get_functor(const Storage* storage)
+ {
+ // NOLINTNEXTLINE(*-reinterpret-cast)
+ return std::launder(reinterpret_cast<Functor*>(storage->get_ptr()));
+ }
+
+ static void op(Storage* this_storage, Storage* other_storage, FunctionOp op)
+ {
+ switch (op)
+ {
+ using enum FunctionOp;
+ case kDestroyThis:
+ {
+ destroy(get_functor(this_storage));
+ break;
+ }
+ case kCopyFromOtherToThisUninit:
+ {
+ create(this_storage, *static_cast<const Functor*>(get_functor(other_storage)));
+ break;
+ }
+ case kMoveFromOtherToThisUninit:
+ {
+ auto* other_functor = get_functor(other_storage);
+ create(this_storage, std::move(*static_cast<const Functor*>(other_functor)));
+ destroy(other_functor);
+ break;
+ }
+ default: break;
+ }
+ }
+};
+
+template<typename Functor, typename R, typename... Args>
+struct FunctionImpl : FunctionImplBase<Functor>
+{
+ static R invoke(Args... args, const Storage& storage)
+ {
+ auto* functor = FunctionImplBase<Functor>::get_functor(&storage);
+ return asl::invoke(*functor, std::forward<Args>(args)...);
+ }
+};
+
+
+template<typename T, typename R, typename... Args>
+concept valid_functor =
+ copy_constructible<T>
+ && move_constructible<T>
+ && invocable<T, Args...>
+ && same_as<R, invoke_result_t<T, Args...>>;
+
+} // namespace function_detail
+
+template<typename T>
+class function;
+
+template<typename R, typename... Args>
+class function<R(Args...)> // NOLINT(*-member-init)
+{
+ using InvokeFn = R (*)(Args..., const function_detail::Storage&);
+ using OpFn = void (*)(function_detail::Storage*, function_detail::Storage*, function_detail::FunctionOp);
+
+ function_detail::Storage m_storage;
+ InvokeFn m_invoke{};
+ OpFn m_op{};
+
+ void destroy()
+ {
+ if (m_op != nullptr)
+ {
+ (*m_op)(&m_storage, nullptr, function_detail::FunctionOp::kDestroyThis);
+ }
+ }
+
+public:
+ function() = default;
+
+ template<typename T>
+ function(T&& func) // NOLINT(*explicit*,*-member-init)
+ requires (
+ !same_as<function, un_cvref_t<T>>
+ && function_detail::valid_functor<T, R, Args...>
+ )
+ {
+ using Functor = decay_t<T>;
+ using Impl = function_detail::FunctionImpl<Functor, R, Args...>;
+
+ Impl::create(&m_storage, std::forward<T>(func));
+ m_invoke = &Impl::invoke; // NOLINT(*-member-initializer)
+ m_op = &Impl::op; // NOLINT(*-member-initializer)
+ }
+
+ function(const function& other) // NOLINT(*-member-init)
+ : m_invoke{other.m_invoke}
+ , m_op{other.m_op}
+ {
+ if (m_op != nullptr)
+ {
+ (*m_op)(
+ &m_storage,
+ const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast)
+ function_detail::FunctionOp::kCopyFromOtherToThisUninit);
+ }
+ }
+
+ function(function&& other) // NOLINT(*-member-init)
+ : m_invoke{asl::exchange(other.m_invoke, nullptr)}
+ , m_op{asl::exchange(other.m_op, nullptr)}
+ {
+ if (m_op != nullptr)
+ {
+ (*m_op)(
+ &m_storage,
+ &other.m_storage,
+ function_detail::FunctionOp::kMoveFromOtherToThisUninit);
+ }
+ }
+
+ ~function()
+ {
+ destroy();
+ }
+
+ function& operator=(const function& other)
+ {
+ if (this != &other)
+ {
+ destroy();
+
+ m_invoke = other.m_invoke;
+ m_op = other.m_op;
+
+ (*m_op)(
+ &m_storage,
+ const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast)
+ function_detail::FunctionOp::kCopyFromOtherToThisUninit);
+ }
+ return *this;
+ }
+
+ function& operator=(function&& other)
+ {
+ if (this != &other)
+ {
+ destroy();
+
+ m_invoke = asl::exchange(other.m_invoke, nullptr);
+ m_op = asl::exchange(other.m_op, nullptr);
+
+ (*m_op)(
+ &m_storage,
+ &other.m_storage,
+ function_detail::FunctionOp::kMoveFromOtherToThisUninit);
+ }
+ return *this;
+ }
+
+ template<typename T>
+ function& operator=(T&& func)
+ requires (
+ !same_as<function, un_cvref_t<T>>
+ && function_detail::valid_functor<T, R, Args...>
+ )
+ {
+ destroy();
+
+ using Functor = decay_t<T>;
+ using Impl = function_detail::FunctionImpl<Functor, R, Args...>;
+
+ Impl::create(&m_storage, std::forward<T>(func));
+ m_invoke = &Impl::invoke;
+ m_op = &Impl::op;
+
+ return *this;
+ }
+
+ constexpr R operator()(Args... args) const
+ {
+ ASL_ASSERT(m_invoke);
+ return (*m_invoke)(args..., m_storage);
+ }
+};
+
+} // namespace asl