Skip to content

Commit

Permalink
optimize checksumming entire database
Browse files Browse the repository at this point in the history
  • Loading branch information
btoews committed Jan 8, 2025
1 parent cd75b15 commit 67dadff
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 81 deletions.
179 changes: 179 additions & 0 deletions checksum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package ltx

import (
"encoding/binary"
"encoding/json"
"fmt"
"hash"
"hash/crc64"
"io"
"os"
"strconv"
"sync"
)

// Checksum represents an LTX checksum.
type Checksum uint64

// ChecksumPages updates the provided checksums slice with the checksum of each
// page in the specified file. The first (by page number) error encountered is
// returned along with the number of the last page successfully checksummed.
// Checksums for subsequent pages may be updated, regardless of an error being
// returned.
//
// nWorkers specifies the amount of parallelism to use. A reasonable default
// will be used if nWorkers is 0.
func ChecksumPages(dbPath string, pageSize, nPages, nWorkers uint32, checksums []Checksum) (uint32, error) {
// Based on experimentation on a fly.io machine with a slow SSD, 512Mb is
// where checksumming starts to take >1s. As the database size increases, we
// get more benefit from an increasing number of workers. Doing a bunch of
// benchmarking on fly.io machines of difference sizes with a 1Gb database,
// 24 threads seems to be the sweet spot.
if nWorkers == 0 && pageSize*nPages > 512*1024*1024 {
nWorkers = 24
}

if nWorkers <= 1 {
return checksumPagesSerial(dbPath, 1, nPages, int64(pageSize), checksums)
}

perWorker := nPages / nWorkers
if nPages%nWorkers != 0 {
perWorker++
}

var (
wg sync.WaitGroup
rets = make([]uint32, nWorkers)
errs = make([]error, nWorkers)
)

for w := uint32(0); w < nWorkers; w++ {
w := w
firstPage := w*perWorker + 1
lastPage := firstPage + perWorker - 1
if lastPage > nPages {
lastPage = nPages
}

wg.Add(1)
go func() {
rets[w], errs[w] = checksumPagesSerial(dbPath, firstPage, lastPage, int64(pageSize), checksums)
wg.Done()
}()
}

wg.Wait()
for i, err := range errs {
if err != nil {
return rets[i], err
}
}

return nPages, nil
}

func checksumPagesSerial(dbPath string, firstPage, lastPage uint32, pageSize int64, checksums []Checksum) (uint32, error) {
f, err := os.Open(dbPath)
if err != nil {
return firstPage - 1, err
}

_, err = f.Seek(int64(firstPage-1)*pageSize, io.SeekStart)
if err != nil {
return firstPage - 1, err
}

buf := make([]byte, pageSize+4)
h := NewHasher()

for pageNo := firstPage; pageNo <= lastPage; pageNo++ {
binary.BigEndian.PutUint32(buf, pageNo)

if _, err := io.ReadFull(f, buf[4:]); err != nil {
return pageNo - 1, err
}

h.Reset()
_, _ = h.Write(buf)
checksums[pageNo-1] = ChecksumFlag | Checksum(h.Sum64())
}

return lastPage, nil
}

// ChecksumPage returns a CRC64 checksum that combines the page number & page data.
func ChecksumPage(pgno uint32, data []byte) Checksum {
return ChecksumPageWithHasher(NewHasher(), pgno, data)
}

// ChecksumPageWithHasher returns a CRC64 checksum that combines the page number & page data.
func ChecksumPageWithHasher(h hash.Hash64, pgno uint32, data []byte) Checksum {
h.Reset()
_ = binary.Write(h, binary.BigEndian, pgno)
_, _ = h.Write(data)
return ChecksumFlag | Checksum(h.Sum64())
}

// ChecksumReader reads an entire database file from r and computes its rolling checksum.
func ChecksumReader(r io.Reader, pageSize int) (Checksum, error) {
data := make([]byte, pageSize)

var chksum Checksum
for pgno := uint32(1); ; pgno++ {
if _, err := io.ReadFull(r, data); err == io.EOF {
break
} else if err != nil {
return chksum, err
}
chksum = ChecksumFlag | (chksum ^ ChecksumPage(pgno, data))
}
return chksum, nil
}

// ParseChecksum parses a 16-character hex string into a checksum.
func ParseChecksum(s string) (Checksum, error) {
if len(s) != 16 {
return 0, fmt.Errorf("invalid formatted checksum length: %q", s)
}
v, err := strconv.ParseUint(s, 16, 64)
if err != nil {
return 0, fmt.Errorf("invalid checksum format: %q", s)
}
return Checksum(v), nil
}

// String returns c formatted as a fixed-width hex number.
func (c Checksum) String() string {
return fmt.Sprintf("%016x", uint64(c))
}

func (c Checksum) MarshalJSON() ([]byte, error) {
return []byte(`"` + c.String() + `"`), nil
}

func (c *Checksum) UnmarshalJSON(data []byte) (err error) {
var s *string
if err := json.Unmarshal(data, &s); err != nil {
return fmt.Errorf("cannot unmarshal checksum from JSON value")
}

// Set to zero if value is nil.
if s == nil {
*c = 0
return nil
}

chksum, err := ParseChecksum(*s)
if err != nil {
return fmt.Errorf("cannot parse checksum from JSON string: %q", *s)
}
*c = Checksum(chksum)

return nil
}

// NewHasher returns a new CRC64-ISO hasher.
func NewHasher() hash.Hash64 {
return crc64.New(crc64.MakeTable(crc64.ISO))
}
101 changes: 101 additions & 0 deletions checksum_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package ltx

import (
"crypto/rand"
"fmt"
"io"
"os"
"path/filepath"
"testing"
)

func TestChecksumPages(t *testing.T) {
// files divisible into pages
testChecksumPages(t, 1024*4, 4, 1024, 1)
testChecksumPages(t, 1024*4, 4, 1024, 2)
testChecksumPages(t, 1024*4, 4, 1024, 3)
testChecksumPages(t, 1024*4, 4, 1024, 4)

// short pages
testChecksumPages(t, 1024*3+100, 4, 1024, 1)
testChecksumPages(t, 1024*3+100, 4, 1024, 2)
testChecksumPages(t, 1024*3+100, 4, 1024, 3)
testChecksumPages(t, 1024*3+100, 4, 1024, 4)

// empty files
testChecksumPages(t, 0, 4, 1024, 1)
testChecksumPages(t, 0, 4, 1024, 2)
testChecksumPages(t, 0, 4, 1024, 3)
testChecksumPages(t, 0, 4, 1024, 4)
}

func testChecksumPages(t *testing.T, fileSize, nPages, pageSize, nWorkers uint32) {
t.Run(fmt.Sprintf("fileSize=%d,nPages=%d,pageSize=%d,nWorkers=%d", fileSize, nPages, pageSize, nWorkers), func(t *testing.T) {
path := filepath.Join(t.TempDir(), "test.db")
f, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := io.CopyN(f, rand.Reader, int64(fileSize)); err != nil {
t.Fatal(err)
}

legacyCS := make([]Checksum, nPages)
legacyLastPage, legacyErr := legacyChecksumPages(path, pageSize, nPages, legacyCS)
newCS := make([]Checksum, nPages)
newLastPage, newErr := ChecksumPages(path, pageSize, nPages, nWorkers, newCS)

if legacyErr != newErr {
t.Fatalf("legacy error: %v, new error: %v", legacyErr, newErr)
}
if legacyLastPage != newLastPage {
t.Fatalf("legacy last page: %d, new last page: %d", legacyLastPage, newLastPage)
}
if len(legacyCS) != len(newCS) {
t.Fatalf("legacy checksums: %d, new checksums: %d", len(legacyCS), len(newCS))
}
for i := range legacyCS {
if legacyCS[i] != newCS[i] {
t.Fatalf("mismatch at index %d: legacy: %v, new: %v", i, legacyCS[i], newCS[i])
}
}
})
}

// logic copied from litefs repo
func legacyChecksumPages(dbPath string, pageSize, nPages uint32, checksums []Checksum) (uint32, error) {
f, err := os.Open(dbPath)
if err != nil {
return 0, err
}
defer f.Close()

buf := make([]byte, pageSize)

for pgno := uint32(1); pgno <= nPages; pgno++ {
offset := int64(pgno-1) * int64(pageSize)
if _, err := readFullAt(f, buf, offset); err != nil {
return pgno - 1, err
}

checksums[pgno-1] = ChecksumPage(pgno, buf)
}

return nPages, nil
}

// copied from litefs/internal
func readFullAt(r io.ReaderAt, buf []byte, off int64) (n int, err error) {
for n < len(buf) && err == nil {
var nn int
nn, err = r.ReadAt(buf[n:], off+int64(n))
n += nn
}
if n >= len(buf) {
return n, nil
} else if n > 0 && err == io.EOF {
return n, io.ErrUnexpectedEOF
}
return n, err
}
81 changes: 0 additions & 81 deletions ltx.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"encoding/json"
"errors"
"fmt"
"hash"
"hash/crc64"
"io"
"regexp"
"strconv"
Expand Down Expand Up @@ -168,51 +166,6 @@ func (t *TXID) UnmarshalJSON(data []byte) (err error) {
return nil
}

// Checksum represents an LTX checksum.
type Checksum uint64

// ParseChecksum parses a 16-character hex string into a checksum.
func ParseChecksum(s string) (Checksum, error) {
if len(s) != 16 {
return 0, fmt.Errorf("invalid formatted checksum length: %q", s)
}
v, err := strconv.ParseUint(s, 16, 64)
if err != nil {
return 0, fmt.Errorf("invalid checksum format: %q", s)
}
return Checksum(v), nil
}

// String returns c formatted as a fixed-width hex number.
func (c Checksum) String() string {
return fmt.Sprintf("%016x", uint64(c))
}

func (c Checksum) MarshalJSON() ([]byte, error) {
return []byte(`"` + c.String() + `"`), nil
}

func (c *Checksum) UnmarshalJSON(data []byte) (err error) {
var s *string
if err := json.Unmarshal(data, &s); err != nil {
return fmt.Errorf("cannot unmarshal checksum from JSON value")
}

// Set to zero if value is nil.
if s == nil {
*c = 0
return nil
}

chksum, err := ParseChecksum(*s)
if err != nil {
return fmt.Errorf("cannot parse checksum from JSON string: %q", *s)
}
*c = Checksum(chksum)

return nil
}

// Header flags.
const (
HeaderFlagMask = uint32(0x00000001)
Expand Down Expand Up @@ -474,40 +427,6 @@ func (h *PageHeader) UnmarshalBinary(b []byte) error {
return nil
}

// NewHasher returns a new CRC64-ISO hasher.
func NewHasher() hash.Hash64 {
return crc64.New(crc64.MakeTable(crc64.ISO))
}

// ChecksumPage returns a CRC64 checksum that combines the page number & page data.
func ChecksumPage(pgno uint32, data []byte) Checksum {
return ChecksumPageWithHasher(NewHasher(), pgno, data)
}

// ChecksumPageWithHasher returns a CRC64 checksum that combines the page number & page data.
func ChecksumPageWithHasher(h hash.Hash64, pgno uint32, data []byte) Checksum {
h.Reset()
_ = binary.Write(h, binary.BigEndian, pgno)
_, _ = h.Write(data)
return ChecksumFlag | Checksum(h.Sum64())
}

// ChecksumReader reads an entire database file from r and computes its rolling checksum.
func ChecksumReader(r io.Reader, pageSize int) (Checksum, error) {
data := make([]byte, pageSize)

var chksum Checksum
for pgno := uint32(1); ; pgno++ {
if _, err := io.ReadFull(r, data); err == io.EOF {
break
} else if err != nil {
return chksum, err
}
chksum = ChecksumFlag | (chksum ^ ChecksumPage(pgno, data))
}
return chksum, nil
}

// ParseFilename parses a transaction range from an LTX file.
func ParseFilename(name string) (minTXID, maxTXID TXID, err error) {
a := filenameRegex.FindStringSubmatch(name)
Expand Down

0 comments on commit 67dadff

Please sign in to comment.