Add custom hasher & comparator for hash_set keys
This commit is contained in:
@ -11,8 +11,43 @@
|
|||||||
namespace asl
|
namespace asl
|
||||||
{
|
{
|
||||||
|
|
||||||
template<is_object T, allocator Allocator = DefaultAllocator>
|
template<typename H, typename T>
|
||||||
requires hashable<T> && move_constructible<T> && move_assignable<T> && equality_comparable<T>
|
concept key_hasher = requires (const T& value)
|
||||||
|
{
|
||||||
|
{ H::hash(value) } -> same_as<uint64_t>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<hashable T>
|
||||||
|
struct default_key_hasher
|
||||||
|
{
|
||||||
|
constexpr static uint64_t hash(const T& value)
|
||||||
|
{
|
||||||
|
return hash_value(value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename C, typename U, typename V = U>
|
||||||
|
concept key_comparator = requires(const U& a, const V& b)
|
||||||
|
{
|
||||||
|
{ C::eq(a, b) } -> same_as<bool>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<equality_comparable T>
|
||||||
|
struct default_key_comparator
|
||||||
|
{
|
||||||
|
constexpr static bool eq(const T& a, const T& b)
|
||||||
|
{
|
||||||
|
return a == b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<
|
||||||
|
is_object T,
|
||||||
|
allocator Allocator = DefaultAllocator,
|
||||||
|
key_hasher<T> KeyHasher = default_key_hasher<T>,
|
||||||
|
key_comparator<T> KeyComparator = default_key_comparator<T>
|
||||||
|
>
|
||||||
|
requires move_constructible<T> && move_assignable<T>
|
||||||
class hash_set
|
class hash_set
|
||||||
{
|
{
|
||||||
static constexpr uint8_t kHasValue = 0x80;
|
static constexpr uint8_t kHasValue = 0x80;
|
||||||
@ -49,7 +84,7 @@ class hash_set
|
|||||||
ASL_ASSERT(is_pow2(capacity));
|
ASL_ASSERT(is_pow2(capacity));
|
||||||
|
|
||||||
const isize_t capacity_mask = capacity - 1;
|
const isize_t capacity_mask = capacity - 1;
|
||||||
const uint64_t hash = hash_value(value);
|
const uint64_t hash = KeyHasher::hash(value);
|
||||||
const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
|
const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
|
||||||
const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
|
const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
|
||||||
|
|
||||||
@ -70,7 +105,7 @@ class hash_set
|
|||||||
first_available_index = i;
|
first_available_index = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (t == tag && values[i].as_init_unsafe() == value)
|
if (t == tag && KeyComparator::eq(values[i].as_init_unsafe(), value))
|
||||||
{
|
{
|
||||||
ASL_ASSERT(already_present_index < 0);
|
ASL_ASSERT(already_present_index < 0);
|
||||||
already_present_index = i;
|
already_present_index = i;
|
||||||
@ -112,14 +147,16 @@ class hash_set
|
|||||||
// NOLINTEND(*-pointer-arithmetic)
|
// NOLINTEND(*-pointer-arithmetic)
|
||||||
}
|
}
|
||||||
|
|
||||||
isize_t find_slot(const T& value) const
|
template<typename U>
|
||||||
|
requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
|
||||||
|
isize_t find_slot(const U& value) const
|
||||||
{
|
{
|
||||||
if (m_size <= 0) { return -1; };
|
if (m_size <= 0) { return -1; };
|
||||||
|
|
||||||
ASL_ASSERT(is_pow2(m_capacity));
|
ASL_ASSERT(is_pow2(m_capacity));
|
||||||
|
|
||||||
const isize_t capacity_mask = m_capacity - 1;
|
const isize_t capacity_mask = m_capacity - 1;
|
||||||
const uint64_t hash = hash_value(value);
|
const uint64_t hash = KeyHasher::hash(value);
|
||||||
const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
|
const uint8_t tag = static_cast<uint8_t>(hash & kHashMask) | kHasValue;
|
||||||
const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
|
const auto starting_index = static_cast<isize_t>(hash >> 7) & capacity_mask;
|
||||||
|
|
||||||
@ -130,7 +167,7 @@ class hash_set
|
|||||||
{
|
{
|
||||||
const uint8_t t = m_tags[i];
|
const uint8_t t = m_tags[i];
|
||||||
|
|
||||||
if (t == tag && m_values[i].as_init_unsafe() == value) { return i; }
|
if (t == tag && KeyComparator::eq(m_values[i].as_init_unsafe(), value)) { return i; }
|
||||||
if (t == kEmpty) { break; }
|
if (t == kEmpty) { break; }
|
||||||
|
|
||||||
i = (i + 1) & capacity_mask;
|
i = (i + 1) & capacity_mask;
|
||||||
@ -246,14 +283,16 @@ public:
|
|||||||
insert_inner(ASL_MOVE(T{ASL_FWD(args)...}), m_tags, m_values, m_capacity, &m_size);
|
insert_inner(ASL_MOVE(T{ASL_FWD(args)...}), m_tags, m_values, m_capacity, &m_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool contains(const T& value) const
|
template<typename U>
|
||||||
|
requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
|
||||||
|
bool contains(const U& value) const
|
||||||
{
|
{
|
||||||
return find_slot(value) >= 0;
|
return find_slot(value) >= 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Todo Remove with something comparable, but not equal? How to hash?
|
template<typename U>
|
||||||
// @Todo Same with contains
|
requires key_hasher<KeyHasher, U> && key_comparator<KeyComparator, T, U>
|
||||||
bool remove(const T& value)
|
bool remove(const U& value)
|
||||||
{
|
{
|
||||||
isize_t slot = find_slot(value);
|
isize_t slot = find_slot(value);
|
||||||
if (slot < 0) { return false; }
|
if (slot < 0) { return false; }
|
||||||
|
@ -32,7 +32,6 @@ ASL_TEST(a_bunch_of_strings)
|
|||||||
ASL_TEST_EXPECT(set.contains("Hello, world!"_sv));
|
ASL_TEST_EXPECT(set.contains("Hello, world!"_sv));
|
||||||
ASL_TEST_EXPECT(set.contains("Hello, puppy!"_sv));
|
ASL_TEST_EXPECT(set.contains("Hello, puppy!"_sv));
|
||||||
ASL_TEST_EXPECT(!set.contains("Hello, Steven!"_sv));
|
ASL_TEST_EXPECT(!set.contains("Hello, Steven!"_sv));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ASL_TEST(a_bunch_of_ints)
|
ASL_TEST(a_bunch_of_ints)
|
||||||
@ -67,11 +66,31 @@ struct HashWithDestructor: public DestructorObserver
|
|||||||
{
|
{
|
||||||
return x == other.x;
|
return x == other.x;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template<typename H>
|
struct CustomComparator
|
||||||
friend H AslHashValue(H h, const HashWithDestructor& value)
|
{
|
||||||
|
static bool eq(const HashWithDestructor& a, const HashWithDestructor& b)
|
||||||
{
|
{
|
||||||
return H::combine(ASL_MOVE(h), value.x);
|
return a.x == b.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool eq(const HashWithDestructor& a, int b)
|
||||||
|
{
|
||||||
|
return a.x == b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CustomHasher
|
||||||
|
{
|
||||||
|
static uint64_t hash(const HashWithDestructor& b)
|
||||||
|
{
|
||||||
|
return asl::hash_value(b.x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static uint64_t hash(int x)
|
||||||
|
{
|
||||||
|
return asl::hash_value(x);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -81,7 +100,7 @@ ASL_TEST(destructor_and_remove)
|
|||||||
bool destroyed[kCount]{};
|
bool destroyed[kCount]{};
|
||||||
|
|
||||||
{
|
{
|
||||||
asl::hash_set<HashWithDestructor> set;
|
asl::hash_set<HashWithDestructor, asl::DefaultAllocator, CustomHasher, CustomComparator> set;
|
||||||
|
|
||||||
for (int i = 0; i < kCount; ++i)
|
for (int i = 0; i < kCount; ++i)
|
||||||
{
|
{
|
||||||
@ -97,14 +116,13 @@ ASL_TEST(destructor_and_remove)
|
|||||||
|
|
||||||
for (int i = 0; i < kCount; i += 2)
|
for (int i = 0; i < kCount; i += 2)
|
||||||
{
|
{
|
||||||
// @Todo Remove with something comparable
|
ASL_TEST_EXPECT(set.remove(i));
|
||||||
ASL_TEST_EXPECT(set.remove(HashWithDestructor{i, nullptr}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < kCount; i += 2)
|
for (int i = 0; i < kCount; i += 2)
|
||||||
{
|
{
|
||||||
ASL_TEST_EXPECT(!set.contains(HashWithDestructor{i, nullptr}));
|
ASL_TEST_EXPECT(!set.contains(i));
|
||||||
ASL_TEST_EXPECT(set.contains(HashWithDestructor{i+1, nullptr}));
|
ASL_TEST_EXPECT(set.contains(i+1));
|
||||||
ASL_TEST_EXPECT(destroyed[i]); // NOLINT
|
ASL_TEST_EXPECT(destroyed[i]); // NOLINT
|
||||||
ASL_TEST_EXPECT(!destroyed[i + 1]); // NOLINT
|
ASL_TEST_EXPECT(!destroyed[i + 1]); // NOLINT
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user