summaryrefslogtreecommitdiff
path: root/asl
diff options
context:
space:
mode:
authorSteven Le Rouzic <steven.lerouzic@gmail.com>2025-01-18 19:59:36 +0100
committerSteven Le Rouzic <steven.lerouzic@gmail.com>2025-01-18 19:59:36 +0100
commit74082720c42c5d6b06b71cefbad4b794ff1b8c3c (patch)
treedc7dc49959b0bcc5e2980a950adcf89273d1c2c3 /asl
parent41454a09c6d73fcecffc1f7d6e3754c60cc49e31 (diff)
Finish the hash_map
Diffstat (limited to 'asl')
-rw-r--r--asl/BUILD.bazel2
-rw-r--r--asl/hash_map.hpp77
-rw-r--r--asl/hash_set.hpp205
-rw-r--r--asl/tests/hash_map_tests.cpp48
-rw-r--r--asl/utility.hpp2
5 files changed, 238 insertions, 96 deletions
diff --git a/asl/BUILD.bazel b/asl/BUILD.bazel
index 468ad58..b5bb68f 100644
--- a/asl/BUILD.bazel
+++ b/asl/BUILD.bazel
@@ -12,6 +12,7 @@ cc_library(
"format.hpp",
"functional.hpp",
"hash.hpp",
+ "hash_map.hpp",
"hash_set.hpp",
"integers.hpp",
"io.hpp",
@@ -60,6 +61,7 @@ cc_library(
"format",
"functional",
"hash",
+ "hash_map",
"hash_set",
"integers",
"maybe_uninit",
diff --git a/asl/hash_map.hpp b/asl/hash_map.hpp
index 310b532..300ffdb 100644
--- a/asl/hash_map.hpp
+++ b/asl/hash_map.hpp
@@ -1,12 +1,9 @@
#pragma once
-#include "asl/annotations.hpp"
#include "asl/meta.hpp"
#include "asl/utility.hpp"
-#include "asl/maybe_uninit.hpp"
#include "asl/hash.hpp"
#include "asl/allocator.hpp"
-#include "asl/memory.hpp"
#include "asl/hash_set.hpp"
namespace asl
@@ -59,7 +56,7 @@ template<
key_comparator<K> KeyComparator = default_key_comparator<K>
>
requires moveable<K> && moveable<V>
-class hash_map : hash_set<
+class hash_map : protected hash_set<
hash_map_internal::Slot<K, V>,
Allocator,
hash_map_internal::SlotHasher<K, V, KeyHasher>,
@@ -95,11 +92,73 @@ public:
using Base::size;
- // @Todo insert
- // @Todo contains
- // @Todo remove
- // @Todo get
- // @Todo tests
+ using Base::remove;
+
+ using Base::contains;
+
+ template<typename U, typename Arg0, typename... Args1>
+ requires
+ key_hasher<KeyHasher, U> &&
+ key_comparator<KeyComparator, K, U> &&
+ constructible_from<K, U&&> &&
+ constructible_from<V, Arg0&&, Args1&&...>
+ void insert(U&& key, Arg0&& arg0, Args1&&... args1)
+ {
+ Base::maybe_grow_to_fit_one_more();
+
+ auto result = Base::find_slot_insert(key);
+
+ // NOLINTBEGIN(*-pointer-arithmetic)
+
+ ASL_ASSERT(result.first_available_index >= 0);
+
+ if (result.already_present_index >= 0)
+ {
+ if (result.already_present_index != result.first_available_index)
+ {
+ ASL_ASSERT((Base::m_tags[result.first_available_index] & Base::kHasValue) == 0);
+
+ Base::m_values[result.first_available_index].construct_unsafe(ASL_MOVE(Base::m_values[result.already_present_index].as_init_unsafe()));
+ Base::m_values[result.already_present_index].destroy_unsafe();
+
+ Base::m_tags[result.first_available_index] = result.tag;
+ Base::m_tags[result.already_present_index] = Base::kTombstone;
+ }
+
+ ASL_ASSERT(Base::m_tags[result.first_available_index] == result.tag);
+
+ if constexpr (sizeof...(Args1) == 0 && assignable_from<V&, Arg0&&>)
+ {
+ Base::m_values[result.first_available_index].as_init_unsafe().value = ASL_FWD(arg0);
+ }
+ else
+ {
+ Base::m_values[result.first_available_index].as_init_unsafe().value = ASL_MOVE(V{ASL_FWD(arg0), ASL_FWD(args1)...});
+ }
+ }
+ else
+ {
+ ASL_ASSERT((Base::m_tags[result.first_available_index] & Base::kHasValue) == 0);
+ Base::m_values[result.first_available_index].construct_unsafe(ASL_FWD(key), V{ASL_FWD(arg0), ASL_FWD(args1)...});
+ Base::m_tags[result.first_available_index] = result.tag;
+ Base::m_size += 1;
+ }
+
+ // NOLINTEND(*-pointer-arithmetic)
+ }
+
+ template<typename U>
+ requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, K, U>
+ V* get(const U& value) const
+ {
+ isize_t index = Base::find_slot_lookup(value);
+ if (index >= 0)
+ {
+ // NOLINTNEXTLINE(*-pointer-arithmetic)
+ return &Base::m_values[index].as_init_unsafe().value;
+ }
+ return nullptr;
+ }
};
} // namespace asl
diff --git a/asl/hash_set.hpp b/asl/hash_set.hpp
index c3fb38d..979235d 100644
--- a/asl/hash_set.hpp
+++ b/asl/hash_set.hpp
@@ -50,6 +50,7 @@ template<
requires moveable<T>
class hash_set
{
+protected:
static constexpr uint8_t kHasValue = 0x80;
static constexpr uint8_t kHashMask = 0x7f;
static constexpr uint8_t kEmpty = 0x00;
@@ -80,7 +81,7 @@ class hash_set
kMinCapacity,
static_cast<isize_t>(round_up_pow2((static_cast<uint64_t>(size) * 4 + 2) / 3)));
}
-
+
static void insert_inner(
T&& value,
uint8_t* tags,
@@ -89,102 +90,33 @@ class hash_set
isize_t* size)
{
ASL_ASSERT(*size < capacity);
- ASL_ASSERT(is_pow2(capacity));
-
- const isize_t capacity_mask = capacity - 1;
- const uint64_t hash = KeyHasher::hash(value);
- const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
- const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
- isize_t first_available_index = -1;
- isize_t already_present_index = -1;
+ const auto result = find_slot_insert(value, tags, values, capacity);
// NOLINTBEGIN(*-pointer-arithmetic)
- for (
- isize_t i = starting_index;
- i != starting_index || first_available_index < 0;
- i = (i + 1) & capacity_mask)
- {
- uint8_t t = tags[i];
-
- if ((t & kHasValue) == 0 && first_available_index < 0)
- {
- first_available_index = i;
- }
+ ASL_ASSERT(result.first_available_index >= 0);
- if (t == tag && KeyComparator::eq(values[i].as_init_unsafe(), value))
- {
- ASL_ASSERT(already_present_index < 0);
- already_present_index = i;
- if (first_available_index < 0)
- {
- first_available_index = i;
- }
- break;
- }
-
- if (t == kEmpty) { break; }
- }
-
- ASL_ASSERT(first_available_index >= 0 || already_present_index >= 0);
-
- if (already_present_index == first_available_index)
+ if (result.already_present_index != result.first_available_index)
{
- ASL_ASSERT((tags[already_present_index] & kHasValue) != 0);
- values[already_present_index].assign_unsafe(ASL_MOVE(value));
- }
- else
- {
- ASL_ASSERT((tags[first_available_index] & kHasValue) == 0);
- if (already_present_index >= 0)
+ ASL_ASSERT((tags[result.first_available_index] & kHasValue) == 0);
+ if (result.already_present_index >= 0)
{
- ASL_ASSERT((tags[already_present_index] & kHasValue) != 0);
- values[already_present_index].destroy_unsafe();
- tags[already_present_index] = kTombstone;
+ ASL_ASSERT((tags[result.already_present_index] & kHasValue) != 0);
+ values[result.already_present_index].destroy_unsafe();
+ tags[result.already_present_index] = kTombstone;
}
else
{
*size += 1;
}
- values[first_available_index].construct_unsafe(ASL_MOVE(value));
- tags[first_available_index] = tag;
+ values[result.first_available_index].construct_unsafe(ASL_MOVE(value));
+ tags[result.first_available_index] = result.tag;
}
// NOLINTEND(*-pointer-arithmetic)
}
-
- template<typename U>
- requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
- isize_t find_slot(const U& value) const
- {
- if (m_size <= 0) { return -1; };
-
- ASL_ASSERT(is_pow2(m_capacity));
-
- const isize_t capacity_mask = m_capacity - 1;
- const uint64_t hash = KeyHasher::hash(value);
- const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
- const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
-
- // NOLINTBEGIN(*-pointer-arithmetic)
-
- isize_t i = starting_index;
- do
- {
- const uint8_t t = m_tags[i];
-
- if (t == tag && KeyComparator::eq(m_values[i].as_init_unsafe(), value)) { return i; }
- if (t == kEmpty) { break; }
-
- i = (i + 1) & capacity_mask;
- } while (i != starting_index);
-
- // NOLINTEND(*-pointer-arithmetic)
-
- return -1;
- }
void grow_and_rehash()
{
@@ -266,6 +198,110 @@ class hash_set
}
}
+ struct FindSlotResult
+ {
+ uint8_t tag{};
+ isize_t first_available_index = -1;
+ isize_t already_present_index = -1;
+ };
+
+ template<typename U>
+ requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
+ static FindSlotResult find_slot_insert(
+ const U& value,
+ const uint8_t* tags,
+ const maybe_uninit<T>* values,
+ isize_t capacity)
+ {
+ ASL_ASSERT(is_pow2(capacity));
+
+ FindSlotResult result{};
+
+ const isize_t capacity_mask = capacity - 1;
+ const uint64_t hash = KeyHasher::hash(value);
+ const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
+
+ result.tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
+
+ // NOLINTBEGIN(*-pointer-arithmetic)
+
+ for (
+ isize_t i = starting_index;
+ i != starting_index || result.first_available_index < 0;
+ i = (i + 1) & capacity_mask)
+ {
+ uint8_t t = tags[i];
+
+ if ((t & kHasValue) == 0 && result.first_available_index < 0)
+ {
+ result.first_available_index = i;
+ }
+
+ if (t == result.tag && KeyComparator::eq(values[i].as_init_unsafe(), value))
+ {
+ ASL_ASSERT(result.already_present_index < 0);
+ result.already_present_index = i;
+ if (result.first_available_index < 0)
+ {
+ result.first_available_index = i;
+ }
+ break;
+ }
+
+ if (t == kEmpty) { break; }
+ }
+
+ // NOLINTEND(*-pointer-arithmetic)
+
+ return result;
+ }
+
+ template<typename U>
+ requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
+ isize_t find_slot_lookup(const U& value) const
+ {
+ if (m_size <= 0) { return -1; };
+
+ ASL_ASSERT(is_pow2(m_capacity));
+
+ const isize_t capacity_mask = m_capacity - 1;
+ const uint64_t hash = KeyHasher::hash(value);
+ const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
+ const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
+
+ // NOLINTBEGIN(*-pointer-arithmetic)
+
+ isize_t i = starting_index;
+ do
+ {
+ const uint8_t t = m_tags[i];
+
+ if (t == tag && KeyComparator::eq(m_values[i].as_init_unsafe(), value)) { return i; }
+ if (t == kEmpty) { break; }
+
+ i = (i + 1) & capacity_mask;
+ } while (i != starting_index);
+
+ // NOLINTEND(*-pointer-arithmetic)
+
+ return -1;
+ }
+
+ template<typename U>
+ requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
+ FindSlotResult find_slot_insert(const U& value)
+ {
+ return find_slot_insert(value, m_tags, m_values, m_capacity);
+ }
+
+ void maybe_grow_to_fit_one_more()
+ {
+ if (m_size >= max_size())
+ {
+ grow_and_rehash();
+ }
+ }
+
public:
constexpr hash_set() requires default_constructible<Allocator>
: m_allocator{}
@@ -351,10 +387,7 @@ public:
void insert(Args&&... args)
requires constructible_from<T, Args&&...>
{
- if (m_size >= max_size())
- {
- grow_and_rehash();
- }
+ maybe_grow_to_fit_one_more();
ASL_ASSERT(m_size < max_size());
insert_inner(ASL_MOVE(T{ASL_FWD(args)...}), m_tags, m_values, m_capacity, &m_size);
}
@@ -363,14 +396,14 @@ public:
requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
bool contains(const U& value) const
{
- return find_slot(value) >= 0;
+ return find_slot_lookup(value) >= 0;
}
template<typename U>
requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
bool remove(const U& value)
{
- isize_t slot = find_slot(value);
+ isize_t slot = find_slot_lookup(value);
if (slot < 0) { return false; }
m_values[slot].destroy_unsafe(); // NOLINT(*-pointer-arithmetic)
diff --git a/asl/tests/hash_map_tests.cpp b/asl/tests/hash_map_tests.cpp
new file mode 100644
index 0000000..53c419c
--- /dev/null
+++ b/asl/tests/hash_map_tests.cpp
@@ -0,0 +1,48 @@
+#include "asl/testing/testing.hpp"
+#include "asl/hash_map.hpp"
+
+ASL_TEST(default)
+{
+ asl::hash_map<int, int> map;
+
+ ASL_TEST_EXPECT(!map.contains(45));
+ ASL_TEST_EXPECT(!map.contains(46));
+
+ map.insert(45, 5);
+ map.insert(46, 6);
+
+ ASL_TEST_EXPECT(map.size() == 2);
+
+ ASL_TEST_EXPECT(map.contains(45));
+ ASL_TEST_EXPECT(map.contains(46));
+ ASL_TEST_EXPECT(!map.contains(47));
+
+ ASL_TEST_EXPECT(*map.get(45) == 5);
+ ASL_TEST_EXPECT(*map.get(46) == 6);
+ ASL_TEST_EXPECT(map.get(47) == nullptr);
+
+ ASL_TEST_EXPECT(map.remove(45));
+ ASL_TEST_EXPECT(!map.remove(45));
+
+ ASL_TEST_EXPECT(map.size() == 1);
+
+ ASL_TEST_EXPECT(!map.contains(45));
+ ASL_TEST_EXPECT(map.contains(46));
+ ASL_TEST_EXPECT(!map.contains(47));
+
+ ASL_TEST_EXPECT(map.get(45) == nullptr);
+ ASL_TEST_EXPECT(*map.get(46) == 6);
+ ASL_TEST_EXPECT(map.get(47) == nullptr);
+
+ map.insert(46, 460);
+
+ ASL_TEST_EXPECT(map.size() == 1);
+
+ ASL_TEST_EXPECT(!map.contains(45));
+ ASL_TEST_EXPECT(map.contains(46));
+ ASL_TEST_EXPECT(!map.contains(47));
+
+ ASL_TEST_EXPECT(map.get(45) == nullptr);
+ ASL_TEST_EXPECT(*map.get(46) == 460);
+ ASL_TEST_EXPECT(map.get(47) == nullptr);
+}
diff --git a/asl/utility.hpp b/asl/utility.hpp
index 6a43852..205b583 100644
--- a/asl/utility.hpp
+++ b/asl/utility.hpp
@@ -4,7 +4,7 @@
#include "asl/layout.hpp"
#include "asl/assert.hpp"
-#define ASL_MOVE(expr_) (static_cast<::asl::un_ref_t<decltype(expr_)>&&>(expr_))
+#define ASL_MOVE(...) (static_cast<::asl::un_ref_t<decltype(__VA_ARGS__)>&&>(__VA_ARGS__))
#define ASL_FWD(expr_) (static_cast<decltype(expr_)&&>(expr_))