From 74082720c42c5d6b06b71cefbad4b794ff1b8c3c Mon Sep 17 00:00:00 2001 From: Steven Le Rouzic Date: Sat, 18 Jan 2025 19:59:36 +0100 Subject: Finish the hash_map --- asl/BUILD.bazel | 2 + asl/hash_map.hpp | 77 ++++++++++++++-- asl/hash_set.hpp | 205 +++++++++++++++++++++++++------------------ asl/tests/hash_map_tests.cpp | 48 ++++++++++ asl/utility.hpp | 2 +- 5 files changed, 238 insertions(+), 96 deletions(-) create mode 100644 asl/tests/hash_map_tests.cpp (limited to 'asl') diff --git a/asl/BUILD.bazel b/asl/BUILD.bazel index 468ad58..b5bb68f 100644 --- a/asl/BUILD.bazel +++ b/asl/BUILD.bazel @@ -12,6 +12,7 @@ cc_library( "format.hpp", "functional.hpp", "hash.hpp", + "hash_map.hpp", "hash_set.hpp", "integers.hpp", "io.hpp", @@ -60,6 +61,7 @@ cc_library( "format", "functional", "hash", + "hash_map", "hash_set", "integers", "maybe_uninit", diff --git a/asl/hash_map.hpp b/asl/hash_map.hpp index 310b532..300ffdb 100644 --- a/asl/hash_map.hpp +++ b/asl/hash_map.hpp @@ -1,12 +1,9 @@ #pragma once -#include "asl/annotations.hpp" #include "asl/meta.hpp" #include "asl/utility.hpp" -#include "asl/maybe_uninit.hpp" #include "asl/hash.hpp" #include "asl/allocator.hpp" -#include "asl/memory.hpp" #include "asl/hash_set.hpp" namespace asl @@ -59,7 +56,7 @@ template< key_comparator KeyComparator = default_key_comparator > requires moveable && moveable -class hash_map : hash_set< +class hash_map : protected hash_set< hash_map_internal::Slot, Allocator, hash_map_internal::SlotHasher, @@ -95,11 +92,73 @@ public: using Base::size; - // @Todo insert - // @Todo contains - // @Todo remove - // @Todo get - // @Todo tests + using Base::remove; + + using Base::contains; + + template + requires + key_hasher && + key_comparator && + constructible_from && + constructible_from + void insert(U&& key, Arg0&& arg0, Args1&&... args1) + { + Base::maybe_grow_to_fit_one_more(); + + auto result = Base::find_slot_insert(key); + + // NOLINTBEGIN(*-pointer-arithmetic) + + ASL_ASSERT(result.first_available_index >= 0); + + if (result.already_present_index >= 0) + { + if (result.already_present_index != result.first_available_index) + { + ASL_ASSERT((Base::m_tags[result.first_available_index] & Base::kHasValue) == 0); + + Base::m_values[result.first_available_index].construct_unsafe(ASL_MOVE(Base::m_values[result.already_present_index].as_init_unsafe())); + Base::m_values[result.already_present_index].destroy_unsafe(); + + Base::m_tags[result.first_available_index] = result.tag; + Base::m_tags[result.already_present_index] = Base::kTombstone; + } + + ASL_ASSERT(Base::m_tags[result.first_available_index] == result.tag); + + if constexpr (sizeof...(Args1) == 0 && assignable_from) + { + Base::m_values[result.first_available_index].as_init_unsafe().value = ASL_FWD(arg0); + } + else + { + Base::m_values[result.first_available_index].as_init_unsafe().value = ASL_MOVE(V{ASL_FWD(arg0), ASL_FWD(args1)...}); + } + } + else + { + ASL_ASSERT((Base::m_tags[result.first_available_index] & Base::kHasValue) == 0); + Base::m_values[result.first_available_index].construct_unsafe(ASL_FWD(key), V{ASL_FWD(arg0), ASL_FWD(args1)...}); + Base::m_tags[result.first_available_index] = result.tag; + Base::m_size += 1; + } + + // NOLINTEND(*-pointer-arithmetic) + } + + template + requires key_hasher && key_comparator + V* get(const U& value) const + { + isize_t index = Base::find_slot_lookup(value); + if (index >= 0) + { + // NOLINTNEXTLINE(*-pointer-arithmetic) + return &Base::m_values[index].as_init_unsafe().value; + } + return nullptr; + } }; } // namespace asl diff --git a/asl/hash_set.hpp b/asl/hash_set.hpp index c3fb38d..979235d 100644 --- a/asl/hash_set.hpp +++ b/asl/hash_set.hpp @@ -50,6 +50,7 @@ template< requires moveable class hash_set { +protected: static constexpr uint8_t kHasValue = 0x80; static constexpr uint8_t kHashMask = 0x7f; static constexpr uint8_t kEmpty = 0x00; @@ -80,7 +81,7 @@ class hash_set kMinCapacity, static_cast(round_up_pow2((static_cast(size) * 4 + 2) / 3))); } - + static void insert_inner( T&& value, uint8_t* tags, @@ -89,102 +90,33 @@ class hash_set isize_t* size) { ASL_ASSERT(*size < capacity); - ASL_ASSERT(is_pow2(capacity)); - - const isize_t capacity_mask = capacity - 1; - const uint64_t hash = KeyHasher::hash(value); - const uint8_t tag = static_cast(hash & kHashMask) | kHasValue; - const auto starting_index = static_cast(hash >> 7) & capacity_mask; - isize_t first_available_index = -1; - isize_t already_present_index = -1; + const auto result = find_slot_insert(value, tags, values, capacity); // NOLINTBEGIN(*-pointer-arithmetic) - for ( - isize_t i = starting_index; - i != starting_index || first_available_index < 0; - i = (i + 1) & capacity_mask) - { - uint8_t t = tags[i]; - - if ((t & kHasValue) == 0 && first_available_index < 0) - { - first_available_index = i; - } + ASL_ASSERT(result.first_available_index >= 0); - if (t == tag && KeyComparator::eq(values[i].as_init_unsafe(), value)) - { - ASL_ASSERT(already_present_index < 0); - already_present_index = i; - if (first_available_index < 0) - { - first_available_index = i; - } - break; - } - - if (t == kEmpty) { break; } - } - - ASL_ASSERT(first_available_index >= 0 || already_present_index >= 0); - - if (already_present_index == first_available_index) + if (result.already_present_index != result.first_available_index) { - ASL_ASSERT((tags[already_present_index] & kHasValue) != 0); - values[already_present_index].assign_unsafe(ASL_MOVE(value)); - } - else - { - ASL_ASSERT((tags[first_available_index] & kHasValue) == 0); - if (already_present_index >= 0) + ASL_ASSERT((tags[result.first_available_index] & kHasValue) == 0); + if (result.already_present_index >= 0) { - ASL_ASSERT((tags[already_present_index] & kHasValue) != 0); - values[already_present_index].destroy_unsafe(); - tags[already_present_index] = kTombstone; + ASL_ASSERT((tags[result.already_present_index] & kHasValue) != 0); + values[result.already_present_index].destroy_unsafe(); + tags[result.already_present_index] = kTombstone; } else { *size += 1; } - values[first_available_index].construct_unsafe(ASL_MOVE(value)); - tags[first_available_index] = tag; + values[result.first_available_index].construct_unsafe(ASL_MOVE(value)); + tags[result.first_available_index] = result.tag; } // NOLINTEND(*-pointer-arithmetic) } - - template - requires key_hasher && key_comparator - isize_t find_slot(const U& value) const - { - if (m_size <= 0) { return -1; }; - - ASL_ASSERT(is_pow2(m_capacity)); - - const isize_t capacity_mask = m_capacity - 1; - const uint64_t hash = KeyHasher::hash(value); - const uint8_t tag = static_cast(hash & kHashMask) | kHasValue; - const auto starting_index = static_cast(hash >> 7) & capacity_mask; - - // NOLINTBEGIN(*-pointer-arithmetic) - - isize_t i = starting_index; - do - { - const uint8_t t = m_tags[i]; - - if (t == tag && KeyComparator::eq(m_values[i].as_init_unsafe(), value)) { return i; } - if (t == kEmpty) { break; } - - i = (i + 1) & capacity_mask; - } while (i != starting_index); - - // NOLINTEND(*-pointer-arithmetic) - - return -1; - } void grow_and_rehash() { @@ -266,6 +198,110 @@ class hash_set } } + struct FindSlotResult + { + uint8_t tag{}; + isize_t first_available_index = -1; + isize_t already_present_index = -1; + }; + + template + requires key_hasher && key_comparator + static FindSlotResult find_slot_insert( + const U& value, + const uint8_t* tags, + const maybe_uninit* values, + isize_t capacity) + { + ASL_ASSERT(is_pow2(capacity)); + + FindSlotResult result{}; + + const isize_t capacity_mask = capacity - 1; + const uint64_t hash = KeyHasher::hash(value); + const auto starting_index = static_cast(hash >> 7) & capacity_mask; + + result.tag = static_cast(hash & kHashMask) | kHasValue; + + // NOLINTBEGIN(*-pointer-arithmetic) + + for ( + isize_t i = starting_index; + i != starting_index || result.first_available_index < 0; + i = (i + 1) & capacity_mask) + { + uint8_t t = tags[i]; + + if ((t & kHasValue) == 0 && result.first_available_index < 0) + { + result.first_available_index = i; + } + + if (t == result.tag && KeyComparator::eq(values[i].as_init_unsafe(), value)) + { + ASL_ASSERT(result.already_present_index < 0); + result.already_present_index = i; + if (result.first_available_index < 0) + { + result.first_available_index = i; + } + break; + } + + if (t == kEmpty) { break; } + } + + // NOLINTEND(*-pointer-arithmetic) + + return result; + } + + template + requires key_hasher && key_comparator + isize_t find_slot_lookup(const U& value) const + { + if (m_size <= 0) { return -1; }; + + ASL_ASSERT(is_pow2(m_capacity)); + + const isize_t capacity_mask = m_capacity - 1; + const uint64_t hash = KeyHasher::hash(value); + const uint8_t tag = static_cast(hash & kHashMask) | kHasValue; + const auto starting_index = static_cast(hash >> 7) & capacity_mask; + + // NOLINTBEGIN(*-pointer-arithmetic) + + isize_t i = starting_index; + do + { + const uint8_t t = m_tags[i]; + + if (t == tag && KeyComparator::eq(m_values[i].as_init_unsafe(), value)) { return i; } + if (t == kEmpty) { break; } + + i = (i + 1) & capacity_mask; + } while (i != starting_index); + + // NOLINTEND(*-pointer-arithmetic) + + return -1; + } + + template + requires key_hasher && key_comparator + FindSlotResult find_slot_insert(const U& value) + { + return find_slot_insert(value, m_tags, m_values, m_capacity); + } + + void maybe_grow_to_fit_one_more() + { + if (m_size >= max_size()) + { + grow_and_rehash(); + } + } + public: constexpr hash_set() requires default_constructible : m_allocator{} @@ -351,10 +387,7 @@ public: void insert(Args&&... args) requires constructible_from { - if (m_size >= max_size()) - { - grow_and_rehash(); - } + maybe_grow_to_fit_one_more(); ASL_ASSERT(m_size < max_size()); insert_inner(ASL_MOVE(T{ASL_FWD(args)...}), m_tags, m_values, m_capacity, &m_size); } @@ -363,14 +396,14 @@ public: requires key_hasher && key_comparator bool contains(const U& value) const { - return find_slot(value) >= 0; + return find_slot_lookup(value) >= 0; } template requires key_hasher && key_comparator bool remove(const U& value) { - isize_t slot = find_slot(value); + isize_t slot = find_slot_lookup(value); if (slot < 0) { return false; } m_values[slot].destroy_unsafe(); // NOLINT(*-pointer-arithmetic) diff --git a/asl/tests/hash_map_tests.cpp b/asl/tests/hash_map_tests.cpp new file mode 100644 index 0000000..53c419c --- /dev/null +++ b/asl/tests/hash_map_tests.cpp @@ -0,0 +1,48 @@ +#include "asl/testing/testing.hpp" +#include "asl/hash_map.hpp" + +ASL_TEST(default) +{ + asl::hash_map map; + + ASL_TEST_EXPECT(!map.contains(45)); + ASL_TEST_EXPECT(!map.contains(46)); + + map.insert(45, 5); + map.insert(46, 6); + + ASL_TEST_EXPECT(map.size() == 2); + + ASL_TEST_EXPECT(map.contains(45)); + ASL_TEST_EXPECT(map.contains(46)); + ASL_TEST_EXPECT(!map.contains(47)); + + ASL_TEST_EXPECT(*map.get(45) == 5); + ASL_TEST_EXPECT(*map.get(46) == 6); + ASL_TEST_EXPECT(map.get(47) == nullptr); + + ASL_TEST_EXPECT(map.remove(45)); + ASL_TEST_EXPECT(!map.remove(45)); + + ASL_TEST_EXPECT(map.size() == 1); + + ASL_TEST_EXPECT(!map.contains(45)); + ASL_TEST_EXPECT(map.contains(46)); + ASL_TEST_EXPECT(!map.contains(47)); + + ASL_TEST_EXPECT(map.get(45) == nullptr); + ASL_TEST_EXPECT(*map.get(46) == 6); + ASL_TEST_EXPECT(map.get(47) == nullptr); + + map.insert(46, 460); + + ASL_TEST_EXPECT(map.size() == 1); + + ASL_TEST_EXPECT(!map.contains(45)); + ASL_TEST_EXPECT(map.contains(46)); + ASL_TEST_EXPECT(!map.contains(47)); + + ASL_TEST_EXPECT(map.get(45) == nullptr); + ASL_TEST_EXPECT(*map.get(46) == 460); + ASL_TEST_EXPECT(map.get(47) == nullptr); +} diff --git a/asl/utility.hpp b/asl/utility.hpp index 6a43852..205b583 100644 --- a/asl/utility.hpp +++ b/asl/utility.hpp @@ -4,7 +4,7 @@ #include "asl/layout.hpp" #include "asl/assert.hpp" -#define ASL_MOVE(expr_) (static_cast<::asl::un_ref_t&&>(expr_)) +#define ASL_MOVE(...) (static_cast<::asl::un_ref_t&&>(__VA_ARGS__)) #define ASL_FWD(expr_) (static_cast(expr_)) -- cgit