Skip to content

Commit

Permalink
Problem: memiavl background snapshot rewriting panic when shutdown
Browse files Browse the repository at this point in the history
Solution:
- gracefully cancel the task when shutdown
  • Loading branch information
yihuang committed Jan 12, 2024
1 parent ea65bfa commit d4a8c9a
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 11 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## UNRELEASED

- [#]() memiavl cancel background snapshot rewriting when graceful shutdown.

*January 5, 2024*

## v1.1.0-rc2
Expand Down
25 changes: 21 additions & 4 deletions memiavl/db.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package memiavl

import (
"context"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -50,6 +51,9 @@ type DB struct {

// result channel of snapshot rewrite goroutine
snapshotRewriteChan chan snapshotResult
// context cancel function to cancel the snapshot rewrite goroutine
snapshotRewriteCancel context.CancelFunc

// the number of old snapshots to keep (excluding the latest one)
snapshotKeepRecent uint32
// block interval to take a new snapshot
Expand Down Expand Up @@ -414,6 +418,7 @@ func (db *DB) checkBackgroundSnapshotRewrite() error {
select {
case result := <-db.snapshotRewriteChan:
db.snapshotRewriteChan = nil
db.snapshotRewriteCancel = nil

if result.mtree == nil {
// background snapshot rewrite failed
Expand Down Expand Up @@ -628,7 +633,7 @@ func (db *DB) copy(cacheSize int) *DB {
}

// RewriteSnapshot writes the current version of memiavl into a snapshot, and update the `current` symlink.
func (db *DB) RewriteSnapshot() error {
func (db *DB) RewriteSnapshot(ctx context.Context) error {
db.mtx.Lock()
defer db.mtx.Unlock()

Expand All @@ -639,7 +644,7 @@ func (db *DB) RewriteSnapshot() error {
snapshotDir := snapshotName(db.lastCommitInfo.Version)
tmpDir := snapshotDir + TmpSuffix
path := filepath.Join(db.dir, tmpDir)
if err := db.MultiTree.WriteSnapshot(path, db.snapshotWriterPool); err != nil {
if err := db.MultiTree.WriteSnapshot(ctx, path, db.snapshotWriterPool); err != nil {
return errors.Join(err, os.RemoveAll(path))
}
if err := os.Rename(path, filepath.Join(db.dir, snapshotDir)); err != nil {
Expand Down Expand Up @@ -707,16 +712,19 @@ func (db *DB) rewriteSnapshotBackground() error {
return errors.New("there's another ongoing snapshot rewriting process")
}

ctx, cancel := context.WithCancel(context.Background())

ch := make(chan snapshotResult)
db.snapshotRewriteChan = ch
db.snapshotRewriteCancel = cancel

cloned := db.copy(0)
wal := db.wal
go func() {
defer close(ch)

cloned.logger.Info("start rewriting snapshot", "version", cloned.Version())
if err := cloned.RewriteSnapshot(); err != nil {
if err := cloned.RewriteSnapshot(ctx); err != nil {
ch <- snapshotResult{err: err}
return
}
Expand Down Expand Up @@ -746,7 +754,9 @@ func (db *DB) Close() error {
defer db.mtx.Unlock()

errs := []error{
db.waitAsyncCommit(), db.MultiTree.Close(), db.wal.Close(),
db.waitAsyncCommit(),
db.MultiTree.Close(),
db.wal.Close(),
}
db.wal = nil

Expand All @@ -755,6 +765,13 @@ func (db *DB) Close() error {
db.fileLock = nil
}

if db.snapshotRewriteChan != nil {
db.snapshotRewriteCancel()
<-db.snapshotRewriteChan
db.snapshotRewriteChan = nil
db.snapshotRewriteCancel = nil
}

return errors.Join(errs...)
}

Expand Down
3 changes: 2 additions & 1 deletion memiavl/import.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package memiavl

import (
"context"
"errors"
"fmt"
"math"
Expand Down Expand Up @@ -133,7 +134,7 @@ func doImport(dir string, version int64, nodes <-chan *ExportNode) (returnErr er
return errors.New("version overflows uint32")
}

return writeSnapshot(dir, uint32(version), func(w *snapshotWriter) (uint32, error) {
return writeSnapshot(context.Background(), dir, uint32(version), func(w *snapshotWriter) (uint32, error) {

Check failure

Code scanning / gosec

Potential integer overflow by integer type conversion Error

Potential integer overflow by integer type conversion
i := &importer{
snapshotWriter: *w,
}
Expand Down
4 changes: 2 additions & 2 deletions memiavl/multitree.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ func (t *MultiTree) CatchupWAL(wal *wal.Log, endVersion int64) error {
return nil
}

func (t *MultiTree) WriteSnapshot(dir string, wp *pond.WorkerPool) error {
func (t *MultiTree) WriteSnapshot(ctx context.Context, dir string, wp *pond.WorkerPool) error {
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
return err
}
Expand All @@ -368,7 +368,7 @@ func (t *MultiTree) WriteSnapshot(dir string, wp *pond.WorkerPool) error {
for _, entry := range t.trees {
tree, name := entry.Tree, entry.Name
group.Submit(func() error {
return tree.WriteSnapshot(filepath.Join(dir, name))
return tree.WriteSnapshot(ctx, filepath.Join(dir, name))
})
}

Expand Down
25 changes: 21 additions & 4 deletions memiavl/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package memiavl

import (
"bufio"
"context"
"encoding/binary"
"errors"
"fmt"
Expand All @@ -24,6 +25,9 @@ const (
FileNameLeaves = "leaves"
FileNameKVs = "kvs"
FileNameMetadata = "metadata"

// check for cancel every 1000 leaves
CancelCheckInterval = 1000
)

// Snapshot manage the lifecycle of mmap-ed files for the snapshot,
Expand Down Expand Up @@ -348,8 +352,8 @@ func (snapshot *Snapshot) export(callback func(*ExportNode) bool) {
}

// WriteSnapshot save the IAVL tree to a new snapshot directory.
func (t *Tree) WriteSnapshot(snapshotDir string) error {
return writeSnapshot(snapshotDir, t.version, func(w *snapshotWriter) (uint32, error) {
func (t *Tree) WriteSnapshot(ctx context.Context, snapshotDir string) error {
return writeSnapshot(ctx, snapshotDir, t.version, func(w *snapshotWriter) (uint32, error) {
if t.root == nil {
return 0, nil
} else {
Expand All @@ -362,6 +366,7 @@ func (t *Tree) WriteSnapshot(snapshotDir string) error {
}

func writeSnapshot(
ctx context.Context,
dir string, version uint32,
doWrite func(*snapshotWriter) (uint32, error),
) (returnErr error) {
Expand Down Expand Up @@ -407,7 +412,7 @@ func writeSnapshot(
leavesWriter := bufio.NewWriter(fpLeaves)
kvsWriter := bufio.NewWriter(fpKVs)

w := newSnapshotWriter(nodesWriter, leavesWriter, kvsWriter)
w := newSnapshotWriter(ctx, nodesWriter, leavesWriter, kvsWriter)
leaves, err := doWrite(w)
if err != nil {
return err
Expand Down Expand Up @@ -460,6 +465,9 @@ func writeSnapshot(
}

type snapshotWriter struct {
// context for cancel the writing process
ctx context.Context

nodesWriter, leavesWriter, kvWriter io.Writer

// count how many nodes have been written
Expand All @@ -469,8 +477,9 @@ type snapshotWriter struct {
kvsOffset uint64
}

func newSnapshotWriter(nodesWriter, leavesWriter, kvsWriter io.Writer) *snapshotWriter {
func newSnapshotWriter(ctx context.Context, nodesWriter, leavesWriter, kvsWriter io.Writer) *snapshotWriter {
return &snapshotWriter{
ctx: ctx,
nodesWriter: nodesWriter,
leavesWriter: leavesWriter,
kvWriter: kvsWriter,
Expand Down Expand Up @@ -502,6 +511,14 @@ func (w *snapshotWriter) writeKeyValue(key, value []byte) error {
}

func (w *snapshotWriter) writeLeaf(version uint32, key, value, hash []byte) error {
if w.leafCounter%CancelCheckInterval == 0 {
select {
case <-w.ctx.Done():
return w.ctx.Err()
default:
}
}

var buf [SizeLeafWithoutHash]byte
binary.LittleEndian.PutUint32(buf[OffsetLeafVersion:], version)
binary.LittleEndian.PutUint32(buf[OffsetLeafKeyLen:], uint32(len(key)))
Expand Down

0 comments on commit d4a8c9a

Please sign in to comment.