Add function

This commit is contained in:
2025-03-20 00:23:28 +01:00
parent a665c590d5
commit c692909ff3
4 changed files with 489 additions and 2 deletions

View File

@ -50,13 +50,15 @@ void alloc_delete(allocator auto& a, T* ptr)
template<typename T>
constexpr T* alloc_new_default(auto&&... args)
{
return alloc_new<T>(DefaultAllocator{}, std::forward<decltype(args)>(args)...);
DefaultAllocator allocator{};
return alloc_new<T>(allocator, std::forward<decltype(args)>(args)...);
}
template<typename T>
void alloc_delete_default(T* ptr)
{
alloc_delete(DefaultAllocator{}, ptr);
DefaultAllocator allocator{};
alloc_delete(allocator, ptr);
}
} // namespace asl

View File

@ -78,6 +78,28 @@ cc_library(
visibility = ["//visibility:public"],
)
cc_library(
name = "function",
hdrs = [
"function.hpp",
],
deps = [
"//asl/base",
"//asl/memory",
],
visibility = ["//visibility:public"],
)
cc_test(
name = "function_tests",
srcs = ["function_tests.cpp"],
deps = [
"//asl/tests:utils",
"//asl/testing",
"//asl/types:function",
],
)
cc_test(
name = "box_tests",
srcs = ["box_tests.cpp"],

284
asl/types/function.hpp Normal file
View File

@ -0,0 +1,284 @@
#pragma once
#include "asl/base/utility.hpp"
#include "asl/base/meta.hpp"
#include "asl/base/integers.hpp"
#include "asl/memory/allocator.hpp"
#include "asl/base/functional.hpp"
namespace asl
{
namespace function_detail
{
static constexpr isize_t kStorageSize = size_of<void*> * 2;
struct Storage
{
alignas(align_of<void*>) byte raw[kStorageSize];
[[nodiscard]]
void* get_ptr() const
{
// NOLINTNEXTLINE(*-const-cast)
return const_cast<void*>(static_cast<const void*>(raw));
}
};
template<typename T>
concept can_be_stored_inline =
size_of<T> <= size_of<Storage> &&
align_of<Storage> % align_of<T> == 0;
enum class FunctionOp : uint8_t
{
kDestroyThis,
kCopyFromOtherToThisUninit,
kMoveFromOtherToThisUninit,
};
template<typename Functor, bool kStoreInline = can_be_stored_inline<Functor>>
struct FunctionImplBase
{
using Allocator = DefaultAllocator;
template<typename T>
static void create(Storage* storage, T&& t)
{
Allocator allocator{};
auto* ptr = alloc_new<Functor>(allocator, std::forward<T>(t));
asl::memcpy(storage->get_ptr(), static_cast<void*>(&ptr), size_of<void*>);
}
static Functor** get_functor_ptr(const Storage* storage)
{
// NOLINTNEXTLINE(*-reinterpret-cast)
return std::launder(reinterpret_cast<Functor**>(storage->get_ptr()));
}
static Functor* get_functor(const Storage* storage)
{
return *get_functor_ptr(storage);
}
static void op(Storage* this_storage, Storage* other_storage, FunctionOp op)
{
switch (op)
{
using enum FunctionOp;
case kDestroyThis:
{
Allocator allocator{};
alloc_delete(allocator, get_functor(this_storage));
break;
}
case kCopyFromOtherToThisUninit:
{
create(this_storage, *static_cast<const Functor*>(get_functor(other_storage)));
break;
}
case kMoveFromOtherToThisUninit:
{
auto* ptr = asl::exchange(*get_functor_ptr(other_storage), nullptr);
asl::memcpy(this_storage->get_ptr(), static_cast<void*>(&ptr), size_of<void*>);
break;
}
default: break;
}
}
};
template<typename Functor>
struct FunctionImplBase<Functor, true>
{
template<typename T>
static void create(Storage* storage, T&& t)
{
new (storage->get_ptr()) Functor(std::forward<T>(t));
}
static Functor* get_functor(const Storage* storage)
{
// NOLINTNEXTLINE(*-reinterpret-cast)
return std::launder(reinterpret_cast<Functor*>(storage->get_ptr()));
}
static void op(Storage* this_storage, Storage* other_storage, FunctionOp op)
{
switch (op)
{
using enum FunctionOp;
case kDestroyThis:
{
destroy(get_functor(this_storage));
break;
}
case kCopyFromOtherToThisUninit:
{
create(this_storage, *static_cast<const Functor*>(get_functor(other_storage)));
break;
}
case kMoveFromOtherToThisUninit:
{
auto* other_functor = get_functor(other_storage);
create(this_storage, std::move(*static_cast<const Functor*>(other_functor)));
destroy(other_functor);
break;
}
default: break;
}
}
};
template<typename Functor, typename R, typename... Args>
struct FunctionImpl : FunctionImplBase<Functor>
{
static R invoke(Args... args, const Storage& storage)
{
auto* functor = FunctionImplBase<Functor>::get_functor(&storage);
return asl::invoke(*functor, std::forward<Args>(args)...);
}
};
template<typename T, typename R, typename... Args>
concept valid_functor =
copy_constructible<T>
&& move_constructible<T>
&& invocable<T, Args...>
&& same_as<R, invoke_result_t<T, Args...>>;
} // namespace function_detail
template<typename T>
class function;
template<typename R, typename... Args>
class function<R(Args...)> // NOLINT(*-member-init)
{
using InvokeFn = R (*)(Args..., const function_detail::Storage&);
using OpFn = void (*)(function_detail::Storage*, function_detail::Storage*, function_detail::FunctionOp);
function_detail::Storage m_storage;
InvokeFn m_invoke{};
OpFn m_op{};
void destroy()
{
if (m_op != nullptr)
{
(*m_op)(&m_storage, nullptr, function_detail::FunctionOp::kDestroyThis);
}
}
public:
function() = default;
template<typename T>
function(T&& func) // NOLINT(*explicit*,*-member-init)
requires (
!same_as<function, un_cvref_t<T>>
&& function_detail::valid_functor<T, R, Args...>
)
{
using Functor = decay_t<T>;
using Impl = function_detail::FunctionImpl<Functor, R, Args...>;
Impl::create(&m_storage, std::forward<T>(func));
m_invoke = &Impl::invoke; // NOLINT(*-member-initializer)
m_op = &Impl::op; // NOLINT(*-member-initializer)
}
function(const function& other) // NOLINT(*-member-init)
: m_invoke{other.m_invoke}
, m_op{other.m_op}
{
if (m_op != nullptr)
{
(*m_op)(
&m_storage,
const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast)
function_detail::FunctionOp::kCopyFromOtherToThisUninit);
}
}
function(function&& other) // NOLINT(*-member-init)
: m_invoke{asl::exchange(other.m_invoke, nullptr)}
, m_op{asl::exchange(other.m_op, nullptr)}
{
if (m_op != nullptr)
{
(*m_op)(
&m_storage,
&other.m_storage,
function_detail::FunctionOp::kMoveFromOtherToThisUninit);
}
}
~function()
{
destroy();
}
function& operator=(const function& other)
{
if (this != &other)
{
destroy();
m_invoke = other.m_invoke;
m_op = other.m_op;
(*m_op)(
&m_storage,
const_cast<function_detail::Storage*>(&other.m_storage), // NOLINT(*-const-cast)
function_detail::FunctionOp::kCopyFromOtherToThisUninit);
}
return *this;
}
function& operator=(function&& other)
{
if (this != &other)
{
destroy();
m_invoke = asl::exchange(other.m_invoke, nullptr);
m_op = asl::exchange(other.m_op, nullptr);
(*m_op)(
&m_storage,
&other.m_storage,
function_detail::FunctionOp::kMoveFromOtherToThisUninit);
}
return *this;
}
template<typename T>
function& operator=(T&& func)
requires (
!same_as<function, un_cvref_t<T>>
&& function_detail::valid_functor<T, R, Args...>
)
{
destroy();
using Functor = decay_t<T>;
using Impl = function_detail::FunctionImpl<Functor, R, Args...>;
Impl::create(&m_storage, std::forward<T>(func));
m_invoke = &Impl::invoke;
m_op = &Impl::op;
return *this;
}
constexpr R operator()(Args... args) const
{
ASL_ASSERT(m_invoke);
return (*m_invoke)(args..., m_storage);
}
};
} // namespace asl

View File

@ -0,0 +1,179 @@
#include "asl/testing/testing.hpp"
#include "asl/types/function.hpp"
static_assert(asl::function_detail::can_be_stored_inline<int(*)(int, int, int)>);
static_assert(asl::function_detail::can_be_stored_inline<decltype([](){})>);
static_assert(asl::function_detail::can_be_stored_inline<decltype([]() static {})>);
static_assert(asl::function_detail::can_be_stored_inline<decltype([a = 1ULL, b = 2ULL](){ return a + b; })>); // NOLINT
static_assert(asl::function_detail::can_be_stored_inline<decltype([a = 1ULL, b = 2ULL]() mutable { return a = b++; })>); // NOLINT
static_assert(!asl::function_detail::can_be_stored_inline<decltype([a = 1ULL, b = 2ULL, c = 3ULL](){ return a + b + c; })>); // NOLINT
static int add(int x, int y)
{
return x + y;
}
ASL_TEST(function_pointer)
{
const asl::function<int(int, int)> fn = add;
ASL_TEST_EXPECT(fn(4, 5) == 9);
ASL_TEST_EXPECT(fn(40, 50) == 90);
}
ASL_TEST(lambda)
{
const asl::function<int(int, int)> fn = [](int a, int b) {
return a + b;
};
ASL_TEST_EXPECT(fn(4, 5) == 9);
}
ASL_TEST(lambda_static)
{
const asl::function<int(int, int)> fn = [](int a, int b) static {
return a + b;
};
ASL_TEST_EXPECT(fn(4, 5) == 9);
}
ASL_TEST(lambda_static_state)
{
const asl::function<int(int)> fn = [state = 0](int b) mutable {
state += b;
return state;
};
ASL_TEST_EXPECT(fn(1) == 1);
ASL_TEST_EXPECT(fn(2) == 3);
ASL_TEST_EXPECT(fn(3) == 6);
ASL_TEST_EXPECT(fn(4) == 10);
}
ASL_TEST(lambda_state)
{
int state = 0;
const asl::function<void(int)> fn = [&state](int x) {
state += x;
};
ASL_TEST_EXPECT(state == 0);
fn(5);
ASL_TEST_EXPECT(state == 5);
fn(4);
ASL_TEST_EXPECT(state == 9);
}
ASL_TEST(lambda_big_state)
{
int s0 = 0;
int s1 = 0;
int s2 = 0;
int s3 = 0;
const asl::function<void(int)> fn = [&](int x) {
s0 += x;
s1 += x + 1;
s2 += x + 2;
s3 += x + 3;
};
ASL_TEST_EXPECT(s0 == 0);
ASL_TEST_EXPECT(s1 == 0);
ASL_TEST_EXPECT(s2 == 0);
ASL_TEST_EXPECT(s3 == 0);
fn(5);
ASL_TEST_EXPECT(s0 == 5);
ASL_TEST_EXPECT(s1 == 6);
ASL_TEST_EXPECT(s2 == 7);
ASL_TEST_EXPECT(s3 == 8);
fn(4);
ASL_TEST_EXPECT(s0 == 9);
ASL_TEST_EXPECT(s1 == 11);
ASL_TEST_EXPECT(s2 == 13);
ASL_TEST_EXPECT(s3 == 15);
}
struct Functor
{
int state{};
int operator()(int x)
{
state += x;
return state;
}
};
ASL_TEST(functor)
{
const asl::function<int(int)> fn = Functor{};
ASL_TEST_EXPECT(fn(1) == 1);
ASL_TEST_EXPECT(fn(2) == 3);
ASL_TEST_EXPECT(fn(3) == 6);
ASL_TEST_EXPECT(fn(4) == 10);
}
ASL_TEST(copy_move_construct_small)
{
asl::function<int(int, int)> fn = [x = 0](int a, int b) mutable { return x++ + a + b; };
ASL_TEST_EXPECT(fn(1, 3) == 4);
asl::function<int(int, int)> fn2 = fn;
ASL_TEST_EXPECT(fn(1, 3) == 5);
ASL_TEST_EXPECT(fn2(5, 3) == 9);
asl::function<int(int, int)> fn3 = std::move(fn2);
ASL_TEST_EXPECT(fn(1, 3) == 6);
ASL_TEST_EXPECT(fn3(5, 3) == 10);
fn2 = fn;
ASL_TEST_EXPECT(fn(1, 3) == 7);
ASL_TEST_EXPECT(fn2(5, 3) == 11);
ASL_TEST_EXPECT(fn3(5, 3) == 11);
fn3 = std::move(fn);
ASL_TEST_EXPECT(fn2(5, 3) == 12);
ASL_TEST_EXPECT(fn3(5, 3) == 12);
}
ASL_TEST(copy_move_construct_big)
{
const int64_t v1 = 1;
const int64_t v2 = 2;
const int64_t v3 = 3;
const int64_t v4 = 4;
asl::function<int64_t(int)> fn = [=](int x) { return v1 + v2 + v3 + v4 + x; };
ASL_TEST_EXPECT(fn(1) == 11);
asl::function<int64_t(int)> fn2 = fn;
ASL_TEST_EXPECT(fn(3) == 13);
ASL_TEST_EXPECT(fn2(5) == 15);
asl::function<int64_t(int)> fn3 = std::move(fn2);
ASL_TEST_EXPECT(fn(1) == 11);
ASL_TEST_EXPECT(fn3(3) == 13);
fn2 = fn;
ASL_TEST_EXPECT(fn(1) == 11);
ASL_TEST_EXPECT(fn2(5) == 15);
ASL_TEST_EXPECT(fn3(3) == 13);
fn3 = std::move(fn);
ASL_TEST_EXPECT(fn2(5) == 15);
ASL_TEST_EXPECT(fn3(3) == 13);
}
ASL_TEST(replace)
{
asl::function<int(int)> fn = [](int x) { return x + 1; };
ASL_TEST_EXPECT(fn(5) == 6);
fn = [](int x) { return x + 3; };
ASL_TEST_EXPECT(fn(5) == 8);
}