From 83b856b7d42deba868608f323a3cec4ae6a17d90 Mon Sep 17 00:00:00 2001 From: Steven Le Rouzic Date: Tue, 14 Jan 2025 23:30:53 +0100 Subject: Add custom hasher & comparator for hash_set keys --- asl/hash_set.hpp | 61 ++++++++++++++++++++++++++++++++++++-------- asl/tests/hash_set_tests.cpp | 36 +++++++++++++++++++------- 2 files changed, 77 insertions(+), 20 deletions(-) (limited to 'asl') diff --git a/asl/hash_set.hpp b/asl/hash_set.hpp index fa95a96..d48a865 100644 --- a/asl/hash_set.hpp +++ b/asl/hash_set.hpp @@ -11,8 +11,43 @@ namespace asl { -template -requires hashable && move_constructible && move_assignable && equality_comparable +template +concept key_hasher = requires (const T& value) +{ + { H::hash(value) } -> same_as; +}; + +template +struct default_key_hasher +{ + constexpr static uint64_t hash(const T& value) + { + return hash_value(value); + } +}; + +template +concept key_comparator = requires(const U& a, const V& b) +{ + { C::eq(a, b) } -> same_as; +}; + +template +struct default_key_comparator +{ + constexpr static bool eq(const T& a, const T& b) + { + return a == b; + } +}; + +template< + is_object T, + allocator Allocator = DefaultAllocator, + key_hasher KeyHasher = default_key_hasher, + key_comparator KeyComparator = default_key_comparator +> +requires move_constructible && move_assignable class hash_set { static constexpr uint8_t kHasValue = 0x80; @@ -49,7 +84,7 @@ class hash_set ASL_ASSERT(is_pow2(capacity)); const isize_t capacity_mask = capacity - 1; - const uint64_t hash = hash_value(value); + const uint64_t hash = KeyHasher::hash(value); const uint8_t tag = static_cast(hash & kHashMask) | kHasValue; const auto starting_index = static_cast(hash >> 7) & capacity_mask; @@ -70,7 +105,7 @@ class hash_set first_available_index = i; } - if (t == tag && values[i].as_init_unsafe() == value) + if (t == tag && KeyComparator::eq(values[i].as_init_unsafe(), value)) { ASL_ASSERT(already_present_index < 0); already_present_index = i; @@ -112,14 +147,16 @@ class hash_set // NOLINTEND(*-pointer-arithmetic) } - isize_t find_slot(const T& value) const + template + requires key_hasher && key_comparator + isize_t find_slot(const U& value) const { if (m_size <= 0) { return -1; }; ASL_ASSERT(is_pow2(m_capacity)); const isize_t capacity_mask = m_capacity - 1; - const uint64_t hash = hash_value(value); + const uint64_t hash = KeyHasher::hash(value); const uint8_t tag = static_cast(hash & kHashMask) | kHasValue; const auto starting_index = static_cast(hash >> 7) & capacity_mask; @@ -130,7 +167,7 @@ class hash_set { const uint8_t t = m_tags[i]; - if (t == tag && m_values[i].as_init_unsafe() == value) { return i; } + if (t == tag && KeyComparator::eq(m_values[i].as_init_unsafe(), value)) { return i; } if (t == kEmpty) { break; } i = (i + 1) & capacity_mask; @@ -246,14 +283,16 @@ public: insert_inner(ASL_MOVE(T{ASL_FWD(args)...}), m_tags, m_values, m_capacity, &m_size); } - bool contains(const T& value) const + template + requires key_hasher && key_comparator + bool contains(const U& value) const { return find_slot(value) >= 0; } - // @Todo Remove with something comparable, but not equal? How to hash? - // @Todo Same with contains - bool remove(const T& value) + template + requires key_hasher && key_comparator + bool remove(const U& value) { isize_t slot = find_slot(value); if (slot < 0) { return false; } diff --git a/asl/tests/hash_set_tests.cpp b/asl/tests/hash_set_tests.cpp index e6a020a..9df9463 100644 --- a/asl/tests/hash_set_tests.cpp +++ b/asl/tests/hash_set_tests.cpp @@ -32,7 +32,6 @@ ASL_TEST(a_bunch_of_strings) ASL_TEST_EXPECT(set.contains("Hello, world!"_sv)); ASL_TEST_EXPECT(set.contains("Hello, puppy!"_sv)); ASL_TEST_EXPECT(!set.contains("Hello, Steven!"_sv)); - } ASL_TEST(a_bunch_of_ints) @@ -67,11 +66,31 @@ struct HashWithDestructor: public DestructorObserver { return x == other.x; } +}; + +struct CustomComparator +{ + static bool eq(const HashWithDestructor& a, const HashWithDestructor& b) + { + return a.x == b.x; + } + + static bool eq(const HashWithDestructor& a, int b) + { + return a.x == b; + } +}; - template - friend H AslHashValue(H h, const HashWithDestructor& value) +struct CustomHasher +{ + static uint64_t hash(const HashWithDestructor& b) + { + return asl::hash_value(b.x); + } + + static uint64_t hash(int x) { - return H::combine(ASL_MOVE(h), value.x); + return asl::hash_value(x); } }; @@ -81,7 +100,7 @@ ASL_TEST(destructor_and_remove) bool destroyed[kCount]{}; { - asl::hash_set set; + asl::hash_set set; for (int i = 0; i < kCount; ++i) { @@ -97,14 +116,13 @@ ASL_TEST(destructor_and_remove) for (int i = 0; i < kCount; i += 2) { - // @Todo Remove with something comparable - ASL_TEST_EXPECT(set.remove(HashWithDestructor{i, nullptr})); + ASL_TEST_EXPECT(set.remove(i)); } for (int i = 0; i < kCount; i += 2) { - ASL_TEST_EXPECT(!set.contains(HashWithDestructor{i, nullptr})); - ASL_TEST_EXPECT(set.contains(HashWithDestructor{i+1, nullptr})); + ASL_TEST_EXPECT(!set.contains(i)); + ASL_TEST_EXPECT(set.contains(i+1)); ASL_TEST_EXPECT(destroyed[i]); // NOLINT ASL_TEST_EXPECT(!destroyed[i + 1]); // NOLINT } -- cgit