From 4e191f2953e42380726cddd6f999c5bbeb93083d Mon Sep 17 00:00:00 2001 From: Rain Valentine Date: Fri, 28 Feb 2025 20:27:22 +0000 Subject: [PATCH] PR feedback, more focused fixes/refactors Signed-off-by: Rain Valentine --- src/defrag.c | 8 +++++- src/rdb.c | 10 ++++---- src/server.h | 2 +- src/t_hash.c | 72 ++++++++++++++++++++++++++++++++-------------------- src/t_zset.c | 4 +-- 5 files changed, 60 insertions(+), 36 deletions(-) diff --git a/src/defrag.c b/src/defrag.c index 310ef8bc0e..8c42d72f7c 100644 --- a/src/defrag.c +++ b/src/defrag.c @@ -38,6 +38,7 @@ #include "eval.h" #include "script.h" #include "module.h" +#include #include #ifdef HAVE_DEFRAG @@ -346,7 +347,7 @@ static void activeDefragSdsDict(dict *d, int val_type) { } while (cursor != 0); } -void activeDefragSdsHashtableCallback(void *privdata, void *entry_ref) { +static void activeDefragSdsHashtableCallback(void *privdata, void *entry_ref) { UNUSED(privdata); sds *sds_ref = (sds *)entry_ref; sds new_sds = activeDefragSds(*sds_ref); @@ -398,6 +399,7 @@ static long scanLaterList(robj *ob, unsigned long *cursor, monotime endtime) { quicklistNode *node; long iterations = 0; int bookmark_failed = 0; + serverAssert(ob->type == OBJ_LIST && ob->encoding == OBJ_ENCODING_QUICKLIST); if (*cursor == 0) { /* if cursor is 0, we start new iteration */ @@ -440,6 +442,7 @@ static void scanLaterZsetCallback(void *privdata, void *element_ref) { } static void scanLaterZset(robj *ob, unsigned long *cursor) { + serverAssert(ob->type == OBJ_ZSET && ob->encoding == OBJ_ENCODING_SKIPLIST); zset *zs = (zset *)ob->ptr; *cursor = hashtableScanDefrag(zs->ht, *cursor, scanLaterZsetCallback, zs->zsl, activeDefragAlloc, HASHTABLE_SCAN_EMIT_REF); } @@ -453,6 +456,7 @@ static void scanHashtableCallbackCountScanned(void *privdata, void *elemref) { } static void scanLaterSet(robj *ob, unsigned long *cursor) { + serverAssert(ob->type == OBJ_SET && ob->encoding == OBJ_ENCODING_HASHTABLE); hashtable *ht = ob->ptr; *cursor = hashtableScanDefrag(ht, *cursor, activeDefragSdsHashtableCallback, NULL, activeDefragAlloc, HASHTABLE_SCAN_EMIT_REF); } @@ -467,6 +471,7 @@ static void activeDefragHashTypeEntry(void *privdata, void *element_ref) { } static void scanLaterHash(robj *ob, unsigned long *cursor) { + serverAssert(ob->type == OBJ_HASH && ob->encoding == OBJ_ENCODING_HASHTABLE); hashtable *ht = ob->ptr; *cursor = hashtableScanDefrag(ht, *cursor, activeDefragHashTypeEntry, NULL, activeDefragAlloc, HASHTABLE_SCAN_EMIT_REF); } @@ -553,6 +558,7 @@ static int scanLaterStreamListpacks(robj *ob, unsigned long *cursor, monotime en static unsigned char last[sizeof(streamID)]; raxIterator ri; long iterations = 0; + serverAssert(ob->type == OBJ_STREAM && ob->encoding == OBJ_ENCODING_STREAM); stream *s = ob->ptr; raxStart(&ri, s->rax); diff --git a/src/rdb.c b/src/rdb.c index 25b0fcbb33..4a62883ede 100644 --- a/src/rdb.c +++ b/src/rdb.c @@ -2077,7 +2077,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { /* Too many entries? Use a hash table right from the start. */ if (len > server.hash_max_listpack_entries) - hashTypeEnsureHashtableEncoded(o); + hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); else if (deep_integrity_validation) { /* In this mode, we need to guarantee that the server won't crash * later when the ziplist is converted to a hashtable. @@ -2119,7 +2119,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { /* Convert to hash table if size threshold is exceeded */ if (sdslen(field) > server.hash_max_listpack_value || sdslen(value) > server.hash_max_listpack_value || !lpSafeToAdd(o->ptr, sdslen(field) + sdslen(value))) { - hashTypeEnsureHashtableEncoded(o); + hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); hashTypeEntry *entry = hashTypeCreateEntry(field, value); sdsfree(field); if (!hashtableAdd((hashtable *)o->ptr, entry)) { @@ -2319,7 +2319,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { o->encoding = OBJ_ENCODING_LISTPACK; if (hashTypeLength(o) > server.hash_max_listpack_entries || maxlen > server.hash_max_listpack_value) { - hashTypeEnsureHashtableEncoded(o); + hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); } } break; @@ -2447,7 +2447,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { } if (hashTypeLength(o) > server.hash_max_listpack_entries) - hashTypeEnsureHashtableEncoded(o); + hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); else o->ptr = lpShrinkToFit(o->ptr); break; @@ -2468,7 +2468,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { goto emptykey; } - if (hashTypeLength(o) > server.hash_max_listpack_entries) hashTypeEnsureHashtableEncoded(o); + if (hashTypeLength(o) > server.hash_max_listpack_entries) hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); break; default: /* totally unreachable */ diff --git a/src/server.h b/src/server.h index a22889cf5c..37543c193c 100644 --- a/src/server.h +++ b/src/server.h @@ -3257,7 +3257,7 @@ hashTypeEntry *hashTypeEntryDefrag(hashTypeEntry *entry, void *(*defragfn)(void void dismissHashTypeEntry(hashTypeEntry *entry); void freeHashTypeEntry(hashTypeEntry *entry); -void hashTypeEnsureHashtableEncoded(robj *o); +void hashTypeConvert(robj *o, int enc); void hashTypeTryConversion(robj *subject, robj **argv, int start, int end); int hashTypeExists(robj *o, sds key); int hashTypeDelete(robj *o, sds key); diff --git a/src/t_hash.c b/src/t_hash.c index 4384646333..547aea9e16 100644 --- a/src/t_hash.c +++ b/src/t_hash.c @@ -267,7 +267,7 @@ void hashTypeTryConversion(robj *o, robj **argv, int start, int end) { * might over allocate memory if there are duplicates. */ size_t new_fields = (end - start + 1) / 2; if (new_fields > server.hash_max_listpack_entries) { - hashTypeEnsureHashtableEncoded(o); + hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); hashtableExpand(o->ptr, new_fields); return; } @@ -276,12 +276,12 @@ void hashTypeTryConversion(robj *o, robj **argv, int start, int end) { if (!sdsEncodedObject(argv[i])) continue; size_t len = sdslen(argv[i]->ptr); if (len > server.hash_max_listpack_value) { - hashTypeEnsureHashtableEncoded(o); + hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); return; } sum += len; } - if (!lpSafeToAdd(o->ptr, sum)) hashTypeEnsureHashtableEncoded(o); + if (!lpSafeToAdd(o->ptr, sum)) hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); } /* Get the value from a listpack encoded hash, identified by field. @@ -415,7 +415,7 @@ int hashTypeSet(robj *o, sds field, sds value, int flags) { * hashTypeTryConversion, so this check will be a NOP. */ if (o->encoding == OBJ_ENCODING_LISTPACK) { if (sdslen(field) > server.hash_max_listpack_value || sdslen(value) > server.hash_max_listpack_value) - hashTypeEnsureHashtableEncoded(o); + hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); } if (o->encoding == OBJ_ENCODING_LISTPACK) { @@ -444,7 +444,7 @@ int hashTypeSet(robj *o, sds field, sds value, int flags) { o->ptr = zl; /* Check if the listpack needs to be converted to a hash table */ - if (hashTypeLength(o) > server.hash_max_listpack_entries) hashTypeEnsureHashtableEncoded(o); + if (hashTypeLength(o) > server.hash_max_listpack_entries) hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); } else if (o->encoding == OBJ_ENCODING_HASHTABLE) { hashtable *ht = o->ptr; @@ -654,33 +654,51 @@ robj *hashTypeLookupWriteOrCreate(client *c, robj *key) { return o; } -void hashTypeEnsureHashtableEncoded(robj *o) { - serverAssert(o->type == OBJ_HASH); - if (o->encoding == OBJ_ENCODING_HASHTABLE) return; - hashtable *ht = hashtableCreate(&hashHashtableType); +void hashTypeConvertListpack(robj *o, int enc) { + serverAssert(o->encoding == OBJ_ENCODING_LISTPACK); - /* Presize the hashtable to avoid rehashing */ - hashtableExpand(ht, hashTypeLength(o)); + if (enc == OBJ_ENCODING_LISTPACK) { + /* Nothing to do... */ - hashTypeIterator hi; - hashTypeInitIterator(o, &hi); - while (hashTypeNext(&hi) != C_ERR) { - sds field = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_FIELD); - sds value = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_VALUE); - hashTypeEntry *entry = hashTypeCreateEntry(field, value); - sdsfree(field); - if (!hashtableAdd(ht, entry)) { - freeHashTypeEntry(entry); - hashTypeResetIterator(&hi); /* Needed for gcc ASAN */ - serverLogHexDump(LL_WARNING, "listpack with dup elements dump", o->ptr, lpBytes(o->ptr)); - serverPanic("Listpack corruption detected"); + } else if (enc == OBJ_ENCODING_HASHTABLE) { + hashTypeIterator hi; + + hashtable *ht = hashtableCreate(&hashHashtableType); + + /* Presize the hashtable to avoid rehashing */ + hashtableExpand(ht, hashTypeLength(o)); + + hashTypeInitIterator(o, &hi); + while (hashTypeNext(&hi) != C_ERR) { + sds field = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_FIELD); + sds value = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_VALUE); + hashTypeEntry *entry = hashTypeCreateEntry(field, value); + sdsfree(field); + if (!hashtableAdd(ht, entry)) { + freeHashTypeEntry(entry); + hashTypeResetIterator(&hi); /* Needed for gcc ASAN */ + serverLogHexDump(LL_WARNING, "listpack with dup elements dump", o->ptr, lpBytes(o->ptr)); + serverPanic("Listpack corruption detected"); + } } + hashTypeResetIterator(&hi); + zfree(o->ptr); + o->encoding = OBJ_ENCODING_HASHTABLE; + o->ptr = ht; + } else { + serverPanic("Unknown hash encoding"); + } +} + +void hashTypeConvert(robj *o, int enc) { + if (o->encoding == OBJ_ENCODING_LISTPACK) { + hashTypeConvertListpack(o, enc); + } else if (o->encoding == OBJ_ENCODING_HASHTABLE) { + serverPanic("Not implemented"); + } else { + serverPanic("Unknown hash encoding"); } - hashTypeResetIterator(&hi); - zfree(o->ptr); - o->encoding = OBJ_ENCODING_HASHTABLE; - o->ptr = ht; } /* This is a helper function for the COPY command. diff --git a/src/t_zset.c b/src/t_zset.c index 77bfb06f91..2444f3ecd0 100644 --- a/src/t_zset.c +++ b/src/t_zset.c @@ -68,7 +68,7 @@ int zslLexValueGteMin(sds value, zlexrangespec *spec); int zslLexValueLteMax(sds value, zlexrangespec *spec); -static void zsetConvertAndExpand(robj *zobj, int encoding, unsigned long cap); +void zsetConvertAndExpand(robj *zobj, int encoding, unsigned long cap); static zskiplistNode *zslGetElementByRankFromNode(zskiplistNode *start_node, int start_level, unsigned long rank); zskiplistNode *zslGetElementByRank(zskiplist *zsl, unsigned long rank); @@ -1269,7 +1269,7 @@ void zsetConvert(robj *zobj, int encoding) { } /* Converts a zset to the specified encoding, pre-sizing it for 'cap' elements. */ -static void zsetConvertAndExpand(robj *zobj, int encoding, unsigned long cap) { +void zsetConvertAndExpand(robj *zobj, int encoding, unsigned long cap) { zset *zs; zskiplistNode *node, *next; sds ele;