diff --git a/.changeset/added_a_test_for_migrations_to_ensure_consistency_between_database_schemas.md b/.changeset/added_a_test_for_migrations_to_ensure_consistency_between_database_schemas.md new file mode 100644 index 0000000..87d2e57 --- /dev/null +++ b/.changeset/added_a_test_for_migrations_to_ensure_consistency_between_database_schemas.md @@ -0,0 +1,5 @@ +--- +default: patch +--- + +# Added a test for migrations to ensure consistency between database schemas diff --git a/persist/sqlite/init.go b/persist/sqlite/init.go index 2949588..95f39a5 100644 --- a/persist/sqlite/init.go +++ b/persist/sqlite/init.go @@ -34,31 +34,28 @@ func (s *Store) initNewDatabase(target int64) error { } func (s *Store) upgradeDatabase(current, target int64) error { - log := s.log.Named("migrations") - log.Info("migrating database", zap.Int64("current", current), zap.Int64("target", target)) - - return s.transaction(func(tx *txn) error { - // defer foreign key constraints until commit - if _, err := tx.Exec("PRAGMA defer_foreign_keys=ON"); err != nil { - return fmt.Errorf("failed to enable foreign key deferral: %w", err) - } - - for _, fn := range migrations[current-1:] { - current++ - start := time.Now() - if err := fn(tx, log.With(zap.Int64("version", current))); err != nil { - return fmt.Errorf("failed to migrate database to version %v: %w", current, err) - } - // check that no foreign key constraints were violated - if err := tx.QueryRow("PRAGMA foreign_key_check").Scan(); !errors.Is(err, sql.ErrNoRows) { - return fmt.Errorf("foreign key constraints are not satisfied") + log := s.log.Named("migrations").With(zap.Int64("target", target)) + for ; current < target; current++ { + version := current + 1 // initial schema is version 1, migration 0 is version 2, etc. + log := log.With(zap.Int64("version", version)) + start := time.Now() + fn := migrations[current-1] + err := s.transaction(func(tx *txn) error { + if _, err := tx.Exec("PRAGMA defer_foreign_keys=ON"); err != nil { + return fmt.Errorf("failed to enable foreign key deferral: %w", err) + } else if err := fn(tx, log); err != nil { + return err + } else if err := foreignKeyCheck(tx, log); err != nil { + return fmt.Errorf("failed foreign key check: %w", err) } - log.Debug("migration complete", zap.Int64("current", current), zap.Int64("target", target), zap.Duration("elapsed", time.Since(start))) + return setDBVersion(tx, version) + }) + if err != nil { + return fmt.Errorf("migration %d failed: %w", version, err) } - - // set the final database version - return setDBVersion(tx, target) - }) + log.Info("migration complete", zap.Duration("elapsed", time.Since(start))) + } + return nil } func (s *Store) init() error { @@ -77,3 +74,30 @@ func (s *Store) init() error { // nothing to do return nil } + +func foreignKeyCheck(txn *txn, log *zap.Logger) error { + rows, err := txn.Query("PRAGMA foreign_key_check") + if err != nil { + return fmt.Errorf("failed to run foreign key check: %w", err) + } + defer rows.Close() + var hasErrors bool + for rows.Next() { + var table string + var rowid sql.NullInt64 + var fkTable string + var fkRowid sql.NullInt64 + + if err := rows.Scan(&table, &rowid, &fkTable, &fkRowid); err != nil { + return fmt.Errorf("failed to scan foreign key check result: %w", err) + } + hasErrors = true + log.Error("foreign key constraint violated", zap.String("table", table), zap.Int64("rowid", rowid.Int64), zap.String("fkTable", fkTable), zap.Int64("fkRowid", fkRowid.Int64)) + } + if err := rows.Err(); err != nil { + return fmt.Errorf("failed to iterate foreign key check results: %w", err) + } else if hasErrors { + return errors.New("foreign key constraint violated") + } + return nil +} diff --git a/persist/sqlite/migrations_test.go b/persist/sqlite/migrations_test.go new file mode 100644 index 0000000..91ef642 --- /dev/null +++ b/persist/sqlite/migrations_test.go @@ -0,0 +1,323 @@ +package sqlite + +import ( + "database/sql" + "fmt" + "path/filepath" + "testing" + + "go.sia.tech/core/types" + "go.uber.org/zap" + "go.uber.org/zap/zaptest" +) + +// nolint:misspell +const initialSchema = `CREATE TABLE chain_indices ( + id INTEGER PRIMARY KEY, + block_id BLOB UNIQUE NOT NULL, + height INTEGER UNIQUE NOT NULL +); +CREATE INDEX chain_indices_height ON chain_indices (block_id, height); + +CREATE TABLE sia_addresses ( + id INTEGER PRIMARY KEY, + sia_address BLOB UNIQUE NOT NULL, + siacoin_balance BLOB NOT NULL, + immature_siacoin_balance BLOB NOT NULL, + siafund_balance INTEGER NOT NULL +); + +CREATE TABLE siacoin_elements ( + id BLOB PRIMARY KEY, + siacoin_value BLOB NOT NULL, + merkle_proof BLOB NOT NULL, + leaf_index INTEGER NOT NULL, + maturity_height INTEGER NOT NULL, /* stored as int64 for easier querying */ + address_id INTEGER NOT NULL REFERENCES sia_addresses (id), + matured BOOLEAN NOT NULL, /* tracks whether the value has been added to the address balance */ + chain_index_id INTEGER NOT NULL REFERENCES chain_indices (id), + spent_index_id INTEGER REFERENCES chain_indices (id) /* soft delete */ +); +CREATE INDEX siacoin_elements_address_id ON siacoin_elements (address_id); +CREATE INDEX siacoin_elements_maturity_height_matured ON siacoin_elements (maturity_height, matured); +CREATE INDEX siacoin_elements_chain_index_id ON siacoin_elements (chain_index_id); +CREATE INDEX siacoin_elements_spent_index_id ON siacoin_elements (spent_index_id); +CREATE INDEX siacoin_elements_address_id_spent_index_id ON siacoin_elements(address_id, spent_index_id); + +CREATE TABLE siafund_elements ( + id BLOB PRIMARY KEY, + claim_start BLOB NOT NULL, + merkle_proof BLOB NOT NULL, + leaf_index INTEGER NOT NULL, + siafund_value INTEGER NOT NULL, + address_id INTEGER NOT NULL REFERENCES sia_addresses (id), + chain_index_id INTEGER NOT NULL REFERENCES chain_indices (id), + spent_index_id INTEGER REFERENCES chain_indices (id) /* soft delete */ +); +CREATE INDEX siafund_elements_address_id ON siafund_elements (address_id); +CREATE INDEX siafund_elements_chain_index_id ON siafund_elements (chain_index_id); +CREATE INDEX siafund_elements_spent_index_id ON siafund_elements (spent_index_id); +CREATE INDEX siafund_elements_address_id_spent_index_id ON siafund_elements(address_id, spent_index_id); + +CREATE TABLE state_tree ( + row INTEGER, + column INTEGER, + value BLOB NOT NULL, + PRIMARY KEY (row, column) +); + +CREATE TABLE events ( + id INTEGER PRIMARY KEY, + chain_index_id INTEGER NOT NULL REFERENCES chain_indices (id), + event_id BLOB UNIQUE NOT NULL, + maturity_height INTEGER NOT NULL, + date_created INTEGER NOT NULL, + event_type TEXT NOT NULL, + event_data BLOB NOT NULL +); +CREATE INDEX events_chain_index_id ON events (chain_index_id); + +CREATE TABLE event_addresses ( + event_id INTEGER NOT NULL REFERENCES events (id) ON DELETE CASCADE, + address_id INTEGER NOT NULL REFERENCES sia_addresses (id), + PRIMARY KEY (event_id, address_id) +); +CREATE INDEX event_addresses_event_id_idx ON event_addresses (event_id); +CREATE INDEX event_addresses_address_id_idx ON event_addresses (address_id); + +CREATE TABLE wallets ( + id INTEGER PRIMARY KEY, + friendly_name TEXT NOT NULL, + description TEXT NOT NULL, + date_created INTEGER NOT NULL, + last_updated INTEGER NOT NULL, + extra_data BLOB +); + +CREATE TABLE wallet_addresses ( + wallet_id INTEGER NOT NULL REFERENCES wallets (id), + address_id INTEGER NOT NULL REFERENCES sia_addresses (id), + description TEXT NOT NULL, + spend_policy BLOB, + extra_data BLOB, + UNIQUE (wallet_id, address_id) +); +CREATE INDEX wallet_addresses_wallet_id ON wallet_addresses (wallet_id); +CREATE INDEX wallet_addresses_address_id ON wallet_addresses (address_id); + +CREATE TABLE syncer_peers ( + peer_address TEXT PRIMARY KEY NOT NULL, + first_seen INTEGER NOT NULL +); + +CREATE TABLE syncer_bans ( + net_cidr TEXT PRIMARY KEY NOT NULL, + expiration INTEGER NOT NULL, + reason TEXT NOT NULL +); +CREATE INDEX syncer_bans_expiration_index ON syncer_bans (expiration); + +CREATE TABLE global_settings ( + id INTEGER PRIMARY KEY NOT NULL DEFAULT 0 CHECK (id = 0), -- enforce a single row + db_version INTEGER NOT NULL, -- used for migrations + index_mode INTEGER, -- the mode of the data store + last_indexed_tip BLOB NOT NULL, -- the last chain index that was processed + element_num_leaves INTEGER NOT NULL -- the number of leaves in the state tree +);` + +func TestMigrationConsistency(t *testing.T) { + fp := filepath.Join(t.TempDir(), "hostd.sqlite3") + db, err := sql.Open("sqlite3", sqliteFilepath(fp)) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + if _, err := db.Exec(initialSchema); err != nil { + t.Fatal(err) + } + + // initialize the settings table + _, err = db.Exec(`INSERT INTO global_settings (id, db_version, index_mode, element_num_leaves, last_indexed_tip) VALUES (0, 1, 0, 0, ?)`, encode(types.ChainIndex{})) + if err != nil { + t.Fatal(err) + } + + if err := db.Close(); err != nil { + t.Fatal(err) + } + + expectedVersion := int64(len(migrations) + 1) + log := zaptest.NewLogger(t) + store, err := OpenDatabase(fp, log) + if err != nil { + t.Fatal(err) + } + defer store.Close() + v := getDBVersion(store.db) + if v != expectedVersion { + t.Fatalf("expected version %d, got %d", expectedVersion, v) + } else if err := store.Close(); err != nil { + t.Fatal(err) + } + + // ensure the database does not change version when opened again + store, err = OpenDatabase(fp, log) + if err != nil { + t.Fatal(err) + } + defer store.Close() + v = getDBVersion(store.db) + if v != expectedVersion { + t.Fatalf("expected version %d, got %d", expectedVersion, v) + } + + fp2 := filepath.Join(t.TempDir(), "hostd.sqlite3") + baseline, err := OpenDatabase(fp2, zap.NewNop()) + if err != nil { + t.Fatal(err) + } + defer baseline.Close() + + getTableIndices := func(db *sql.DB) (map[string]bool, error) { + const query = `SELECT name, tbl_name, sql FROM sqlite_schema WHERE type='index'` + rows, err := db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + + indices := make(map[string]bool) + for rows.Next() { + var name, table string + var sqlStr sql.NullString // auto indices have no sql + if err := rows.Scan(&name, &table, &sqlStr); err != nil { + return nil, err + } + indices[fmt.Sprintf("%s.%s.%s", name, table, sqlStr.String)] = true + } + if err := rows.Err(); err != nil { + return nil, err + } + return indices, nil + } + + // ensure the migrated database has the same indices as the baseline + baselineIndices, err := getTableIndices(baseline.db) + if err != nil { + t.Fatal(err) + } + + migratedIndices, err := getTableIndices(store.db) + if err != nil { + t.Fatal(err) + } + + for k := range baselineIndices { + if !migratedIndices[k] { + t.Errorf("missing index %s", k) + } + } + + for k := range migratedIndices { + if !baselineIndices[k] { + t.Errorf("unexpected index %s", k) + } + } + + getTables := func(db *sql.DB) (map[string]bool, error) { + const query = `SELECT name FROM sqlite_schema WHERE type='table'` + rows, err := db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + + tables := make(map[string]bool) + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + tables[name] = true + } + if err := rows.Err(); err != nil { + return nil, err + } + return tables, nil + } + + // ensure the migrated database has the same tables as the baseline + baselineTables, err := getTables(baseline.db) + if err != nil { + t.Fatal(err) + } + + migratedTables, err := getTables(store.db) + if err != nil { + t.Fatal(err) + } + + for k := range baselineTables { + if !migratedTables[k] { + t.Errorf("missing table %s", k) + } + } + for k := range migratedTables { + if !baselineTables[k] { + t.Errorf("unexpected table %s", k) + } + } + + // ensure each table has the same columns as the baseline + getTableColumns := func(db *sql.DB, table string) (map[string]bool, error) { + query := fmt.Sprintf(`PRAGMA table_info(%s)`, table) // cannot use parameterized query for PRAGMA statements + rows, err := db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + + columns := make(map[string]bool) + for rows.Next() { + var cid int + var name, colType string + var defaultValue sql.NullString + var notNull bool + var primaryKey int // composite keys are indices + if err := rows.Scan(&cid, &name, &colType, ¬Null, &defaultValue, &primaryKey); err != nil { + return nil, err + } + // column ID is ignored since it may not match between the baseline and migrated databases + key := fmt.Sprintf("%s.%s.%s.%t.%d", name, colType, defaultValue.String, notNull, primaryKey) + columns[key] = true + } + if err := rows.Err(); err != nil { + return nil, err + } + return columns, nil + } + + for k := range baselineTables { + baselineColumns, err := getTableColumns(baseline.db, k) + if err != nil { + t.Fatal(err) + } + migratedColumns, err := getTableColumns(store.db, k) + if err != nil { + t.Fatal(err) + } + + for c := range baselineColumns { + if !migratedColumns[c] { + t.Errorf("missing column %s.%s", k, c) + } + } + + for c := range migratedColumns { + if !baselineColumns[c] { + t.Errorf("unexpected column %s.%s", k, c) + } + } + } +} diff --git a/persist/sqlite/store.go b/persist/sqlite/store.go index 1237413..a9dedf6 100644 --- a/persist/sqlite/store.go +++ b/persist/sqlite/store.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/mattn/go-sqlite3" "go.sia.tech/walletd/wallet" "go.uber.org/zap" "lukechampine.com/frand" @@ -105,19 +106,47 @@ func doTransaction(db *sql.DB, log *zap.Logger, fn func(tx *txn) error) error { return nil } +func integrityCheck(db *sql.DB, log *zap.Logger) error { + rows, err := db.Query("PRAGMA integrity_check") + if err != nil { + return fmt.Errorf("failed to run integrity check: %w", err) + } + defer rows.Close() + var hasErrors bool + for rows.Next() { + var result string + if err := rows.Scan(&result); err != nil { + return fmt.Errorf("failed to scan integrity check result: %w", err) + } else if result != "ok" { + log.Error("integrity check failed", zap.String("result", result)) + hasErrors = true + } + } + if err := rows.Err(); err != nil { + return fmt.Errorf("failed to iterate integrity check results: %w", err) + } else if hasErrors { + return errors.New("integrity check failed") + } + return nil +} + // OpenDatabase creates a new SQLite store and initializes the database. If the // database does not exist, it is created. func OpenDatabase(fp string, log *zap.Logger) (*Store, error) { db, err := sql.Open("sqlite3", sqliteFilepath(fp)) if err != nil { return nil, err + } else if err := integrityCheck(db, log.Named("integrity")); err != nil { + return nil, fmt.Errorf("integrity check failed: %w", err) } store := &Store{ db: db, log: log, } if err := store.init(); err != nil { - return nil, fmt.Errorf("failed to initialize database: %w", err) + return nil, err } + sqliteVersion, _, _ := sqlite3.Version() + log.Debug("database initialized", zap.String("sqliteVersion", sqliteVersion), zap.Int("schemaVersion", len(migrations)+1), zap.String("path", fp)) return store, nil }