Add custom hasher & comparator for hash_set keys

This commit is contained in:
2025-01-14 23:30:53 +01:00
parent 5f21ebf42e
commit 83b856b7d4
2 changed files with 77 additions and 20 deletions

View File

@ -11,8 +11,43 @@
namespace asl
{
template<is_object T, allocator Allocator = DefaultAllocator>
requires hashable<T> && move_constructible<T> && move_assignable<T> && equality_comparable<T>
template<typename H, typename T>
concept key_hasher = requires (const T& value)
{
{ H::hash(value) } -> same_as<uint64_t>;
};
template<hashable T>
struct default_key_hasher
{
constexpr static uint64_t hash(const T& value)
{
return hash_value(value);
}
};
template<typename C, typename U, typename V = U>
concept key_comparator = requires(const U& a, const V& b)
{
{ C::eq(a, b) } -> same_as<bool>;
};
template<equality_comparable T>
struct default_key_comparator
{
constexpr static bool eq(const T& a, const T& b)
{
return a == b;
}
};
template<
is_object T,
allocator Allocator = DefaultAllocator,
key_hasher<T> KeyHasher = default_key_hasher<T>,
key_comparator<T> KeyComparator = default_key_comparator<T>
>
requires move_constructible<T> && move_assignable<T>
class hash_set
{
static constexpr uint8_t kHasValue = 0x80;
@ -49,7 +84,7 @@ class hash_set
ASL_ASSERT(is_pow2(capacity));
const isize_t capacity_mask = capacity - 1;
const uint64_t hash = hash_value(value);
const uint64_t hash = KeyHasher::hash(value);
const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
@ -70,7 +105,7 @@ class hash_set
first_available_index = i;
}
if (t == tag && values[i].as_init_unsafe() == value)
if (t == tag && KeyComparator::eq(values[i].as_init_unsafe(), value))
{
ASL_ASSERT(already_present_index < 0);
already_present_index = i;
@ -112,14 +147,16 @@ class hash_set
// NOLINTEND(*-pointer-arithmetic)
}
isize_t find_slot(const T& value) const
template<typename U>
requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
isize_t find_slot(const U& value) const
{
if (m_size <= 0) { return -1; };
ASL_ASSERT(is_pow2(m_capacity));
const isize_t capacity_mask = m_capacity - 1;
const uint64_t hash = hash_value(value);
const uint64_t hash = KeyHasher::hash(value);
const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
@ -130,7 +167,7 @@ class hash_set
{
const uint8_t t = m_tags[i];
if (t == tag && m_values[i].as_init_unsafe() == value) { return i; }
if (t == tag && KeyComparator::eq(m_values[i].as_init_unsafe(), value)) { return i; }
if (t == kEmpty) { break; }
i = (i + 1) & capacity_mask;
@ -246,14 +283,16 @@ public:
insert_inner(ASL_MOVE(T{ASL_FWD(args)...}), m_tags, m_values, m_capacity, &m_size);
}
bool contains(const T& value) const
template<typename U>
requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
bool contains(const U& value) const
{
return find_slot(value) >= 0;
}
// @Todo Remove with something comparable, but not equal? How to hash?
// @Todo Same with contains
bool remove(const T& value)
template<typename U>
requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
bool remove(const U& value)
{
isize_t slot = find_slot(value);
if (slot < 0) { return false; }

View File

@ -32,7 +32,6 @@ ASL_TEST(a_bunch_of_strings)
ASL_TEST_EXPECT(set.contains("Hello, world!"_sv));
ASL_TEST_EXPECT(set.contains("Hello, puppy!"_sv));
ASL_TEST_EXPECT(!set.contains("Hello, Steven!"_sv));
}
ASL_TEST(a_bunch_of_ints)
@ -67,11 +66,31 @@ struct HashWithDestructor: public DestructorObserver
{
return x == other.x;
}
};
template<typename H>
friend H AslHashValue(H h, const HashWithDestructor& value)
struct CustomComparator
{
static bool eq(const HashWithDestructor& a, const HashWithDestructor& b)
{
return H::combine(ASL_MOVE(h), value.x);
return a.x == b.x;
}
static bool eq(const HashWithDestructor& a, int b)
{
return a.x == b;
}
};
struct CustomHasher
{
static uint64_t hash(const HashWithDestructor& b)
{
return asl::hash_value(b.x);
}
static uint64_t hash(int x)
{
return asl::hash_value(x);
}
};
@ -81,7 +100,7 @@ ASL_TEST(destructor_and_remove)
bool destroyed[kCount]{};
{
asl::hash_set<HashWithDestructor> set;
asl::hash_set<HashWithDestructor, asl::DefaultAllocator, CustomHasher, CustomComparator> set;
for (int i = 0; i < kCount; ++i)
{
@ -97,14 +116,13 @@ ASL_TEST(destructor_and_remove)
for (int i = 0; i < kCount; i += 2)
{
// @Todo Remove with something comparable
ASL_TEST_EXPECT(set.remove(HashWithDestructor{i, nullptr}));
ASL_TEST_EXPECT(set.remove(i));
}
for (int i = 0; i < kCount; i += 2)
{
ASL_TEST_EXPECT(!set.contains(HashWithDestructor{i, nullptr}));
ASL_TEST_EXPECT(set.contains(HashWithDestructor{i+1, nullptr}));
ASL_TEST_EXPECT(!set.contains(i));
ASL_TEST_EXPECT(set.contains(i+1));
ASL_TEST_EXPECT(destroyed[i]); // NOLINT
ASL_TEST_EXPECT(!destroyed[i + 1]); // NOLINT
}