diff options
author | Steven Le Rouzic <steven.lerouzic@gmail.com> | 2025-01-14 23:30:53 +0100 |
---|---|---|
committer | Steven Le Rouzic <steven.lerouzic@gmail.com> | 2025-01-14 23:30:53 +0100 |
commit | 83b856b7d42deba868608f323a3cec4ae6a17d90 (patch) | |
tree | 232e45525b31d8f2849c35110ce76df11bedccd7 /asl | |
parent | 5f21ebf42e670470b315a992b8a60f7c2e2bbbeb (diff) |
Add custom hasher & comparator for hash_set keys
Diffstat (limited to 'asl')
-rw-r--r-- | asl/hash_set.hpp | 61 | ||||
-rw-r--r-- | asl/tests/hash_set_tests.cpp | 36 |
2 files changed, 77 insertions, 20 deletions
diff --git a/asl/hash_set.hpp b/asl/hash_set.hpp index fa95a96..d48a865 100644 --- a/asl/hash_set.hpp +++ b/asl/hash_set.hpp @@ -11,8 +11,43 @@ namespace asl
{
-template<is_object T, allocator Allocator = DefaultAllocator>
-requires hashable<T> && move_constructible<T> && move_assignable<T> && equality_comparable<T>
+template<typename H, typename T>
+concept key_hasher = requires (const T& value)
+{
+ { H::hash(value) } -> same_as<uint64_t>;
+};
+
+template<hashable T>
+struct default_key_hasher
+{
+ constexpr static uint64_t hash(const T& value)
+ {
+ return hash_value(value);
+ }
+};
+
+template<typename C, typename U, typename V = U>
+concept key_comparator = requires(const U& a, const V& b)
+{
+ { C::eq(a, b) } -> same_as<bool>;
+};
+
+template<equality_comparable T>
+struct default_key_comparator
+{
+ constexpr static bool eq(const T& a, const T& b)
+ {
+ return a == b;
+ }
+};
+
+template<
+ is_object T,
+ allocator Allocator = DefaultAllocator,
+ key_hasher<T> KeyHasher = default_key_hasher<T>,
+ key_comparator<T> KeyComparator = default_key_comparator<T>
+>
+requires move_constructible<T> && move_assignable<T>
class hash_set
{
static constexpr uint8_t kHasValue = 0x80;
@@ -49,7 +84,7 @@ class hash_set ASL_ASSERT(is_pow2(capacity));
const isize_t capacity_mask = capacity - 1;
- const uint64_t hash = hash_value(value);
+ const uint64_t hash = KeyHasher::hash(value);
const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
@@ -70,7 +105,7 @@ class hash_set first_available_index = i;
}
- if (t == tag && values[i].as_init_unsafe() == value)
+ if (t == tag && KeyComparator::eq(values[i].as_init_unsafe(), value))
{
ASL_ASSERT(already_present_index < 0);
already_present_index = i;
@@ -112,14 +147,16 @@ class hash_set // NOLINTEND(*-pointer-arithmetic)
}
- isize_t find_slot(const T& value) const
+ template<typename U>
+ requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
+ isize_t find_slot(const U& value) const
{
if (m_size <= 0) { return -1; };
ASL_ASSERT(is_pow2(m_capacity));
const isize_t capacity_mask = m_capacity - 1;
- const uint64_t hash = hash_value(value);
+ const uint64_t hash = KeyHasher::hash(value);
const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
@@ -130,7 +167,7 @@ class hash_set {
const uint8_t t = m_tags[i];
- if (t == tag && m_values[i].as_init_unsafe() == value) { return i; }
+ if (t == tag && KeyComparator::eq(m_values[i].as_init_unsafe(), value)) { return i; }
if (t == kEmpty) { break; }
i = (i + 1) & capacity_mask;
@@ -246,14 +283,16 @@ public: insert_inner(ASL_MOVE(T{ASL_FWD(args)...}), m_tags, m_values, m_capacity, &m_size);
}
- bool contains(const T& value) const
+ template<typename U>
+ requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
+ bool contains(const U& value) const
{
return find_slot(value) >= 0;
}
- // @Todo Remove with something comparable, but not equal? How to hash?
- // @Todo Same with contains
- bool remove(const T& value)
+ template<typename U>
+ requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
+ bool remove(const U& value)
{
isize_t slot = find_slot(value);
if (slot < 0) { return false; }
diff --git a/asl/tests/hash_set_tests.cpp b/asl/tests/hash_set_tests.cpp index e6a020a..9df9463 100644 --- a/asl/tests/hash_set_tests.cpp +++ b/asl/tests/hash_set_tests.cpp @@ -32,7 +32,6 @@ ASL_TEST(a_bunch_of_strings) ASL_TEST_EXPECT(set.contains("Hello, world!"_sv));
ASL_TEST_EXPECT(set.contains("Hello, puppy!"_sv));
ASL_TEST_EXPECT(!set.contains("Hello, Steven!"_sv));
-
}
ASL_TEST(a_bunch_of_ints)
@@ -67,11 +66,31 @@ struct HashWithDestructor: public DestructorObserver {
return x == other.x;
}
+};
+
+struct CustomComparator
+{
+ static bool eq(const HashWithDestructor& a, const HashWithDestructor& b)
+ {
+ return a.x == b.x;
+ }
+
+ static bool eq(const HashWithDestructor& a, int b)
+ {
+ return a.x == b;
+ }
+};
- template<typename H>
- friend H AslHashValue(H h, const HashWithDestructor& value)
+struct CustomHasher
+{
+ static uint64_t hash(const HashWithDestructor& b)
+ {
+ return asl::hash_value(b.x);
+ }
+
+ static uint64_t hash(int x)
{
- return H::combine(ASL_MOVE(h), value.x);
+ return asl::hash_value(x);
}
};
@@ -81,7 +100,7 @@ ASL_TEST(destructor_and_remove) bool destroyed[kCount]{};
{
- asl::hash_set<HashWithDestructor> set;
+ asl::hash_set<HashWithDestructor, asl::DefaultAllocator, CustomHasher, CustomComparator> set;
for (int i = 0; i < kCount; ++i)
{
@@ -97,14 +116,13 @@ ASL_TEST(destructor_and_remove) for (int i = 0; i < kCount; i += 2)
{
- // @Todo Remove with something comparable
- ASL_TEST_EXPECT(set.remove(HashWithDestructor{i, nullptr}));
+ ASL_TEST_EXPECT(set.remove(i));
}
for (int i = 0; i < kCount; i += 2)
{
- ASL_TEST_EXPECT(!set.contains(HashWithDestructor{i, nullptr}));
- ASL_TEST_EXPECT(set.contains(HashWithDestructor{i+1, nullptr}));
+ ASL_TEST_EXPECT(!set.contains(i));
+ ASL_TEST_EXPECT(set.contains(i+1));
ASL_TEST_EXPECT(destroyed[i]); // NOLINT
ASL_TEST_EXPECT(!destroyed[i + 1]); // NOLINT
}
|