diff --git a/hare3/hare_test.go b/hare3/hare_test.go index dd8139ab12..0df5b7f726 100644 --- a/hare3/hare_test.go +++ b/hare3/hare_test.go @@ -148,8 +148,8 @@ func (n *node) reuseSigner(signer *signing.EdSigner) *node { return n } -func (n *node) withDb() *node { - n.db = sql.InMemory() +func (n *node) withDb(tb testing.TB) *node { + n.db = sql.InMemoryTest(tb) n.atxsdata = atxsdata.New() n.proposals = store.New() return n @@ -342,7 +342,7 @@ func (cl *lockstepCluster) addActive(n int) *lockstepCluster { for i := last; i < last+n; i++ { cl.addNode((&node{t: cl.t, i: i}). withController().withSyncer().withPublisher(). - withClock().withDb().withSigner().withAtx(cl.units.min, cl.units.max). + withClock().withDb(cl.t).withSigner().withAtx(cl.units.min, cl.units.max). withOracle().withHare()) } return cl @@ -353,7 +353,7 @@ func (cl *lockstepCluster) addInactive(n int) *lockstepCluster { for i := last; i < last+n; i++ { cl.addNode((&node{t: cl.t, i: i}). withController().withSyncer().withPublisher(). - withClock().withDb().withSigner(). + withClock().withDb(cl.t).withSigner(). withOracle().withHare()) } return cl @@ -366,7 +366,7 @@ func (cl *lockstepCluster) addEquivocators(n int) *lockstepCluster { cl.addNode((&node{t: cl.t, i: i}). reuseSigner(cl.nodes[i-last].signer). withController().withSyncer().withPublisher(). - withClock().withDb().withAtx(cl.units.min, cl.units.max). + withClock().withDb(cl.t).withAtx(cl.units.min, cl.units.max). withOracle().withHare()) } return cl diff --git a/hare3/malfeasance_test.go b/hare3/malfeasance_test.go index 0f2a0f1491..7ff1cf52b1 100644 --- a/hare3/malfeasance_test.go +++ b/hare3/malfeasance_test.go @@ -26,7 +26,7 @@ type testMalfeasanceHandler struct { } func newTestMalfeasanceHandler(tb testing.TB) *testMalfeasanceHandler { - db := sql.InMemory() + db := sql.InMemoryTest(tb) observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { diff --git a/hare4/eligibility/oracle_test.go b/hare4/eligibility/oracle_test.go index d6ae2e4a58..5dfdf877b6 100644 --- a/hare4/eligibility/oracle_test.go +++ b/hare4/eligibility/oracle_test.go @@ -53,7 +53,7 @@ type testOracle struct { } func defaultOracle(tb testing.TB) *testOracle { - db := sql.InMemory() + db := sql.InMemoryTest(tb) atxsdata := atxsdata.New() ctrl := gomock.NewController(tb) diff --git a/hare4/hare_test.go b/hare4/hare_test.go index 70f363a8f6..498a80d23f 100644 --- a/hare4/hare_test.go +++ b/hare4/hare_test.go @@ -159,8 +159,8 @@ func (n *node) reuseSigner(signer *signing.EdSigner) *node { return n } -func (n *node) withDb() *node { - n.db = sql.InMemory() +func (n *node) withDb(tb testing.TB) *node { + n.db = sql.InMemoryTest(tb) n.atxsdata = atxsdata.New() n.proposals = store.New() return n @@ -391,7 +391,7 @@ func (cl *lockstepCluster) addActive(n int) *lockstepCluster { for i := last; i < last+n; i++ { nn := (&node{t: cl.t, i: i}). withController().withSyncer().withPublisher(). - withClock().withDb().withSigner().withAtx(cl.units.min, cl.units.max). + withClock().withDb(cl.t).withSigner().withAtx(cl.units.min, cl.units.max). withStreamRequester().withOracle().withHare() if cl.mockVerify { nn = nn.withVerifier() @@ -406,7 +406,7 @@ func (cl *lockstepCluster) addInactive(n int) *lockstepCluster { for i := last; i < last+n; i++ { cl.addNode((&node{t: cl.t, i: i}). withController().withSyncer().withPublisher(). - withClock().withDb().withSigner(). + withClock().withDb(cl.t).withSigner(). withStreamRequester().withOracle().withHare()) } return cl @@ -419,7 +419,7 @@ func (cl *lockstepCluster) addEquivocators(n int) *lockstepCluster { cl.addNode((&node{t: cl.t, i: i}). reuseSigner(cl.nodes[i-last].signer). withController().withSyncer().withPublisher(). - withClock().withDb().withAtx(cl.units.min, cl.units.max). + withClock().withDb(cl.t).withAtx(cl.units.min, cl.units.max). withStreamRequester().withOracle().withHare()) } return cl diff --git a/sql/database.go b/sql/database.go index 5f86f0bda5..e20a393b63 100644 --- a/sql/database.go +++ b/sql/database.go @@ -11,6 +11,7 @@ import ( "strings" "sync" "sync/atomic" + "testing" "time" sqlite "github.com/go-llsqlite/crawshaw" @@ -176,6 +177,7 @@ func WithQueryCacheSizes(sizes map[QueryCacheKind]int) Opt { type Opt func(c *conf) // InMemory database for testing. +// Please use InMemoryTest for automatic closing of the returned db during `tb.Cleanup`. func InMemory(opts ...Opt) *Database { opts = append(opts, WithConnections(1)) db, err := Open("file::memory:?mode=memory", opts...) @@ -185,6 +187,17 @@ func InMemory(opts ...Opt) *Database { return db } +// InMemoryTest returns an in-mem database for testing and ensures database is closed during `tb.Cleanup`. +func InMemoryTest(tb testing.TB, opts ...Opt) *Database { + opts = append(opts, WithConnections(1)) + db, err := Open("file::memory:?mode=memory", opts...) + if err != nil { + panic(err) + } + tb.Cleanup(func() { db.Close() }) + return db +} + // Open database with options. // // Database is opened in WAL mode and pragma synchronous=normal.