diff options
Diffstat (limited to 'asl/types/function.hpp')
-rw-r--r-- | asl/types/function.hpp | 284 |
1 files changed, 284 insertions, 0 deletions
diff --git a/asl/types/function.hpp b/asl/types/function.hpp new file mode 100644 index 0000000..6460711 --- /dev/null +++ b/asl/types/function.hpp @@ -0,0 +1,284 @@ +#pragma once + +#include "asl/base/utility.hpp" +#include "asl/base/meta.hpp" +#include "asl/base/integers.hpp" +#include "asl/memory/allocator.hpp" +#include "asl/base/functional.hpp" + +namespace asl +{ + +namespace function_detail +{ + +static constexpr isize_t kStorageSize = size_of<void*> * 2; + +struct Storage +{ + alignas(align_of<void*>) byte raw[kStorageSize]; + + [[nodiscard]] + void* get_ptr() const + { + // NOLINTNEXTLINE(*-const-cast) + return const_cast<void*>(static_cast<const void*>(raw)); + } +}; + +template<typename T> +concept can_be_stored_inline = + size_of<T> <= size_of<Storage> && + align_of<Storage> % align_of<T> == 0; + +enum class FunctionOp : uint8_t +{ + kDestroyThis, + kCopyFromOtherToThisUninit, + kMoveFromOtherToThisUninit, +}; + +template<typename Functor, bool kStoreInline = can_be_stored_inline<Functor>> +struct FunctionImplBase +{ + using Allocator = DefaultAllocator; + + template<typename T> + static void create(Storage* storage, T&& t) + { + Allocator allocator{}; + auto* ptr = alloc_new<Functor>(allocator, std::forward<T>(t)); + asl::memcpy(storage->get_ptr(), static_cast<void*>(&ptr), size_of<void*>); + } + + static Functor** get_functor_ptr(const Storage* storage) + { + // NOLINTNEXTLINE(*-reinterpret-cast) + return std::launder(reinterpret_cast<Functor**>(storage->get_ptr())); + } + + static Functor* get_functor(const Storage* storage) + { + return *get_functor_ptr(storage); + } + + static void op(Storage* this_storage, Storage* other_storage, FunctionOp op) + { + switch (op) + { + using enum FunctionOp; + case kDestroyThis: + { + Allocator allocator{}; + alloc_delete(allocator, get_functor(this_storage)); + break; + } + case kCopyFromOtherToThisUninit: + { + create(this_storage, *static_cast<const Functor*>(get_functor(other_storage))); + break; + } + case kMoveFromOtherToThisUninit: + { + auto* ptr = asl::exchange(*get_functor_ptr(other_storage), nullptr); + asl::memcpy(this_storage->get_ptr(), static_cast<void*>(&ptr), size_of<void*>); + break; + } + default: break; + } + } +}; + +template<typename Functor> +struct FunctionImplBase<Functor, true> +{ + template<typename T> + static void create(Storage* storage, T&& t) + { + new (storage->get_ptr()) Functor(std::forward<T>(t)); + } + + static Functor* get_functor(const Storage* storage) + { + // NOLINTNEXTLINE(*-reinterpret-cast) + return std::launder(reinterpret_cast<Functor*>(storage->get_ptr())); + } + + static void op(Storage* this_storage, Storage* other_storage, FunctionOp op) + { + switch (op) + { + using enum FunctionOp; + case kDestroyThis: + { + destroy(get_functor(this_storage)); + break; + } + case kCopyFromOtherToThisUninit: + { + create(this_storage, *static_cast<const Functor*>(get_functor(other_storage))); + break; + } + case kMoveFromOtherToThisUninit: + { + auto* other_functor = get_functor(other_storage); + create(this_storage, std::move(*static_cast<const Functor*>(other_functor))); + destroy(other_functor); + break; + } + default: break; + } + } +}; + +template<typename Functor, typename R, typename... Args> +struct FunctionImpl : FunctionImplBase<Functor> +{ + static R invoke(Args... args, const Storage& storage) + { + auto* functor = FunctionImplBase<Functor>::get_functor(&storage); + return asl::invoke(*functor, std::forward<Args>(args)...); + } +}; + + +template<typename T, typename R, typename... Args> +concept valid_functor = + copy_constructible<T> + && move_constructible<T> + && invocable<T, Args...> + && same_as<R, invoke_result_t<T, Args...>>; + +} // namespace function_detail + +template<typename T> +class function; + +template<typename R, typename... Args> +class function<R(Args...)> // NOLINT(*-member-init) +{ + using InvokeFn = R (*)(Args..., const function_detail::Storage&); + using OpFn = void (*)(function_detail::Storage*, function_detail::Storage*, function_detail::FunctionOp); + + function_detail::Storage m_storage; + InvokeFn m_invoke{}; + OpFn m_op{}; + + void destroy() + { + if (m_op != nullptr) + { + (*m_op)(&m_storage, nullptr, function_detail::FunctionOp::kDestroyThis); + } + } + +public: + function() = default; + + template<typename T> + function(T&& func) // NOLINT(*explicit*,*-member-init) + requires ( + !same_as<function, un_cvref_t<T>> + && function_detail::valid_functor<T, R, Args...> + ) + { + using Functor = decay_t<T>; + using Impl = function_detail::FunctionImpl<Functor, R, Args...>; + + Impl::create(&m_storage, std::forward<T>(func)); + m_invoke = &Impl::invoke; // NOLINT(*-member-initializer) + m_op = &Impl::op; // NOLINT(*-member-initializer) + } + + function(const function& other) // NOLINT(*-member-init) + : m_invoke{other.m_invoke} + , m_op{other.m_op} + { + if (m_op != nullptr) + { + (*m_op)( + &m_storage, + const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast) + function_detail::FunctionOp::kCopyFromOtherToThisUninit); + } + } + + function(function&& other) // NOLINT(*-member-init) + : m_invoke{asl::exchange(other.m_invoke, nullptr)} + , m_op{asl::exchange(other.m_op, nullptr)} + { + if (m_op != nullptr) + { + (*m_op)( + &m_storage, + &other.m_storage, + function_detail::FunctionOp::kMoveFromOtherToThisUninit); + } + } + + ~function() + { + destroy(); + } + + function& operator=(const function& other) + { + if (this != &other) + { + destroy(); + + m_invoke = other.m_invoke; + m_op = other.m_op; + + (*m_op)( + &m_storage, + const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast) + function_detail::FunctionOp::kCopyFromOtherToThisUninit); + } + return *this; + } + + function& operator=(function&& other) + { + if (this != &other) + { + destroy(); + + m_invoke = asl::exchange(other.m_invoke, nullptr); + m_op = asl::exchange(other.m_op, nullptr); + + (*m_op)( + &m_storage, + &other.m_storage, + function_detail::FunctionOp::kMoveFromOtherToThisUninit); + } + return *this; + } + + template<typename T> + function& operator=(T&& func) + requires ( + !same_as<function, un_cvref_t<T>> + && function_detail::valid_functor<T, R, Args...> + ) + { + destroy(); + + using Functor = decay_t<T>; + using Impl = function_detail::FunctionImpl<Functor, R, Args...>; + + Impl::create(&m_storage, std::forward<T>(func)); + m_invoke = &Impl::invoke; + m_op = &Impl::op; + + return *this; + } + + constexpr R operator()(Args... args) const + { + ASL_ASSERT(m_invoke); + return (*m_invoke)(args..., m_storage); + } +}; + +} // namespace asl |