Add custom hasher & comparator for hash_set keys

This commit is contained in:
2025-01-14 23:30:53 +01:00
parent 5f21ebf42e
commit 83b856b7d4
2 changed files with 77 additions and 20 deletions

View File

@ -11,8 +11,43 @@
namespace asl namespace asl
{ {
template<is_object T, allocator Allocator = DefaultAllocator> template<typename H, typename T>
requires hashable<T> && move_constructible<T> && move_assignable<T> && equality_comparable<T> concept key_hasher = requires (const T& value)
{
{ H::hash(value) } -> same_as<uint64_t>;
};
template<hashable T>
struct default_key_hasher
{
constexpr static uint64_t hash(const T& value)
{
return hash_value(value);
}
};
template<typename C, typename U, typename V = U>
concept key_comparator = requires(const U& a, const V& b)
{
{ C::eq(a, b) } -> same_as<bool>;
};
template<equality_comparable T>
struct default_key_comparator
{
constexpr static bool eq(const T& a, const T& b)
{
return a == b;
}
};
template<
is_object T,
allocator Allocator = DefaultAllocator,
key_hasher<T> KeyHasher = default_key_hasher<T>,
key_comparator<T> KeyComparator = default_key_comparator<T>
>
requires move_constructible<T> && move_assignable<T>
class hash_set class hash_set
{ {
static constexpr uint8_t kHasValue = 0x80; static constexpr uint8_t kHasValue = 0x80;
@ -49,7 +84,7 @@ class hash_set
ASL_ASSERT(is_pow2(capacity)); ASL_ASSERT(is_pow2(capacity));
const isize_t capacity_mask = capacity - 1; const isize_t capacity_mask = capacity - 1;
const uint64_t hash = hash_value(value); const uint64_t hash = KeyHasher::hash(value);
const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue; const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask; const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
@ -70,7 +105,7 @@ class hash_set
first_available_index = i; first_available_index = i;
} }
if (t == tag && values[i].as_init_unsafe() == value) if (t == tag && KeyComparator::eq(values[i].as_init_unsafe(), value))
{ {
ASL_ASSERT(already_present_index < 0); ASL_ASSERT(already_present_index < 0);
already_present_index = i; already_present_index = i;
@ -112,14 +147,16 @@ class hash_set
// NOLINTEND(*-pointer-arithmetic) // NOLINTEND(*-pointer-arithmetic)
} }
isize_t find_slot(const T& value) const template<typename U>
requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
isize_t find_slot(const U& value) const
{ {
if (m_size <= 0) { return -1; }; if (m_size <= 0) { return -1; };
ASL_ASSERT(is_pow2(m_capacity)); ASL_ASSERT(is_pow2(m_capacity));
const isize_t capacity_mask = m_capacity - 1; const isize_t capacity_mask = m_capacity - 1;
const uint64_t hash = hash_value(value); const uint64_t hash = KeyHasher::hash(value);
const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue; const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask; const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
@ -130,7 +167,7 @@ class hash_set
{ {
const uint8_t t = m_tags[i]; const uint8_t t = m_tags[i];
if (t == tag && m_values[i].as_init_unsafe() == value) { return i; } if (t == tag && KeyComparator::eq(m_values[i].as_init_unsafe(), value)) { return i; }
if (t == kEmpty) { break; } if (t == kEmpty) { break; }
i = (i + 1) & capacity_mask; i = (i + 1) & capacity_mask;
@ -246,14 +283,16 @@ public:
insert_inner(ASL_MOVE(T{ASL_FWD(args)...}), m_tags, m_values, m_capacity, &m_size); insert_inner(ASL_MOVE(T{ASL_FWD(args)...}), m_tags, m_values, m_capacity, &m_size);
} }
bool contains(const T& value) const template<typename U>
requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
bool contains(const U& value) const
{ {
return find_slot(value) >= 0; return find_slot(value) >= 0;
} }
// @Todo Remove with something comparable, but not equal? How to hash? template<typename U>
// @Todo Same with contains requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
bool remove(const T& value) bool remove(const U& value)
{ {
isize_t slot = find_slot(value); isize_t slot = find_slot(value);
if (slot < 0) { return false; } if (slot < 0) { return false; }

View File

@ -32,7 +32,6 @@ ASL_TEST(a_bunch_of_strings)
ASL_TEST_EXPECT(set.contains("Hello, world!"_sv)); ASL_TEST_EXPECT(set.contains("Hello, world!"_sv));
ASL_TEST_EXPECT(set.contains("Hello, puppy!"_sv)); ASL_TEST_EXPECT(set.contains("Hello, puppy!"_sv));
ASL_TEST_EXPECT(!set.contains("Hello, Steven!"_sv)); ASL_TEST_EXPECT(!set.contains("Hello, Steven!"_sv));
} }
ASL_TEST(a_bunch_of_ints) ASL_TEST(a_bunch_of_ints)
@ -67,11 +66,31 @@ struct HashWithDestructor: public DestructorObserver
{ {
return x == other.x; return x == other.x;
} }
};
template<typename H> struct CustomComparator
friend H AslHashValue(H h, const HashWithDestructor& value) {
static bool eq(const HashWithDestructor& a, const HashWithDestructor& b)
{ {
return H::combine(ASL_MOVE(h), value.x); return a.x == b.x;
}
static bool eq(const HashWithDestructor& a, int b)
{
return a.x == b;
}
};
struct CustomHasher
{
static uint64_t hash(const HashWithDestructor& b)
{
return asl::hash_value(b.x);
}
static uint64_t hash(int x)
{
return asl::hash_value(x);
} }
}; };
@ -81,7 +100,7 @@ ASL_TEST(destructor_and_remove)
bool destroyed[kCount]{}; bool destroyed[kCount]{};
{ {
asl::hash_set<HashWithDestructor> set; asl::hash_set<HashWithDestructor, asl::DefaultAllocator, CustomHasher, CustomComparator> set;
for (int i = 0; i < kCount; ++i) for (int i = 0; i < kCount; ++i)
{ {
@ -97,14 +116,13 @@ ASL_TEST(destructor_and_remove)
for (int i = 0; i < kCount; i += 2) for (int i = 0; i < kCount; i += 2)
{ {
// @Todo Remove with something comparable ASL_TEST_EXPECT(set.remove(i));
ASL_TEST_EXPECT(set.remove(HashWithDestructor{i, nullptr}));
} }
for (int i = 0; i < kCount; i += 2) for (int i = 0; i < kCount; i += 2)
{ {
ASL_TEST_EXPECT(!set.contains(HashWithDestructor{i, nullptr})); ASL_TEST_EXPECT(!set.contains(i));
ASL_TEST_EXPECT(set.contains(HashWithDestructor{i+1, nullptr})); ASL_TEST_EXPECT(set.contains(i+1));
ASL_TEST_EXPECT(destroyed[i]); // NOLINT ASL_TEST_EXPECT(destroyed[i]); // NOLINT
ASL_TEST_EXPECT(!destroyed[i + 1]); // NOLINT ASL_TEST_EXPECT(!destroyed[i + 1]); // NOLINT
} }