From 74082720c42c5d6b06b71cefbad4b794ff1b8c3c Mon Sep 17 00:00:00 2001
From: Steven Le Rouzic <steven.lerouzic@gmail.com>
Date: Sat, 18 Jan 2025 19:59:36 +0100
Subject: Finish the hash_map

---
 asl/hash_set.hpp | 205 ++++++++++++++++++++++++++++++++-----------------------
 1 file changed, 119 insertions(+), 86 deletions(-)

(limited to 'asl/hash_set.hpp')

diff --git a/asl/hash_set.hpp b/asl/hash_set.hpp
index c3fb38d..979235d 100644
--- a/asl/hash_set.hpp
+++ b/asl/hash_set.hpp
@@ -50,6 +50,7 @@ template<
 requires moveable<T>
 class hash_set
 {
+protected:
     static constexpr uint8_t kHasValue  = 0x80;
     static constexpr uint8_t kHashMask  = 0x7f;
     static constexpr uint8_t kEmpty     = 0x00;
@@ -80,7 +81,7 @@ class hash_set
             kMinCapacity,
             static_cast<isize_t>(round_up_pow2((static_cast<uint64_t>(size) * 4 + 2) / 3)));
     }
-
+    
     static void insert_inner(
         T&& value,
         uint8_t* tags,
@@ -89,102 +90,33 @@ class hash_set
         isize_t* size)
     {
         ASL_ASSERT(*size < capacity);
-        ASL_ASSERT(is_pow2(capacity));
-
-        const isize_t capacity_mask = capacity - 1;
-        const uint64_t hash = KeyHasher::hash(value);
-        const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
-        const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
 
-        isize_t first_available_index = -1;
-        isize_t already_present_index = -1;
+        const auto result = find_slot_insert(value, tags, values, capacity);
         
         // NOLINTBEGIN(*-pointer-arithmetic)
 
-        for (
-            isize_t i = starting_index;
-            i != starting_index || first_available_index < 0;
-            i = (i + 1) & capacity_mask)
-        {
-            uint8_t t = tags[i];
-            
-            if ((t & kHasValue) == 0 && first_available_index < 0)
-            {
-                first_available_index = i;
-            }
+        ASL_ASSERT(result.first_available_index >= 0);
 
-            if (t == tag && KeyComparator::eq(values[i].as_init_unsafe(), value))
-            {
-                ASL_ASSERT(already_present_index < 0);
-                already_present_index = i;
-                if (first_available_index < 0)
-                {
-                    first_available_index = i;
-                }
-                break;
-            }
-
-            if (t == kEmpty) { break; }
-        }
-
-        ASL_ASSERT(first_available_index >= 0 || already_present_index >= 0);
-
-        if (already_present_index == first_available_index)
+        if (result.already_present_index != result.first_available_index)
         {
-            ASL_ASSERT((tags[already_present_index] & kHasValue) != 0);
-            values[already_present_index].assign_unsafe(ASL_MOVE(value));
-        }
-        else
-        {
-            ASL_ASSERT((tags[first_available_index] & kHasValue) == 0);
-            if (already_present_index >= 0)
+            ASL_ASSERT((tags[result.first_available_index] & kHasValue) == 0);
+            if (result.already_present_index >= 0)
             {
-                ASL_ASSERT((tags[already_present_index] & kHasValue) != 0);
-                values[already_present_index].destroy_unsafe();
-                tags[already_present_index] = kTombstone;
+                ASL_ASSERT((tags[result.already_present_index] & kHasValue) != 0);
+                values[result.already_present_index].destroy_unsafe();
+                tags[result.already_present_index] = kTombstone;
             }
             else
             {
                 *size += 1;
             }
 
-            values[first_available_index].construct_unsafe(ASL_MOVE(value));
-            tags[first_available_index] = tag;
+            values[result.first_available_index].construct_unsafe(ASL_MOVE(value));
+            tags[result.first_available_index] = result.tag;
         }
         
         // NOLINTEND(*-pointer-arithmetic)
     }
-
-    template<typename U>
-    requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
-    isize_t find_slot(const U& value) const
-    {
-        if (m_size <= 0) { return -1; };
-
-        ASL_ASSERT(is_pow2(m_capacity));
-
-        const isize_t capacity_mask = m_capacity - 1;
-        const uint64_t hash = KeyHasher::hash(value);
-        const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
-        const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
-
-        // NOLINTBEGIN(*-pointer-arithmetic)
-
-        isize_t i = starting_index;
-        do
-        {
-            const uint8_t t = m_tags[i];
-            
-            if (t == tag && KeyComparator::eq(m_values[i].as_init_unsafe(), value)) { return i; }
-            if (t == kEmpty) { break; }
-            
-            i = (i + 1) & capacity_mask;
-        } while (i != starting_index);
-        
-        // NOLINTEND(*-pointer-arithmetic)
-
-        return -1;
-    }
     
     void grow_and_rehash()
     {
@@ -266,6 +198,110 @@ class hash_set
         }
     }
 
+    struct FindSlotResult
+    {
+        uint8_t tag{};
+        isize_t first_available_index = -1;
+        isize_t already_present_index = -1;
+    };
+
+    template<typename U>
+    requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
+    static FindSlotResult find_slot_insert(
+        const U& value,
+        const uint8_t* tags,
+        const maybe_uninit<T>* values,
+        isize_t capacity)
+    {
+        ASL_ASSERT(is_pow2(capacity));
+
+        FindSlotResult result{};
+
+        const isize_t capacity_mask = capacity - 1;
+        const uint64_t hash = KeyHasher::hash(value);
+        const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
+
+        result.tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
+        
+        // NOLINTBEGIN(*-pointer-arithmetic)
+
+        for (
+            isize_t i = starting_index;
+            i != starting_index || result.first_available_index < 0;
+            i = (i + 1) & capacity_mask)
+        {
+            uint8_t t = tags[i];
+            
+            if ((t & kHasValue) == 0 && result.first_available_index < 0)
+            {
+                result.first_available_index = i;
+            }
+
+            if (t == result.tag && KeyComparator::eq(values[i].as_init_unsafe(), value))
+            {
+                ASL_ASSERT(result.already_present_index < 0);
+                result.already_present_index = i;
+                if (result.first_available_index < 0)
+                {
+                    result.first_available_index = i;
+                }
+                break;
+            }
+
+            if (t == kEmpty) { break; }
+        }
+        
+        // NOLINTEND(*-pointer-arithmetic)
+
+        return result;
+    }
+
+    template<typename U>
+    requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
+    isize_t find_slot_lookup(const U& value) const
+    {
+        if (m_size <= 0) { return -1; };
+
+        ASL_ASSERT(is_pow2(m_capacity));
+
+        const isize_t capacity_mask = m_capacity - 1;
+        const uint64_t hash = KeyHasher::hash(value);
+        const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
+        const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
+
+        // NOLINTBEGIN(*-pointer-arithmetic)
+
+        isize_t i = starting_index;
+        do
+        {
+            const uint8_t t = m_tags[i];
+            
+            if (t == tag && KeyComparator::eq(m_values[i].as_init_unsafe(), value)) { return i; }
+            if (t == kEmpty) { break; }
+            
+            i = (i + 1) & capacity_mask;
+        } while (i != starting_index);
+        
+        // NOLINTEND(*-pointer-arithmetic)
+
+        return -1;
+    }
+
+    template<typename U>
+    requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
+    FindSlotResult find_slot_insert(const U& value)
+    {
+        return find_slot_insert(value, m_tags, m_values, m_capacity);
+    }
+
+    void maybe_grow_to_fit_one_more()
+    {
+        if (m_size >= max_size())
+        {
+            grow_and_rehash();
+        }
+    }
+    
 public:
     constexpr hash_set() requires default_constructible<Allocator>
         : m_allocator{}
@@ -351,10 +387,7 @@ public:
     void insert(Args&&... args)
         requires constructible_from<T, Args&&...>
     {
-        if (m_size >= max_size())
-        {
-            grow_and_rehash();
-        }
+        maybe_grow_to_fit_one_more();
         ASL_ASSERT(m_size < max_size());
         insert_inner(ASL_MOVE(T{ASL_FWD(args)...}), m_tags, m_values, m_capacity, &m_size);
     }
@@ -363,14 +396,14 @@ public:
     requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
     bool contains(const U& value) const
     {
-        return find_slot(value) >= 0;
+        return find_slot_lookup(value) >= 0;
     }
 
     template<typename U>
     requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
     bool remove(const U& value)
     {
-        isize_t slot = find_slot(value);
+        isize_t slot = find_slot_lookup(value);
         if (slot < 0) { return false; }
 
         m_values[slot].destroy_unsafe(); // NOLINT(*-pointer-arithmetic)
-- 
cgit