diff --git a/activation/handler_v1.go b/activation/handler_v1.go index 273edb0a07..647cd7cdb5 100644 --- a/activation/handler_v1.go +++ b/activation/handler_v1.go @@ -240,7 +240,7 @@ func (h *HandlerV1) syntacticallyValidateDeps( if err := identities.SetMalicious(h.cdb, atx.SmesherID, encodedProof, time.Now()); err != nil { return 0, 0, nil, fmt.Errorf("adding malfeasance proof: %w", err) } - h.cdb.CacheMalfeasanceProof(atx.SmesherID, proof) + h.cdb.CacheMalfeasanceProof(atx.SmesherID, encodedProof) h.tortoise.OnMalfeasance(atx.SmesherID) return 0, 0, proof, nil } @@ -495,7 +495,7 @@ func (h *HandlerV1) storeAtx( atxs.AtxAdded(h.cdb, atx) if proof != nil { - h.cdb.CacheMalfeasanceProof(atx.SmesherID, proof) + h.cdb.CacheMalfeasanceProof(atx.SmesherID, codec.MustEncode(proof)) h.tortoise.OnMalfeasance(atx.SmesherID) } diff --git a/activation/malfeasance.go b/activation/malfeasance.go index 17cae36997..2874a96163 100644 --- a/activation/malfeasance.go +++ b/activation/malfeasance.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strconv" "github.com/prometheus/client_golang/prometheus" "github.com/spacemeshos/post/shared" @@ -44,6 +45,19 @@ func NewMalfeasanceHandler( } } +func (mh *MalfeasanceHandler) Info(data wire.ProofData) (map[string]string, error) { + ap, ok := data.(*wire.AtxProof) + if !ok { + return nil, errors.New("wrong message type for multiple ATXs") + } + return map[string]string{ + "atx1": ap.Messages[0].InnerMsg.MsgHash.String(), + "atx2": ap.Messages[1].InnerMsg.MsgHash.String(), + "publish_epoch": strconv.FormatUint(uint64(ap.Messages[0].InnerMsg.PublishEpoch), 10), + "smesher_id": ap.Messages[0].SmesherID.String(), + }, nil +} + func (mh *MalfeasanceHandler) Validate(ctx context.Context, data wire.ProofData) (types.NodeID, error) { ap, ok := data.(*wire.AtxProof) if !ok { @@ -109,6 +123,18 @@ func NewInvalidPostIndexHandler( } } +func (mh *InvalidPostIndexHandler) Info(data wire.ProofData) (map[string]string, error) { + pp, ok := data.(*wire.InvalidPostIndexProof) + if !ok { + return nil, errors.New("wrong message type for invalid post index") + } + return map[string]string{ + "atx": pp.Atx.ID().String(), + "index": strconv.FormatUint(uint64(pp.InvalidIdx), 10), + "smesher_id": pp.Atx.SmesherID.String(), + }, nil +} + func (mh *InvalidPostIndexHandler) Validate(ctx context.Context, data wire.ProofData) (types.NodeID, error) { proof, ok := data.(*wire.InvalidPostIndexProof) if !ok { @@ -174,6 +200,19 @@ func NewInvalidPrevATXHandler( } } +func (mh *InvalidPrevATXHandler) Info(data wire.ProofData) (map[string]string, error) { + pp, ok := data.(*wire.InvalidPrevATXProof) + if !ok { + return nil, errors.New("wrong message type for invalid previous ATX") + } + return map[string]string{ + "atx1": pp.Atx1.ID().String(), + "atx2": pp.Atx2.ID().String(), + "prev_atx": pp.Atx1.PrevATXID.String(), + "smesher_id": pp.Atx1.SmesherID.String(), + }, nil +} + func (mh *InvalidPrevATXHandler) Validate(ctx context.Context, data wire.ProofData) (types.NodeID, error) { proof, ok := data.(*wire.InvalidPrevATXProof) if !ok { diff --git a/api/grpcserver/activation_service_test.go b/api/grpcserver/activation_service_test.go index 756501ac85..f35f48dd69 100644 --- a/api/grpcserver/activation_service_test.go +++ b/api/grpcserver/activation_service_test.go @@ -14,6 +14,7 @@ import ( "google.golang.org/protobuf/types/known/emptypb" "github.com/spacemeshos/go-spacemesh/api/grpcserver" + "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" "github.com/spacemeshos/go-spacemesh/sql" @@ -169,7 +170,7 @@ func TestGet_IdentityCanceled(t *testing.T) { } atx.SetID(id) atxProvider.EXPECT().GetAtx(id).Return(&atx, nil) - atxProvider.EXPECT().MalfeasanceProof(smesher).Return(proof, nil) + atxProvider.EXPECT().MalfeasanceProof(smesher).Return(codec.MustEncode(proof), nil) atxProvider.EXPECT().Previous(id).Return([]types.ATXID{previous}, nil) response, err := activationService.Get(context.Background(), &pb.GetRequest{Id: id.Bytes()}) @@ -184,5 +185,5 @@ func TestGet_IdentityCanceled(t *testing.T) { require.Equal(t, previous.Bytes(), response.Atx.PreviousAtxs[0].Id) require.Equal(t, atx.NumUnits, response.Atx.NumUnits) require.Equal(t, atx.Sequence, response.Atx.Sequence) - require.Equal(t, events.ToMalfeasancePB(smesher, proof, false), response.MalfeasanceProof) + require.Equal(t, events.ToMalfeasancePB(smesher, codec.MustEncode(proof), false), response.MalfeasanceProof) } diff --git a/api/grpcserver/config.go b/api/grpcserver/config.go index 564d9184f4..801f07a980 100644 --- a/api/grpcserver/config.go +++ b/api/grpcserver/config.go @@ -49,6 +49,8 @@ const ( TransactionV2Alpha1 Service = "transaction_v2alpha1" TransactionStreamV2Alpha1 Service = "transaction_stream_v2alpha1" AccountV2Alpha1 Service = "account_v2alpha1" + MalfeasanceV2Alpha1 Service = "malfeasance_v2alpha1" + MalfeasanceStreamV2Alpha1 Service = "malfeasance_stream_v2alpha1" ) // DefaultConfig defines the default configuration options for api. @@ -57,12 +59,13 @@ func DefaultConfig() Config { PublicServices: []Service{ GlobalState, Mesh, Transaction, Node, Activation, ActivationV2Alpha1, RewardV2Alpha1, NetworkV2Alpha1, NodeV2Alpha1, LayerV2Alpha1, TransactionV2Alpha1, - AccountV2Alpha1, + AccountV2Alpha1, MalfeasanceV2Alpha1, }, PublicListener: "0.0.0.0:9092", PrivateServices: []Service{ Admin, Smesher, Debug, ActivationStreamV2Alpha1, RewardStreamV2Alpha1, LayerStreamV2Alpha1, TransactionStreamV2Alpha1, + MalfeasanceStreamV2Alpha1, }, PrivateListener: "127.0.0.1:9093", PostServices: []Service{Post, PostInfo}, diff --git a/api/grpcserver/globalstate_service_test.go b/api/grpcserver/globalstate_service_test.go index 8df8fc6425..9469468ad3 100644 --- a/api/grpcserver/globalstate_service_test.go +++ b/api/grpcserver/globalstate_service_test.go @@ -79,7 +79,7 @@ func TestGlobalStateService(t *testing.T) { _, err := c.AccountDataQuery(ctx, &pb.AccountDataQueryRequest{}) require.Error(t, err) - require.Contains(t, err.Error(), "`Filter` must be provided") + require.ErrorContains(t, err, "`Filter` must be provided") }) t.Run("AccountDataQuery_MissingFlags", func(t *testing.T) { t.Parallel() @@ -91,7 +91,7 @@ func TestGlobalStateService(t *testing.T) { }, }) require.Error(t, err) - require.Contains(t, err.Error(), "`Filter.AccountMeshDataFlags` must set at least one") + require.ErrorContains(t, err, "`Filter.AccountMeshDataFlags` must set at least one") }) t.Run("AccountDataQuery_BadOffset", func(t *testing.T) { t.Parallel() diff --git a/api/grpcserver/interface.go b/api/grpcserver/interface.go index d8a4b5f5bb..7b513b9c7e 100644 --- a/api/grpcserver/interface.go +++ b/api/grpcserver/interface.go @@ -9,7 +9,6 @@ import ( "github.com/spacemeshos/go-spacemesh/activation" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/p2p/peerinfo" "github.com/spacemeshos/go-spacemesh/signing" @@ -58,7 +57,7 @@ type atxProvider interface { GetAtx(id types.ATXID) (*types.ActivationTx, error) Previous(id types.ATXID) ([]types.ATXID, error) MaxHeightAtx() (types.ATXID, error) - MalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, error) + MalfeasanceProof(id types.NodeID) ([]byte, error) } type postState interface { diff --git a/api/grpcserver/mesh_service.go b/api/grpcserver/mesh_service.go index 7d08ff4037..a256cf1f19 100644 --- a/api/grpcserver/mesh_service.go +++ b/api/grpcserver/mesh_service.go @@ -19,7 +19,6 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/events" - "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" ) @@ -634,13 +633,13 @@ func (s *MeshService) MalfeasanceStream( } // first serve those already existed locally. - if err := s.cdb.IterateMalfeasanceProofs(func(id types.NodeID, mp *wire.MalfeasanceProof) error { + if err := s.cdb.IterateMalfeasanceProofs(func(id types.NodeID, proof []byte) error { select { case <-stream.Context().Done(): return nil default: res := &pb.MalfeasanceStreamResponse{ - Proof: events.ToMalfeasancePB(id, mp, req.IncludeProof), + Proof: events.ToMalfeasancePB(id, proof, req.IncludeProof), } return stream.Send(res) } diff --git a/api/grpcserver/mesh_service_test.go b/api/grpcserver/mesh_service_test.go index 57c1fc0241..022b003e7d 100644 --- a/api/grpcserver/mesh_service_test.go +++ b/api/grpcserver/mesh_service_test.go @@ -96,9 +96,7 @@ func BallotMalfeasance(tb testing.TB, db sql.Executor) (types.NodeID, *wire.Malf Data: &bp, }, } - data, err := codec.Encode(mp) - require.NoError(tb, err) - require.NoError(tb, identities.SetMalicious(db, sig.NodeID(), data, time.Now())) + require.NoError(tb, identities.SetMalicious(db, sig.NodeID(), codec.MustEncode(mp), time.Now())) return sig.NodeID(), mp } @@ -176,7 +174,7 @@ func TestMeshService_MalfeasanceQuery(t *testing.T) { require.Equal(t, nodeID, types.BytesToNodeID(resp.Proof.SmesherId.Id)) require.EqualValues(t, layer, resp.Proof.Layer.Number) require.Equal(t, pb.MalfeasanceProof_MALFEASANCE_BALLOT, resp.Proof.Kind) - require.Equal(t, events.ToMalfeasancePB(nodeID, proof, true), resp.Proof) + require.Equal(t, events.ToMalfeasancePB(nodeID, codec.MustEncode(proof), true), resp.Proof) require.NotEmpty(t, resp.Proof.Proof) var got wire.MalfeasanceProof require.NoError(t, codec.Decode(resp.Proof.Proof, &got)) @@ -247,15 +245,17 @@ func TestMeshService_MalfeasanceStream(t *testing.T) { require.Equal(t, 10, hare) id, proof := AtxMalfeasance(t, db) - events.ReportMalfeasance(id, proof) + proofBytes := codec.MustEncode(proof) + events.ReportMalfeasance(id, proofBytes) resp, err := stream.Recv() require.NoError(t, err) - require.Equal(t, events.ToMalfeasancePB(id, proof, false), resp.Proof) + require.Equal(t, events.ToMalfeasancePB(id, proofBytes, false), resp.Proof) id, proof = BallotMalfeasance(t, db) - events.ReportMalfeasance(id, proof) + proofBytes = codec.MustEncode(proof) + events.ReportMalfeasance(id, proofBytes) resp, err = stream.Recv() require.NoError(t, err) - require.Equal(t, events.ToMalfeasancePB(id, proof, false), resp.Proof) + require.Equal(t, events.ToMalfeasancePB(id, proofBytes, false), resp.Proof) } type MeshAPIMockInstrumented struct { diff --git a/api/grpcserver/mocks.go b/api/grpcserver/mocks.go index a9eab2e4d6..494f04f1a4 100644 --- a/api/grpcserver/mocks.go +++ b/api/grpcserver/mocks.go @@ -18,7 +18,6 @@ import ( multiaddr "github.com/multiformats/go-multiaddr" activation "github.com/spacemeshos/go-spacemesh/activation" types "github.com/spacemeshos/go-spacemesh/common/types" - wire "github.com/spacemeshos/go-spacemesh/malfeasance/wire" p2p "github.com/spacemeshos/go-spacemesh/p2p" peerinfo "github.com/spacemeshos/go-spacemesh/p2p/peerinfo" signing "github.com/spacemeshos/go-spacemesh/signing" @@ -913,10 +912,10 @@ func (c *MockatxProviderGetAtxCall) DoAndReturn(f func(types.ATXID) (*types.Acti } // MalfeasanceProof mocks base method. -func (m *MockatxProvider) MalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, error) { +func (m *MockatxProvider) MalfeasanceProof(id types.NodeID) ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "MalfeasanceProof", id) - ret0, _ := ret[0].(*wire.MalfeasanceProof) + ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -934,19 +933,19 @@ type MockatxProviderMalfeasanceProofCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockatxProviderMalfeasanceProofCall) Return(arg0 *wire.MalfeasanceProof, arg1 error) *MockatxProviderMalfeasanceProofCall { +func (c *MockatxProviderMalfeasanceProofCall) Return(arg0 []byte, arg1 error) *MockatxProviderMalfeasanceProofCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockatxProviderMalfeasanceProofCall) Do(f func(types.NodeID) (*wire.MalfeasanceProof, error)) *MockatxProviderMalfeasanceProofCall { +func (c *MockatxProviderMalfeasanceProofCall) Do(f func(types.NodeID) ([]byte, error)) *MockatxProviderMalfeasanceProofCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockatxProviderMalfeasanceProofCall) DoAndReturn(f func(types.NodeID) (*wire.MalfeasanceProof, error)) *MockatxProviderMalfeasanceProofCall { +func (c *MockatxProviderMalfeasanceProofCall) DoAndReturn(f func(types.NodeID) ([]byte, error)) *MockatxProviderMalfeasanceProofCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/api/grpcserver/v2alpha1/account_test.go b/api/grpcserver/v2alpha1/account_test.go index 3b80a98969..29037ba138 100644 --- a/api/grpcserver/v2alpha1/account_test.go +++ b/api/grpcserver/v2alpha1/account_test.go @@ -27,7 +27,7 @@ type testAccount struct { } func TestAccountService_List(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) ctrl, ctx := gomock.WithContext(context.Background(), t) conState := NewMockaccountConState(ctrl) diff --git a/api/grpcserver/v2alpha1/activation.go b/api/grpcserver/v2alpha1/activation.go index d87e6ed399..9fe94e6bab 100644 --- a/api/grpcserver/v2alpha1/activation.go +++ b/api/grpcserver/v2alpha1/activation.go @@ -67,30 +67,14 @@ func (s *ActivationStreamService) Stream( } } - dbChan := make(chan *types.ActivationTx, 100) - errChan := make(chan error, 1) - ops, err := toAtxOperations(toAtxRequest(request)) if err != nil { return status.Error(codes.InvalidArgument, err.Error()) } - // send db data to chan to avoid buffer overflow - go func() { - defer close(dbChan) - if err := atxs.IterateAtxsOps(s.db, ops, func(atx *types.ActivationTx) bool { - select { - case dbChan <- atx: - return true - case <-ctx.Done(): - // exit if the stream context is canceled - return false - } - }); err != nil { - errChan <- status.Error(codes.Internal, err.Error()) - return - } - }() + ctx, cancel := context.WithCancel(stream.Context()) + defer cancel() + dbChan, errChan := s.fetchFromDB(ctx, ops) var eventsOut <-chan events.ActivationTx var eventsFull <-chan struct{} @@ -145,6 +129,31 @@ func (s *ActivationStreamService) Stream( } } +func (s *ActivationStreamService) fetchFromDB( + ctx context.Context, + ops builder.Operations, +) (<-chan *types.ActivationTx, <-chan error) { + dbChan := make(chan *types.ActivationTx) + errChan := make(chan error, 1) // buffered to avoid blocking, routine should exit immediately after sending an error + + go func() { + defer close(dbChan) + if err := atxs.IterateAtxsOps(s.db, ops, func(atx *types.ActivationTx) bool { + select { + case dbChan <- atx: + return true + case <-ctx.Done(): + // exit if the context is canceled + return false + } + }); err != nil { + errChan <- status.Error(codes.Internal, err.Error()) + } + }() + + return dbChan, errChan +} + func toAtx(atx *types.ActivationTx) *spacemeshv2alpha1.Activation { return &spacemeshv2alpha1.Activation{ Id: atx.ID().Bytes(), diff --git a/api/grpcserver/v2alpha1/activation_test.go b/api/grpcserver/v2alpha1/activation_test.go index 9153abdba4..d26ad87603 100644 --- a/api/grpcserver/v2alpha1/activation_test.go +++ b/api/grpcserver/v2alpha1/activation_test.go @@ -16,31 +16,34 @@ import ( "github.com/spacemeshos/go-spacemesh/common/fixture" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestActivationService_List(t *testing.T) { - db := statesql.InMemory() - ctx := context.Background() + setup := func(t *testing.T) (spacemeshv2alpha1.ActivationServiceClient, []types.ActivationTx) { + db := statesql.InMemoryTest(t) - gen := fixture.NewAtxsGenerator() - activations := make([]types.ActivationTx, 100) - for i := range activations { - atx := gen.Next() - require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) - activations[i] = *atx - } + gen := fixture.NewAtxsGenerator() + activations := make([]types.ActivationTx, 100) + for i := range activations { + atx := gen.Next() + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) + activations[i] = *atx + } - svc := NewActivationService(db) - cfg, cleanup := launchServer(t, svc) - t.Cleanup(cleanup) + svc := NewActivationService(db) + cfg, cleanup := launchServer(t, svc) + t.Cleanup(cleanup) - conn := dialGrpc(t, cfg) - client := spacemeshv2alpha1.NewActivationServiceClient(conn) + conn := dialGrpc(t, cfg) + return spacemeshv2alpha1.NewActivationServiceClient(conn), activations + } t.Run("limit set too high", func(t *testing.T) { - _, err := client.List(ctx, &spacemeshv2alpha1.ActivationRequest{Limit: 200}) + client, _ := setup(t) + _, err := client.List(context.Background(), &spacemeshv2alpha1.ActivationRequest{Limit: 200}) require.Error(t, err) s, ok := status.FromError(err) @@ -50,7 +53,8 @@ func TestActivationService_List(t *testing.T) { }) t.Run("no limit set", func(t *testing.T) { - _, err := client.List(ctx, &spacemeshv2alpha1.ActivationRequest{}) + client, _ := setup(t) + _, err := client.List(context.Background(), &spacemeshv2alpha1.ActivationRequest{}) require.Error(t, err) s, ok := status.FromError(err) @@ -60,7 +64,8 @@ func TestActivationService_List(t *testing.T) { }) t.Run("limit and offset", func(t *testing.T) { - list, err := client.List(ctx, &spacemeshv2alpha1.ActivationRequest{ + client, _ := setup(t) + list, err := client.List(context.Background(), &spacemeshv2alpha1.ActivationRequest{ Limit: 25, Offset: 50, }) @@ -69,13 +74,15 @@ func TestActivationService_List(t *testing.T) { }) t.Run("all", func(t *testing.T) { - list, err := client.List(ctx, &spacemeshv2alpha1.ActivationRequest{Limit: 100}) + client, activations := setup(t) + list, err := client.List(context.Background(), &spacemeshv2alpha1.ActivationRequest{Limit: 100}) require.NoError(t, err) require.Equal(t, len(activations), len(list.Activations)) }) t.Run("coinbase", func(t *testing.T) { - list, err := client.List(ctx, &spacemeshv2alpha1.ActivationRequest{ + client, activations := setup(t) + list, err := client.List(context.Background(), &spacemeshv2alpha1.ActivationRequest{ Limit: 1, Coinbase: activations[3].Coinbase.String(), }) @@ -84,7 +91,8 @@ func TestActivationService_List(t *testing.T) { }) t.Run("smesherId", func(t *testing.T) { - list, err := client.List(ctx, &spacemeshv2alpha1.ActivationRequest{ + client, activations := setup(t) + list, err := client.List(context.Background(), &spacemeshv2alpha1.ActivationRequest{ Limit: 1, SmesherId: [][]byte{activations[1].SmesherID.Bytes()}, }) @@ -93,7 +101,8 @@ func TestActivationService_List(t *testing.T) { }) t.Run("id", func(t *testing.T) { - list, err := client.List(ctx, &spacemeshv2alpha1.ActivationRequest{ + client, activations := setup(t) + list, err := client.List(context.Background(), &spacemeshv2alpha1.ActivationRequest{ Limit: 1, Id: [][]byte{activations[3].ID().Bytes()}, }) @@ -103,29 +112,30 @@ func TestActivationService_List(t *testing.T) { } func TestActivationStreamService_Stream(t *testing.T) { - db := statesql.InMemory() - ctx := context.Background() - - gen := fixture.NewAtxsGenerator() - activations := make([]types.ActivationTx, 100) - for i := range activations { - atx := gen.Next() - require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) - activations[i] = *atx - } + setup := func(t *testing.T, db sql.Executor) spacemeshv2alpha1.ActivationStreamServiceClient { + gen := fixture.NewAtxsGenerator() + activations := make([]types.ActivationTx, 100) + for i := range activations { + atx := gen.Next() + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) + activations[i] = *atx + } - svc := NewActivationStreamService(db) - cfg, cleanup := launchServer(t, svc) - t.Cleanup(cleanup) + svc := NewActivationStreamService(db) + cfg, cleanup := launchServer(t, svc) + t.Cleanup(cleanup) - conn := dialGrpc(t, cfg) - client := spacemeshv2alpha1.NewActivationStreamServiceClient(conn) + conn := dialGrpc(t, cfg) + return spacemeshv2alpha1.NewActivationStreamServiceClient(conn) + } t.Run("all", func(t *testing.T) { events.InitializeReporter() t.Cleanup(events.CloseEventReporter) - stream, err := client.Stream(ctx, &spacemeshv2alpha1.ActivationStreamRequest{}) + client := setup(t, statesql.InMemoryTest(t)) + + stream, err := client.Stream(context.Background(), &spacemeshv2alpha1.ActivationStreamRequest{}) require.NoError(t, err) var i int @@ -136,19 +146,22 @@ func TestActivationStreamService_Stream(t *testing.T) { } i++ } - require.Len(t, activations, i) + require.Equal(t, 100, i) }) t.Run("watch", func(t *testing.T) { events.InitializeReporter() t.Cleanup(events.CloseEventReporter) + db := statesql.InMemoryTest(t) + client := setup(t, db) + const ( start = 100 n = 10 ) - gen = fixture.NewAtxsGenerator().WithEpochs(start, 10) + gen := fixture.NewAtxsGenerator().WithEpochs(start, 10) var streamed []*events.ActivationTx for i := 0; i < n; i++ { atx := gen.Next() @@ -186,7 +199,7 @@ func TestActivationStreamService_Stream(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - stream, err := client.Stream(ctx, tc.request) + stream, err := client.Stream(context.Background(), tc.request) require.NoError(t, err) _, err = stream.Header() require.NoError(t, err) @@ -194,7 +207,7 @@ func TestActivationStreamService_Stream(t *testing.T) { var expect []*types.ActivationTx for _, rst := range streamed { require.NoError(t, events.ReportNewActivation(rst.ActivationTx)) - matcher := atxsMatcher{tc.request, ctx} + matcher := atxsMatcher{tc.request, context.Background()} if matcher.match(rst) { expect = append(expect, rst.ActivationTx) } @@ -211,7 +224,7 @@ func TestActivationStreamService_Stream(t *testing.T) { } func TestActivationService_ActivationsCount(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) ctx := context.Background() genEpoch3 := fixture.NewAtxsGenerator().WithEpochs(3, 1) diff --git a/api/grpcserver/v2alpha1/interface.go b/api/grpcserver/v2alpha1/interface.go new file mode 100644 index 0000000000..d0a03b528c --- /dev/null +++ b/api/grpcserver/v2alpha1/interface.go @@ -0,0 +1,7 @@ +package v2alpha1 + +//go:generate mockgen -typed -package=v2alpha1 -destination=./mocks.go -source=./interface.go + +type malfeasanceInfo interface { + Info(data []byte) (map[string]string, error) +} diff --git a/api/grpcserver/v2alpha1/layer.go b/api/grpcserver/v2alpha1/layer.go index 901bb46a28..590cca84ff 100644 --- a/api/grpcserver/v2alpha1/layer.go +++ b/api/grpcserver/v2alpha1/layer.go @@ -60,29 +60,14 @@ func (s *LayerStreamService) Stream( } } - dbChan := make(chan *spacemeshv2alpha1.Layer, 100) - errChan := make(chan error, 1) - ops, err := toLayerOperations(toLayerRequest(request)) if err != nil { return status.Error(codes.InvalidArgument, err.Error()) } - // send db data to chan to avoid buffer overflow - go func() { - defer close(dbChan) - if err := layers.IterateLayersWithBlockOps(s.db, ops, func(layer *layers.Layer) bool { - select { - case dbChan <- toLayer(layer): - return true - case <-ctx.Done(): - // exit if the stream context is canceled - return false - } - }); err != nil { - errChan <- status.Error(codes.Internal, err.Error()) - return - } - }() + + ctx, cancel := context.WithCancel(stream.Context()) + defer cancel() + dbChan, errChan := s.fetchFromDB(ctx, ops) var eventsOut <-chan events.LayerUpdate var eventsFull <-chan struct{} @@ -155,6 +140,31 @@ func (s *LayerStreamService) Stream( } } +func (s *LayerStreamService) fetchFromDB( + ctx context.Context, + ops builder.Operations, +) (<-chan *spacemeshv2alpha1.Layer, <-chan error) { + dbChan := make(chan *spacemeshv2alpha1.Layer, 100) + errChan := make(chan error, 1) // buffered to avoid blocking, routine should exit immediately after sending an error + + go func() { + defer close(dbChan) + if err := layers.IterateLayersWithBlockOps(s.db, ops, func(layer *layers.Layer) bool { + select { + case dbChan <- toLayer(layer): + return true + case <-ctx.Done(): + // exit if the stream context is canceled + return false + } + }); err != nil { + errChan <- status.Error(codes.Internal, err.Error()) + } + }() + + return dbChan, errChan +} + func toLayerRequest(filter *spacemeshv2alpha1.LayerStreamRequest) *spacemeshv2alpha1.LayerRequest { req := &spacemeshv2alpha1.LayerRequest{ StartLayer: filter.StartLayer, diff --git a/api/grpcserver/v2alpha1/layer_test.go b/api/grpcserver/v2alpha1/layer_test.go index 8c76cbeb55..9186c0efc1 100644 --- a/api/grpcserver/v2alpha1/layer_test.go +++ b/api/grpcserver/v2alpha1/layer_test.go @@ -23,29 +23,31 @@ import ( ) func TestLayerService_List(t *testing.T) { - db := statesql.InMemory() - ctx := context.Background() - - lrs := make([]layers.Layer, 100) - r1 := rand.New(rand.NewSource(time.Now().UnixNano())) - r2 := rand.New(rand.NewSource(time.Now().UnixNano() + 1)) - for i := range lrs { - processed := r1.Intn(2) == 0 - withBlock := r2.Intn(2) == 0 - l, err := generateLayer(db, types.LayerID(i), layerGenProcessed(processed), layerGenWithBlock(withBlock)) - require.NoError(t, err) - lrs[i] = *l - } + setup := func(t *testing.T) spacemeshv2alpha1.LayerServiceClient { + db := statesql.InMemoryTest(t) + + lrs := make([]layers.Layer, 90) + r1 := rand.New(rand.NewSource(time.Now().UnixNano())) + r2 := rand.New(rand.NewSource(time.Now().UnixNano() + 1)) + for i := range lrs { + processed := r1.Intn(2) == 0 + withBlock := r2.Intn(2) == 0 + l, err := generateLayer(db, types.LayerID(i), layerGenProcessed(processed), layerGenWithBlock(withBlock)) + require.NoError(t, err) + lrs[i] = *l + } - svc := NewLayerService(db) - cfg, cleanup := launchServer(t, svc) - t.Cleanup(cleanup) + svc := NewLayerService(db) + cfg, cleanup := launchServer(t, svc) + t.Cleanup(cleanup) - conn := dialGrpc(t, cfg) - client := spacemeshv2alpha1.NewLayerServiceClient(conn) + conn := dialGrpc(t, cfg) + return spacemeshv2alpha1.NewLayerServiceClient(conn) + } t.Run("limit set too high", func(t *testing.T) { - _, err := client.List(ctx, &spacemeshv2alpha1.LayerRequest{Limit: 200}) + client := setup(t) + _, err := client.List(context.Background(), &spacemeshv2alpha1.LayerRequest{Limit: 200}) require.Error(t, err) s, ok := status.FromError(err) @@ -55,7 +57,8 @@ func TestLayerService_List(t *testing.T) { }) t.Run("no limit set", func(t *testing.T) { - _, err := client.List(ctx, &spacemeshv2alpha1.LayerRequest{}) + client := setup(t) + _, err := client.List(context.Background(), &spacemeshv2alpha1.LayerRequest{}) require.Error(t, err) s, ok := status.FromError(err) @@ -65,7 +68,8 @@ func TestLayerService_List(t *testing.T) { }) t.Run("limit and offset", func(t *testing.T) { - list, err := client.List(ctx, &spacemeshv2alpha1.LayerRequest{ + client := setup(t) + list, err := client.List(context.Background(), &spacemeshv2alpha1.LayerRequest{ Limit: 25, Offset: 50, }) @@ -74,13 +78,14 @@ func TestLayerService_List(t *testing.T) { }) t.Run("all", func(t *testing.T) { - ls, err := client.List(ctx, &spacemeshv2alpha1.LayerRequest{ + client := setup(t) + ls, err := client.List(context.Background(), &spacemeshv2alpha1.LayerRequest{ StartLayer: 0, EndLayer: 100, Limit: 100, }) require.NoError(t, err) - require.Len(t, lrs, len(ls.Layers)) + require.Len(t, ls.Layers, 90) }) } @@ -99,32 +104,33 @@ func TestLayerConvertEventStatus(t *testing.T) { } func TestLayerStreamService_Stream(t *testing.T) { - db := statesql.InMemory() - ctx := context.Background() - - lrs := make([]layers.Layer, 100) - r1 := rand.New(rand.NewSource(time.Now().UnixNano())) - r2 := rand.New(rand.NewSource(time.Now().UnixNano() + 1)) - for i := range lrs { - processed := r1.Intn(2) == 0 - withBlock := r2.Intn(2) == 0 - l, err := generateLayer(db, types.LayerID(i), layerGenProcessed(processed), layerGenWithBlock(withBlock)) - require.NoError(t, err) - lrs[i] = *l - } + setup := func(t *testing.T, db sql.StateDatabase) (spacemeshv2alpha1.LayerStreamServiceClient, []layers.Layer) { + lrs := make([]layers.Layer, 100) + r1 := rand.New(rand.NewSource(time.Now().UnixNano())) + r2 := rand.New(rand.NewSource(time.Now().UnixNano() + 1)) + for i := range lrs { + processed := r1.Intn(2) == 0 + withBlock := r2.Intn(2) == 0 + l, err := generateLayer(db, types.LayerID(i), layerGenProcessed(processed), layerGenWithBlock(withBlock)) + require.NoError(t, err) + lrs[i] = *l + } - svc := NewLayerStreamService(db) - cfg, cleanup := launchServer(t, svc) - t.Cleanup(cleanup) + svc := NewLayerStreamService(db) + cfg, cleanup := launchServer(t, svc) + t.Cleanup(cleanup) - conn := dialGrpc(t, cfg) - client := spacemeshv2alpha1.NewLayerStreamServiceClient(conn) + conn := dialGrpc(t, cfg) + return spacemeshv2alpha1.NewLayerStreamServiceClient(conn), lrs + } t.Run("all", func(t *testing.T) { events.InitializeReporter() t.Cleanup(events.CloseEventReporter) - stream, err := client.Stream(ctx, &spacemeshv2alpha1.LayerStreamRequest{}) + client, lrs := setup(t, statesql.InMemoryTest(t)) + + stream, err := client.Stream(context.Background(), &spacemeshv2alpha1.LayerStreamRequest{}) require.NoError(t, err) var i int @@ -133,7 +139,7 @@ func TestLayerStreamService_Stream(t *testing.T) { if errors.Is(err, io.EOF) { break } - assert.Equal(t, toLayer(&lrs[i]).String(), l.String()) + require.Equal(t, toLayer(&lrs[i]).String(), l.String()) i++ } require.Len(t, lrs, i) @@ -143,6 +149,9 @@ func TestLayerStreamService_Stream(t *testing.T) { events.InitializeReporter() t.Cleanup(events.CloseEventReporter) + db := statesql.InMemoryTest(t) + client, _ := setup(t, db) + const ( start = 100 n = 10 @@ -169,7 +178,7 @@ func TestLayerStreamService_Stream(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - stream, err := client.Stream(ctx, tc.request) + stream, err := client.Stream(context.Background(), tc.request) require.NoError(t, err) _, err = stream.Header() require.NoError(t, err) @@ -191,7 +200,7 @@ func TestLayerStreamService_Stream(t *testing.T) { } require.NoError(t, events.ReportLayerUpdate(lu)) - matcher := layersMatcher{tc.request, ctx} + matcher := layersMatcher{tc.request, context.Background()} if matcher.match(&lu) { expect = append(expect, &rst) } diff --git a/api/grpcserver/v2alpha1/malfeasance.go b/api/grpcserver/v2alpha1/malfeasance.go new file mode 100644 index 0000000000..bd0b2d7417 --- /dev/null +++ b/api/grpcserver/v2alpha1/malfeasance.go @@ -0,0 +1,327 @@ +package v2alpha1 + +import ( + "bytes" + "context" + "errors" + "io" + "slices" + "strconv" + "time" + + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + spacemeshv2alpha1 "github.com/spacemeshos/api/release/go/spacemesh/v2alpha1" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/events" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/builder" + "github.com/spacemeshos/go-spacemesh/sql/identities" +) + +const ( + Malfeasance = "malfeasance_v2alpha1" + MalfeasanceStream = "malfeasance_stream_v2alpha1" +) + +func NewMalfeasanceService(db sql.Executor, malfeasanceHandler malfeasanceInfo) *MalfeasanceService { + return &MalfeasanceService{ + db: db, + info: malfeasanceHandler, + } +} + +type MalfeasanceService struct { + db sql.Executor + info malfeasanceInfo +} + +func (s *MalfeasanceService) RegisterService(server *grpc.Server) { + spacemeshv2alpha1.RegisterMalfeasanceServiceServer(server, s) +} + +func (s *MalfeasanceService) RegisterHandlerService(mux *runtime.ServeMux) error { + return spacemeshv2alpha1.RegisterMalfeasanceServiceHandlerServer(context.Background(), mux, s) +} + +func (s *MalfeasanceService) String() string { + return "MalfeasanceService" +} + +func (s *MalfeasanceService) List( + ctx context.Context, + request *spacemeshv2alpha1.MalfeasanceRequest, +) (*spacemeshv2alpha1.MalfeasanceList, error) { + switch { + case request.Limit > 100: + return nil, status.Error(codes.InvalidArgument, "limit is capped at 100") + case request.Limit == 0: + return nil, status.Error(codes.InvalidArgument, "limit must be set to <= 100") + } + + ops, err := toMalfeasanceOps(request) + if err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + + proofs := make([]*spacemeshv2alpha1.MalfeasanceProof, 0, request.Limit) + if err := identities.IterateMaliciousOps(s.db, ops, func(id types.NodeID, proof []byte, received time.Time) bool { + rst := toProof(ctx, s.info, id, proof) + if rst == nil { + return true + } + proofs = append(proofs, rst) + return true + }); err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + + return &spacemeshv2alpha1.MalfeasanceList{Proofs: proofs}, nil +} + +func NewMalfeasanceStreamService(db sql.Executor, malfeasanceHandler malfeasanceInfo) *MalfeasanceStreamService { + return &MalfeasanceStreamService{ + db: db, + info: malfeasanceHandler, + } +} + +type MalfeasanceStreamService struct { + db sql.Executor + info malfeasanceInfo +} + +func (s *MalfeasanceStreamService) RegisterService(server *grpc.Server) { + spacemeshv2alpha1.RegisterMalfeasanceStreamServiceServer(server, s) +} + +func (s *MalfeasanceStreamService) RegisterHandlerService(mux *runtime.ServeMux) error { + return spacemeshv2alpha1.RegisterMalfeasanceStreamServiceHandlerServer(context.Background(), mux, s) +} + +func (s *MalfeasanceStreamService) String() string { + return "MalfeasanceStreamService" +} + +func (s *MalfeasanceStreamService) Stream( + request *spacemeshv2alpha1.MalfeasanceStreamRequest, + stream spacemeshv2alpha1.MalfeasanceStreamService_StreamServer, +) error { + var sub *events.BufferedSubscription[events.EventMalfeasance] + if request.Watch { + matcher := malfeasanceMatcher{request} + var err error + sub, err = events.SubscribeMatched(matcher.match) + if err != nil { + return status.Error(codes.Internal, err.Error()) + } + defer sub.Close() + if err := stream.SendHeader(metadata.MD{}); err != nil { + return status.Errorf(codes.Unavailable, "can't send header") + } + } + + ops, err := toMalfeasanceOps(&spacemeshv2alpha1.MalfeasanceRequest{ + SmesherId: request.SmesherId, + }) + if err != nil { + return status.Error(codes.InvalidArgument, err.Error()) + } + + ctx, cancel := context.WithCancel(stream.Context()) + defer cancel() + dbChan, errChan := s.fetchFromDB(ctx, ops) + + var eventsOut <-chan events.EventMalfeasance + var eventsFull <-chan struct{} + if sub != nil { + eventsOut = sub.Out() + eventsFull = sub.Full() + } + + for { + select { + // process events first + case rst := <-eventsOut: + proof := toProof(stream.Context(), s.info, rst.Smesher, rst.Proof) + if proof == nil { + continue + } + err = stream.Send(proof) + switch { + case errors.Is(err, io.EOF): + return nil + case err != nil: + return status.Error(codes.Internal, err.Error()) + } + default: + select { + case rst := <-eventsOut: + proof := toProof(stream.Context(), s.info, rst.Smesher, rst.Proof) + if proof == nil { + continue + } + err = stream.Send(proof) + switch { + case errors.Is(err, io.EOF): + return nil + case err != nil: + return status.Error(codes.Internal, err.Error()) + } + case <-eventsFull: + return status.Error(codes.Canceled, "buffer overflow") + case rst, ok := <-dbChan: + if !ok { + dbChan = nil + if sub == nil { + return nil + } + continue + } + err = stream.Send(rst) + switch { + case errors.Is(err, io.EOF): + return nil + case err != nil: + return status.Error(codes.Internal, err.Error()) + } + case err := <-errChan: + return err + case <-stream.Context().Done(): + return nil + } + } + } +} + +func (s *MalfeasanceStreamService) fetchFromDB( + ctx context.Context, + ops builder.Operations, +) (<-chan *spacemeshv2alpha1.MalfeasanceProof, <-chan error) { + dbChan := make(chan *spacemeshv2alpha1.MalfeasanceProof) + errChan := make(chan error, 1) // buffered to avoid blocking, routine should exit immediately after sending an error + + go func() { + defer close(dbChan) + if err := identities.IterateMaliciousOps(s.db, ops, + func(id types.NodeID, proof []byte, received time.Time) bool { + rst := toProof(ctx, s.info, id, proof) + if rst == nil { + return true + } + + select { + case dbChan <- rst: + return true + case <-ctx.Done(): + // exit if the context is canceled + return false + } + }, + ); err != nil { + errChan <- status.Error(codes.Internal, err.Error()) + } + }() + return dbChan, errChan +} + +func toProof( + ctx context.Context, + info malfeasanceInfo, + id types.NodeID, + proof []byte, +) *spacemeshv2alpha1.MalfeasanceProof { + properties, err := info.Info(proof) + if err != nil { + ctxzap.Debug(ctx, "failed to get malfeasance info", + zap.String("smesher", id.String()), + zap.Error(err), + ) + return nil + } + domain, err := strconv.ParseUint(properties["domain"], 10, 64) + if err != nil { + ctxzap.Debug(ctx, "failed to parse proof domain", + zap.String("smesher", id.String()), + zap.String("domain", properties["domain"]), + zap.Error(err), + ) + return nil + } + delete(properties, "domain") + proofType, err := strconv.ParseUint(properties["type"], 10, 32) + if err != nil { + ctxzap.Debug(ctx, "failed to parse proof type", + zap.String("smesher", id.String()), + zap.String("type", properties["type"]), + zap.Error(err), + ) + return nil + } + delete(properties, "type") + return &spacemeshv2alpha1.MalfeasanceProof{ + Smesher: id.Bytes(), + Domain: spacemeshv2alpha1.MalfeasanceProof_MalfeasanceDomain(domain), + Type: uint32(proofType), + Properties: properties, + } +} + +func toMalfeasanceOps(filter *spacemeshv2alpha1.MalfeasanceRequest) (builder.Operations, error) { + ops := builder.Operations{} + ops.Filter = append(ops.Filter, builder.Op{ + Field: builder.Proof, + Token: builder.IsNotNull, + }) + ops.Modifiers = append(ops.Modifiers, builder.Modifier{ + Key: builder.OrderBy, + Value: builder.Smesher, + }) + + if filter == nil { + return ops, nil + } + + if len(filter.SmesherId) > 0 { + ops.Filter = append(ops.Filter, builder.Op{ + Field: builder.Smesher, + Token: builder.In, + Value: filter.SmesherId, + }) + } + + if filter.Limit != 0 { + ops.Modifiers = append(ops.Modifiers, builder.Modifier{ + Key: builder.Limit, + Value: int64(filter.Limit), + }) + } + if filter.Offset != 0 { + ops.Modifiers = append(ops.Modifiers, builder.Modifier{ + Key: builder.Offset, + Value: int64(filter.Offset), + }) + } + + return ops, nil +} + +type malfeasanceMatcher struct { + *spacemeshv2alpha1.MalfeasanceStreamRequest +} + +func (m *malfeasanceMatcher) match(event *events.EventMalfeasance) bool { + if len(m.SmesherId) > 0 { + idx := slices.IndexFunc(m.SmesherId, func(id []byte) bool { return bytes.Equal(id, event.Smesher.Bytes()) }) + if idx == -1 { + return false + } + } + return true +} diff --git a/api/grpcserver/v2alpha1/malfeasance_test.go b/api/grpcserver/v2alpha1/malfeasance_test.go new file mode 100644 index 0000000000..22744ad648 --- /dev/null +++ b/api/grpcserver/v2alpha1/malfeasance_test.go @@ -0,0 +1,212 @@ +package v2alpha1 + +import ( + "context" + "errors" + "fmt" + "io" + "strconv" + "testing" + "time" + + spacemeshv2alpha1 "github.com/spacemeshos/api/release/go/spacemesh/v2alpha1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/events" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +type malInfo struct { + ID types.NodeID + Proof []byte + + Properties map[string]string +} + +func TestMalfeasanceService_List(t *testing.T) { + setup := func(t *testing.T) (spacemeshv2alpha1.MalfeasanceServiceClient, []malInfo) { + db := statesql.InMemoryTest(t) + ctrl := gomock.NewController(t) + info := NewMockmalfeasanceInfo(ctrl) + + proofs := make([]malInfo, 90) + for i := range proofs { + proofs[i] = malInfo{ID: types.RandomNodeID(), Proof: types.RandomBytes(100)} + proofs[i].Properties = map[string]string{ + "domain": "0", + "type": strconv.FormatUint(uint64(i%4+1), 10), + fmt.Sprintf("key%d", i): fmt.Sprintf("value%d", i), + } + info.EXPECT().Info(proofs[i].Proof).Return(proofs[i].Properties, nil).AnyTimes() + + require.NoError(t, identities.SetMalicious(db, proofs[i].ID, proofs[i].Proof, time.Now())) + } + + svc := NewMalfeasanceService(db, info) + cfg, cleanup := launchServer(t, svc) + t.Cleanup(cleanup) + + conn := dialGrpc(t, cfg) + return spacemeshv2alpha1.NewMalfeasanceServiceClient(conn), proofs + } + + t.Run("limit set too high", func(t *testing.T) { + client, _ := setup(t) + _, err := client.List(context.Background(), &spacemeshv2alpha1.MalfeasanceRequest{Limit: 200}) + require.Error(t, err) + + s, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.InvalidArgument, s.Code()) + require.Equal(t, "limit is capped at 100", s.Message()) + }) + + t.Run("no limit set", func(t *testing.T) { + client, _ := setup(t) + _, err := client.List(context.Background(), &spacemeshv2alpha1.MalfeasanceRequest{}) + require.Error(t, err) + + s, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.InvalidArgument, s.Code()) + require.Equal(t, "limit must be set to <= 100", s.Message()) + }) + + t.Run("limit and offset", func(t *testing.T) { + client, _ := setup(t) + list, err := client.List(context.Background(), &spacemeshv2alpha1.MalfeasanceRequest{ + Limit: 25, + Offset: 50, + }) + require.NoError(t, err) + require.Len(t, list.Proofs, 25) + }) + + t.Run("all", func(t *testing.T) { + client, _ := setup(t) + list, err := client.List(context.Background(), &spacemeshv2alpha1.MalfeasanceRequest{Limit: 100}) + require.NoError(t, err) + require.Len(t, list.Proofs, 90) + }) + + t.Run("smesherId", func(t *testing.T) { + client, proofs := setup(t) + list, err := client.List(context.Background(), &spacemeshv2alpha1.MalfeasanceRequest{ + Limit: 1, + SmesherId: [][]byte{proofs[1].ID.Bytes()}, + }) + require.NoError(t, err) + require.Equal(t, proofs[1].ID.Bytes(), list.GetProofs()[0].GetSmesher()) + }) +} + +func TestMalfeasanceStreamService_Stream(t *testing.T) { + setup := func( + t *testing.T, + db sql.Executor, + info *MockmalfeasanceInfo, + ) spacemeshv2alpha1.MalfeasanceStreamServiceClient { + proofs := make([]malInfo, 90) + for i := range proofs { + proofs[i] = malInfo{ID: types.RandomNodeID(), Proof: types.RandomBytes(100)} + proofs[i].Properties = map[string]string{ + "domain": "0", + "type": strconv.FormatUint(uint64(i%4+1), 10), + fmt.Sprintf("key%d", i): fmt.Sprintf("value%d", i), + } + info.EXPECT().Info(proofs[i].Proof).Return(proofs[i].Properties, nil).AnyTimes() + + require.NoError(t, identities.SetMalicious(db, proofs[i].ID, proofs[i].Proof, time.Now())) + } + + svc := NewMalfeasanceStreamService(db, info) + cfg, cleanup := launchServer(t, svc) + t.Cleanup(cleanup) + + conn := dialGrpc(t, cfg) + return spacemeshv2alpha1.NewMalfeasanceStreamServiceClient(conn) + } + + t.Run("all", func(t *testing.T) { + events.InitializeReporter() + t.Cleanup(events.CloseEventReporter) + + ctrl := gomock.NewController(t) + info := NewMockmalfeasanceInfo(ctrl) + client := setup(t, statesql.InMemoryTest(t), info) + + stream, err := client.Stream(context.Background(), &spacemeshv2alpha1.MalfeasanceStreamRequest{}) + require.NoError(t, err) + + var i int + for { + _, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + i++ + } + require.Equal(t, 90, i) + }) + + t.Run("watch", func(t *testing.T) { + events.InitializeReporter() + t.Cleanup(events.CloseEventReporter) + + db := statesql.InMemoryTest(t) + ctrl := gomock.NewController(t) + info := NewMockmalfeasanceInfo(ctrl) + client := setup(t, db, info) + + const ( + start = 100 + n = 10 + ) + + var streamed []*events.EventMalfeasance + for i := 0; i < n; i++ { + smesher := types.RandomNodeID() + streamed = append(streamed, &events.EventMalfeasance{ + Smesher: smesher, + Proof: types.RandomBytes(100), + }) + properties := map[string]string{ + "domain": "0", + "type": strconv.FormatUint(uint64(i%4+1), 10), + fmt.Sprintf("key%d", i): fmt.Sprintf("value%d", i), + } + info.EXPECT().Info(streamed[i].Proof).Return(properties, nil).AnyTimes() + } + + request := &spacemeshv2alpha1.MalfeasanceStreamRequest{ + SmesherId: [][]byte{streamed[3].Smesher.Bytes()}, + Watch: true, + } + stream, err := client.Stream(context.Background(), request) + require.NoError(t, err) + _, err = stream.Header() + require.NoError(t, err) + + var expect []types.NodeID + for _, rst := range streamed { + events.ReportMalfeasance(rst.Smesher, rst.Proof) + matcher := malfeasanceMatcher{request} + if matcher.match(rst) { + expect = append(expect, rst.Smesher) + } + } + + for _, rst := range expect { + received, err := stream.Recv() + require.NoError(t, err) + require.Equal(t, rst.Bytes(), received.Smesher) + } + }) +} diff --git a/api/grpcserver/v2alpha1/mocks.go b/api/grpcserver/v2alpha1/mocks.go new file mode 100644 index 0000000000..d57a76b28e --- /dev/null +++ b/api/grpcserver/v2alpha1/mocks.go @@ -0,0 +1,78 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./interface.go +// +// Generated by this command: +// +// mockgen -typed -package=v2alpha1 -destination=./mocks.go -source=./interface.go +// + +// Package v2alpha1 is a generated GoMock package. +package v2alpha1 + +import ( + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockmalfeasanceInfo is a mock of malfeasanceInfo interface. +type MockmalfeasanceInfo struct { + ctrl *gomock.Controller + recorder *MockmalfeasanceInfoMockRecorder +} + +// MockmalfeasanceInfoMockRecorder is the mock recorder for MockmalfeasanceInfo. +type MockmalfeasanceInfoMockRecorder struct { + mock *MockmalfeasanceInfo +} + +// NewMockmalfeasanceInfo creates a new mock instance. +func NewMockmalfeasanceInfo(ctrl *gomock.Controller) *MockmalfeasanceInfo { + mock := &MockmalfeasanceInfo{ctrl: ctrl} + mock.recorder = &MockmalfeasanceInfoMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockmalfeasanceInfo) EXPECT() *MockmalfeasanceInfoMockRecorder { + return m.recorder +} + +// Info mocks base method. +func (m *MockmalfeasanceInfo) Info(data []byte) (map[string]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Info", data) + ret0, _ := ret[0].(map[string]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Info indicates an expected call of Info. +func (mr *MockmalfeasanceInfoMockRecorder) Info(data any) *MockmalfeasanceInfoInfoCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockmalfeasanceInfo)(nil).Info), data) + return &MockmalfeasanceInfoInfoCall{Call: call} +} + +// MockmalfeasanceInfoInfoCall wrap *gomock.Call +type MockmalfeasanceInfoInfoCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockmalfeasanceInfoInfoCall) Return(arg0 map[string]string, arg1 error) *MockmalfeasanceInfoInfoCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockmalfeasanceInfoInfoCall) Do(f func([]byte) (map[string]string, error)) *MockmalfeasanceInfoInfoCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockmalfeasanceInfoInfoCall) DoAndReturn(f func([]byte) (map[string]string, error)) *MockmalfeasanceInfoInfoCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/api/grpcserver/v2alpha1/reward.go b/api/grpcserver/v2alpha1/reward.go index ea54c66cb7..56e90d6753 100644 --- a/api/grpcserver/v2alpha1/reward.go +++ b/api/grpcserver/v2alpha1/reward.go @@ -61,30 +61,14 @@ func (s *RewardStreamService) Stream( } } - dbChan := make(chan *types.Reward, 100) - errChan := make(chan error, 1) - ops, err := toRewardOperations(toRewardRequest(request)) if err != nil { return status.Error(codes.InvalidArgument, err.Error()) } - // send db data to chan to avoid buffer overflow - go func() { - defer close(dbChan) - if err := rewards.IterateRewardsOps(s.db, ops, func(rwd *types.Reward) bool { - select { - case dbChan <- rwd: - return true - case <-ctx.Done(): - // exit if the stream context is canceled - return false - } - }); err != nil { - errChan <- status.Error(codes.Internal, err.Error()) - return - } - }() + ctx, cancel := context.WithCancel(stream.Context()) + defer cancel() + dbChan, errChan := s.fetchFromDB(ctx, ops) var eventsOut <-chan types.Reward var eventsFull <-chan struct{} @@ -139,6 +123,31 @@ func (s *RewardStreamService) Stream( } } +func (s *RewardStreamService) fetchFromDB( + ctx context.Context, + ops builder.Operations, +) (<-chan *types.Reward, <-chan error) { + dbChan := make(chan *types.Reward) + errChan := make(chan error, 1) // buffered to avoid blocking, routine should exit immediately after sending an error + + go func() { + defer close(dbChan) + if err := rewards.IterateRewardsOps(s.db, ops, func(rwd *types.Reward) bool { + select { + case dbChan <- rwd: + return true + case <-ctx.Done(): + // exit if the context is canceled + return false + } + }); err != nil { + errChan <- status.Error(codes.Internal, err.Error()) + } + }() + + return dbChan, errChan +} + func (s *RewardStreamService) String() string { return "RewardStreamService" } diff --git a/api/grpcserver/v2alpha1/reward_test.go b/api/grpcserver/v2alpha1/reward_test.go index 8f9368fc00..752c9984b5 100644 --- a/api/grpcserver/v2alpha1/reward_test.go +++ b/api/grpcserver/v2alpha1/reward_test.go @@ -20,26 +20,28 @@ import ( ) func TestRewardService_List(t *testing.T) { - db := statesql.InMemory() - ctx := context.Background() - - gen := fixture.NewRewardsGenerator().WithAddresses(100).WithUniqueCoinbase() - rwds := make([]types.Reward, 100) - for i := range rwds { - rwd := gen.Next() - require.NoError(t, rewards.Add(db, rwd)) - rwds[i] = *rwd - } + setup := func(t *testing.T) (spacemeshv2alpha1.RewardServiceClient, []types.Reward) { + db := statesql.InMemoryTest(t) + + gen := fixture.NewRewardsGenerator().WithAddresses(90).WithUniqueCoinbase() + rwds := make([]types.Reward, 90) + for i := range rwds { + rwd := gen.Next() + require.NoError(t, rewards.Add(db, rwd)) + rwds[i] = *rwd + } - svc := NewRewardService(db) - cfg, cleanup := launchServer(t, svc) - t.Cleanup(cleanup) + svc := NewRewardService(db) + cfg, cleanup := launchServer(t, svc) + t.Cleanup(cleanup) - conn := dialGrpc(t, cfg) - client := spacemeshv2alpha1.NewRewardServiceClient(conn) + conn := dialGrpc(t, cfg) + return spacemeshv2alpha1.NewRewardServiceClient(conn), rwds + } t.Run("limit set too high", func(t *testing.T) { - _, err := client.List(ctx, &spacemeshv2alpha1.RewardRequest{Limit: 200}) + client, _ := setup(t) + _, err := client.List(context.Background(), &spacemeshv2alpha1.RewardRequest{Limit: 200}) require.Error(t, err) s, ok := status.FromError(err) @@ -49,7 +51,8 @@ func TestRewardService_List(t *testing.T) { }) t.Run("no limit set", func(t *testing.T) { - _, err := client.List(ctx, &spacemeshv2alpha1.RewardRequest{}) + client, _ := setup(t) + _, err := client.List(context.Background(), &spacemeshv2alpha1.RewardRequest{}) require.Error(t, err) s, ok := status.FromError(err) @@ -59,7 +62,8 @@ func TestRewardService_List(t *testing.T) { }) t.Run("limit and offset", func(t *testing.T) { - list, err := client.List(ctx, &spacemeshv2alpha1.RewardRequest{ + client, _ := setup(t) + list, err := client.List(context.Background(), &spacemeshv2alpha1.RewardRequest{ Limit: 25, Offset: 50, }) @@ -68,13 +72,15 @@ func TestRewardService_List(t *testing.T) { }) t.Run("all", func(t *testing.T) { - list, err := client.List(ctx, &spacemeshv2alpha1.RewardRequest{Limit: 100}) + client, rwds := setup(t) + list, err := client.List(context.Background(), &spacemeshv2alpha1.RewardRequest{Limit: 100}) require.NoError(t, err) - require.Len(t, rwds, len(list.Rewards)) + require.Len(t, list.Rewards, len(rwds)) }) t.Run("coinbase", func(t *testing.T) { - list, err := client.List(ctx, &spacemeshv2alpha1.RewardRequest{ + client, rwds := setup(t) + list, err := client.List(context.Background(), &spacemeshv2alpha1.RewardRequest{ Limit: 1, StartLayer: rwds[3].Layer.Uint32(), EndLayer: rwds[3].Layer.Uint32(), @@ -88,7 +94,8 @@ func TestRewardService_List(t *testing.T) { }) t.Run("smesher", func(t *testing.T) { - list, err := client.List(ctx, &spacemeshv2alpha1.RewardRequest{ + client, rwds := setup(t) + list, err := client.List(context.Background(), &spacemeshv2alpha1.RewardRequest{ Limit: 1, StartLayer: rwds[4].Layer.Uint32(), EndLayer: rwds[4].Layer.Uint32(), @@ -103,29 +110,32 @@ func TestRewardService_List(t *testing.T) { } func TestRewardStreamService_Stream(t *testing.T) { - db := statesql.InMemory() - ctx := context.Background() - - gen := fixture.NewRewardsGenerator() - rwds := make([]types.Reward, 100) - for i := range rwds { - rwd := gen.Next() - require.NoError(t, rewards.Add(db, rwd)) - rwds[i] = *rwd - } + setup := func(t *testing.T) spacemeshv2alpha1.RewardStreamServiceClient { + db := statesql.InMemoryTest(t) - svc := NewRewardStreamService(db) - cfg, cleanup := launchServer(t, svc) - t.Cleanup(cleanup) + gen := fixture.NewRewardsGenerator() + rwds := make([]types.Reward, 100) + for i := range rwds { + rwd := gen.Next() + require.NoError(t, rewards.Add(db, rwd)) + rwds[i] = *rwd + } + + svc := NewRewardStreamService(db) + cfg, cleanup := launchServer(t, svc) + t.Cleanup(cleanup) - conn := dialGrpc(t, cfg) - client := spacemeshv2alpha1.NewRewardStreamServiceClient(conn) + conn := dialGrpc(t, cfg) + return spacemeshv2alpha1.NewRewardStreamServiceClient(conn) + } t.Run("all", func(t *testing.T) { events.InitializeReporter() t.Cleanup(events.CloseEventReporter) - stream, err := client.Stream(ctx, &spacemeshv2alpha1.RewardStreamRequest{}) + client := setup(t) + + stream, err := client.Stream(context.Background(), &spacemeshv2alpha1.RewardStreamRequest{}) require.NoError(t, err) var i int @@ -136,19 +146,21 @@ func TestRewardStreamService_Stream(t *testing.T) { } i++ } - require.Len(t, rwds, i) + require.Equal(t, 100, i) }) t.Run("watch", func(t *testing.T) { events.InitializeReporter() t.Cleanup(events.CloseEventReporter) + client := setup(t) + const ( start = 100 n = 10 ) - gen = fixture.NewRewardsGenerator().WithLayers(start, 10) + gen := fixture.NewRewardsGenerator().WithLayers(start, 10) var streamed []types.Reward for i := 0; i < n; i++ { rwd := gen.Next() @@ -183,7 +195,7 @@ func TestRewardStreamService_Stream(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - stream, err := client.Stream(ctx, tc.request) + stream, err := client.Stream(context.Background(), tc.request) require.NoError(t, err) _, err = stream.Header() require.NoError(t, err) @@ -191,7 +203,7 @@ func TestRewardStreamService_Stream(t *testing.T) { var expect []*types.Reward for _, rst := range streamed { require.NoError(t, events.ReportRewardReceived(rst)) - matcher := rewardsMatcher{tc.request, ctx} + matcher := rewardsMatcher{tc.request, context.Background()} if matcher.match(&rst) { expect = append(expect, &rst) } diff --git a/api/grpcserver/v2alpha1/transaction_test.go b/api/grpcserver/v2alpha1/transaction_test.go index c94a5287c9..c743c6e842 100644 --- a/api/grpcserver/v2alpha1/transaction_test.go +++ b/api/grpcserver/v2alpha1/transaction_test.go @@ -38,7 +38,7 @@ import ( func TestTransactionService_List(t *testing.T) { types.SetLayersPerEpoch(5) - db := statesql.InMemory() + db := statesql.InMemoryTest(t) ctx := context.Background() gen := fixture.NewTransactionResultGenerator().WithAddresses(2) @@ -223,7 +223,7 @@ func TestTransactionService_List(t *testing.T) { func TestTransactionService_EstimateGas(t *testing.T) { types.SetLayersPerEpoch(5) - db := statesql.InMemory() + db := statesql.InMemoryTest(t) vminst := vm.New(db) ctx := context.Background() @@ -289,7 +289,7 @@ func TestTransactionService_EstimateGas(t *testing.T) { func TestTransactionService_ParseTransaction(t *testing.T) { types.SetLayersPerEpoch(5) - db := statesql.InMemory() + db := statesql.InMemoryTest(t) vminst := vm.New(db) ctx := context.Background() @@ -408,7 +408,7 @@ func TestTransactionServiceSubmitUnsync(t *testing.T) { txHandler := NewMocktransactionValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil) - svc := NewTransactionService(statesql.InMemory(), nil, syncer, txHandler, publisher) + svc := NewTransactionService(statesql.InMemoryTest(t), nil, syncer, txHandler, publisher) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -451,7 +451,7 @@ func TestTransactionServiceSubmitInvalidTx(t *testing.T) { txHandler := NewMocktransactionValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(errors.New("failed validation")) - svc := NewTransactionService(statesql.InMemory(), nil, syncer, txHandler, publisher) + svc := NewTransactionService(statesql.InMemoryTest(t), nil, syncer, txHandler, publisher) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -488,7 +488,7 @@ func TestTransactionService_SubmitNoConcurrency(t *testing.T) { txHandler := NewMocktransactionValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil).Times(numTxs) - svc := NewTransactionService(statesql.InMemory(), nil, syncer, txHandler, publisher) + svc := NewTransactionService(statesql.InMemoryTest(t), nil, syncer, txHandler, publisher) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) diff --git a/beacon/handlers_test.go b/beacon/handlers_test.go index 3fe1bcdd98..875c430a3c 100644 --- a/beacon/handlers_test.go +++ b/beacon/handlers_test.go @@ -1095,7 +1095,7 @@ func Test_HandleFirstVotes_FailedToVerifySig(t *testing.T) { tpd.mClock.EXPECT().CurrentLayer().Return(epoch.FirstLayer()) tpd.mClock.EXPECT().LayerToTime(gomock.Any()).Return(time.Now()).AnyTimes() got := tpd.HandleFirstVotes(context.Background(), "peerID", msgBytes) - require.Contains(t, got.Error(), fmt.Sprintf("verify signature %s: failed", msg.Signature)) + require.ErrorContains(t, got, fmt.Sprintf("verify signature %s: failed", msg.Signature)) checkVoted(t, tpd.ProtocolDriver, epoch, signer, types.FirstRound, false) checkFirstIncomingVotes(t, tpd.ProtocolDriver, epoch, map[types.NodeID]proposalList{}) } @@ -1366,7 +1366,7 @@ func Test_handleFollowingVotes_FailedToVerifySig(t *testing.T) { tpd.mClock.EXPECT().CurrentLayer().Return(epoch.FirstLayer()) tpd.mClock.EXPECT().LayerToTime(gomock.Any()).Return(time.Now()).AnyTimes() got := tpd.HandleFollowingVotes(context.Background(), "peerID", msgBytes) - require.Contains(t, got.Error(), fmt.Sprintf("verify signature %s: failed", msg.Signature)) + require.ErrorContains(t, got, fmt.Sprintf("verify signature %s: failed", msg.Signature)) checkVoted(t, tpd.ProtocolDriver, epoch, signer, round, false) checkVoteMargins(t, tpd.ProtocolDriver, epoch, emptyVoteMargins(plist)) } diff --git a/common/util/json_test.go b/common/util/json_test.go index cfbfd776b1..b8d3ef8a51 100644 --- a/common/util/json_test.go +++ b/common/util/json_test.go @@ -17,28 +17,14 @@ package util import ( - "bytes" "encoding/hex" "encoding/json" "errors" + "strconv" "testing" -) -func checkError(t *testing.T, input string, got, want error) bool { - if got == nil { - if want != nil { - t.Errorf("input %s: got no error, want %q", input, want) - return false - } - return true - } - if want == nil { - t.Errorf("input %s: unexpected error %q", input, got) - } else if got.Error() != want.Error() { - t.Errorf("input %s: got error %q, want %q", input, got, want) - } - return false -} + "github.com/stretchr/testify/require" +) func referenceBytes(s string) []byte { b, err := hex.DecodeString(s) @@ -48,18 +34,18 @@ func referenceBytes(s string) []byte { return b } -var errJSONEOF = errors.New("unexpected end of JSON input") - -var unmarshalBytesTests = []unmarshalTest{ +var unmarshalBytesErrorTests = []unmarshalTest{ // invalid encoding - {input: "", wantErr: errJSONEOF}, + {input: "", wantErr: errors.New("unexpected end of JSON input")}, {input: "null", wantErr: errNonString(bytesT)}, {input: "10", wantErr: errNonString(bytesT)}, {input: `"0"`, wantErr: wrapTypeError(ErrMissingPrefix, bytesT)}, {input: `"0x0"`, wantErr: wrapTypeError(ErrOddLength, bytesT)}, {input: `"0xxx"`, wantErr: wrapTypeError(ErrSyntax, bytesT)}, {input: `"0x01zz01"`, wantErr: wrapTypeError(ErrSyntax, bytesT)}, +} +var unmarshalBytesTests = []unmarshalTest{ // valid encoding {input: `""`, want: referenceBytes("")}, {input: `"0x"`, want: referenceBytes("")}, @@ -73,16 +59,17 @@ var unmarshalBytesTests = []unmarshalTest{ } func TestUnmarshalBytes(t *testing.T) { + for _, test := range unmarshalBytesErrorTests { + var v Bytes + err := json.Unmarshal([]byte(test.input), &v) + require.EqualError(t, err, test.wantErr.Error()) + } + for _, test := range unmarshalBytesTests { var v Bytes err := json.Unmarshal([]byte(test.input), &v) - if !checkError(t, test.input, err, test.wantErr) { - continue - } - if !bytes.Equal(test.want.([]byte), []byte(v)) { - t.Errorf("input %s: value mismatch: got %x, want %x", test.input, &v, test.want) - continue - } + require.NoError(t, err) + require.Equal(t, test.want.([]byte), []byte(v)) } } @@ -100,17 +87,8 @@ func TestMarshalBytes(t *testing.T) { for _, test := range encodeBytesTests { in := test.input.([]byte) out, err := json.Marshal(Bytes(in)) - if err != nil { - t.Errorf("%x: %v", in, err) - continue - } - if want := `"` + test.want + `"`; string(out) != want { - t.Errorf("%x: MarshalJSON output mismatch: got %q, want %q", in, out, want) - continue - } - if out := Bytes(in).String(); out != test.want { - t.Errorf("%x: String mismatch: got %q, want %q", in, out, test.want) - continue - } + require.NoError(t, err) + require.Equal(t, strconv.Quote(test.want), string(out)) + require.Equal(t, test.want, Bytes(in).String()) } } diff --git a/datastore/store.go b/datastore/store.go index 3eaa6db83c..aee925cb02 100644 --- a/datastore/store.go +++ b/datastore/store.go @@ -10,9 +10,7 @@ import ( "go.uber.org/zap" "github.com/spacemeshos/go-spacemesh/atxsdata" - "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/proposals/store" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/activesets" @@ -45,7 +43,7 @@ type CachedDB struct { // used to coordinate db update and cache mu sync.Mutex - malfeasanceCache *lru.Cache[types.NodeID, *wire.MalfeasanceProof] + malfeasanceCache *lru.Cache[types.NodeID, []byte] } type Config struct { @@ -95,7 +93,7 @@ func NewCachedDB(db sql.StateDatabase, lg *zap.Logger, opts ...Opt) *CachedDB { lg.Fatal("failed to create atx cache", zap.Error(err)) } - malfeasanceCache, err := lru.New[types.NodeID, *wire.MalfeasanceProof](o.cfg.MalfeasanceSize) + malfeasanceCache, err := lru.New[types.NodeID, []byte](o.cfg.MalfeasanceSize) if err != nil { lg.Fatal("failed to create malfeasance cache", zap.Error(err)) } @@ -120,7 +118,7 @@ func (db *CachedDB) MalfeasanceCacheSize() int { } // GetMalfeasanceProof gets the malfeasance proof associated with the NodeID. -func (db *CachedDB) MalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, error) { +func (db *CachedDB) MalfeasanceProof(id types.NodeID) ([]byte, error) { if id == types.EmptyNodeID { panic("invalid argument to GetMalfeasanceProof") } @@ -139,13 +137,11 @@ func (db *CachedDB) MalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, e if err != nil && err != sql.ErrNotFound { return nil, err } - proof := &wire.MalfeasanceProof{} - codec.MustDecode(blob.Bytes, proof) - db.malfeasanceCache.Add(id, proof) - return proof, err + db.malfeasanceCache.Add(id, blob.Bytes) + return blob.Bytes, err } -func (db *CachedDB) CacheMalfeasanceProof(id types.NodeID, proof *wire.MalfeasanceProof) { +func (db *CachedDB) CacheMalfeasanceProof(id types.NodeID, proof []byte) { if id == types.EmptyNodeID { panic("invalid argument to CacheMalfeasanceProof") } @@ -200,7 +196,7 @@ func (db *CachedDB) Previous(id types.ATXID) ([]types.ATXID, error) { } func (db *CachedDB) IterateMalfeasanceProofs( - iter func(types.NodeID, *wire.MalfeasanceProof) error, + iter func(types.NodeID, []byte) error, ) error { ids, err := identities.GetMalicious(db) if err != nil { diff --git a/datastore/store_test.go b/datastore/store_test.go index a5f474a342..4517fb5682 100644 --- a/datastore/store_test.go +++ b/datastore/store_test.go @@ -56,19 +56,7 @@ func TestMalfeasanceProof_Dishonest(t *testing.T) { cdb := datastore.NewCachedDB(db, zaptest.NewLogger(t)) require.Equal(t, 0, cdb.MalfeasanceCacheSize()) - // a bad guy - proof := &mwire.MalfeasanceProof{ - Layer: types.LayerID(11), - Proof: mwire.Proof{ - Type: mwire.MultipleBallots, - Data: &mwire.BallotProof{ - Messages: [2]mwire.BallotProofMsg{ - {}, - {}, - }, - }, - }, - } + proof := types.RandomBytes(100) nodeID1 := types.NodeID{1} cdb.CacheMalfeasanceProof(nodeID1, proof) diff --git a/events/events.go b/events/events.go index 30107a733b..e13cd05c80 100644 --- a/events/events.go +++ b/events/events.go @@ -256,14 +256,14 @@ func EmitProposal(nodeID types.NodeID, layer types.LayerID, proposal types.Propo ) } -func EmitOwnMalfeasanceProof(nodeID types.NodeID, mp *wire.MalfeasanceProof) { +func EmitOwnMalfeasanceProof(nodeID types.NodeID, proof []byte) { const help = "Node committed malicious behavior. Identity will be canceled." emitUserEvent( help, false, &pb.Event_Malfeasance{ Malfeasance: &pb.EventMalfeasance{ - Proof: ToMalfeasancePB(nodeID, mp, false), + Proof: ToMalfeasancePB(nodeID, proof, false), }, }, ) @@ -284,8 +284,9 @@ func emitUserEvent(help string, failure bool, details pb.IsEventDetails) { } } -func ToMalfeasancePB(nodeID types.NodeID, mp *wire.MalfeasanceProof, includeProof bool) *pb.MalfeasanceProof { - if mp == nil { +func ToMalfeasancePB(nodeID types.NodeID, proof []byte, includeProof bool) *pb.MalfeasanceProof { + mp := &wire.MalfeasanceProof{} + if err := codec.Decode(proof, mp); err != nil { return &pb.MalfeasanceProof{} } kind := pb.MalfeasanceProof_MALFEASANCE_UNSPECIFIED @@ -308,8 +309,7 @@ func ToMalfeasancePB(nodeID types.NodeID, mp *wire.MalfeasanceProof, includeProo DebugInfo: wire.MalfeasanceInfo(nodeID, mp), } if includeProof { - data, _ := codec.Encode(mp) - result.Proof = data + result.Proof = proof } return result } diff --git a/events/malfeasance.go b/events/malfeasance.go index 073a191b99..4ea0ef6ea8 100644 --- a/events/malfeasance.go +++ b/events/malfeasance.go @@ -3,13 +3,12 @@ package events import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/log" - "github.com/spacemeshos/go-spacemesh/malfeasance/wire" ) // EventMalfeasance includes the malfeasance proof. type EventMalfeasance struct { Smesher types.NodeID - Proof *wire.MalfeasanceProof + Proof []byte } // SubscribeMalfeasance subscribes malfeasance events. @@ -27,11 +26,11 @@ func SubscribeMalfeasance() Subscription { } // ReportMalfeasance reports a malfeasance proof. -func ReportMalfeasance(nodeID types.NodeID, mp *wire.MalfeasanceProof) { +func ReportMalfeasance(nodeID types.NodeID, proof []byte) { mu.RLock() defer mu.RUnlock() if reporter != nil { - if err := reporter.malfeasanceEmitter.Emit(EventMalfeasance{Smesher: nodeID, Proof: mp}); err != nil { + if err := reporter.malfeasanceEmitter.Emit(EventMalfeasance{Smesher: nodeID, Proof: proof}); err != nil { log.With().Error("failed to emit malfeasance proof", log.Err(err)) } } diff --git a/go.mod b/go.mod index 4b42b0f526..7990fd3209 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/seehuhn/mt19937 v1.0.0 github.com/slok/go-http-metrics v0.12.0 - github.com/spacemeshos/api/release/go v1.52.0 + github.com/spacemeshos/api/release/go v1.53.0 github.com/spacemeshos/economics v0.1.3 github.com/spacemeshos/fixed v0.1.1 github.com/spacemeshos/go-scale v1.2.0 diff --git a/go.sum b/go.sum index a44068cc0b..d30408d585 100644 --- a/go.sum +++ b/go.sum @@ -603,8 +603,8 @@ github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:Udh github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= -github.com/spacemeshos/api/release/go v1.52.0 h1:3cohOoFIk0RLF5fdL0y6pFgZ7Ngg1Yht+aeN3Xm5Qn8= -github.com/spacemeshos/api/release/go v1.52.0/go.mod h1:Qr/pVPMmN5Q5qLHSXqVMDKDCu6LkHWzGPNflylE0u00= +github.com/spacemeshos/api/release/go v1.53.0 h1:NGPCPNMwkQtPgx2P/bmhvUdGFyv78UMSTGYvkj1JIlo= +github.com/spacemeshos/api/release/go v1.53.0/go.mod h1:nIWzRxJe365XvjN51AhfwR6Lf0EDaMcnmAb6hd2D0xw= github.com/spacemeshos/economics v0.1.3 h1:ACkq3mTebIky4Zwbs9SeSSRZrUCjU/Zk0wq9Z0BTh2A= github.com/spacemeshos/economics v0.1.3/go.mod h1:FH7u0FzTIm6Kpk+X5HOZDvpkgNYBKclmH86rVwYaDAo= github.com/spacemeshos/fixed v0.1.1 h1:N1y4SUpq1EV+IdJrWJwUCt1oBFzeru/VKVcBsvPc2Fk= diff --git a/hare3/hare.go b/hare3/hare.go index a4c3dd3326..5eb085e092 100644 --- a/hare3/hare.go +++ b/hare3/hare.go @@ -292,7 +292,7 @@ func (h *Hare) Running() int { return len(h.sessions) } -func (h *Hare) Handler(ctx context.Context, peer p2p.Peer, buf []byte) error { +func (h *Hare) Handler(ctx context.Context, _ p2p.Peer, buf []byte) error { msg := &Message{} if err := codec.Decode(buf, msg); err != nil { malformedError.Inc() diff --git a/hare3/malfeasance.go b/hare3/malfeasance.go index e71a9c6985..e7128ea3bd 100644 --- a/hare3/malfeasance.go +++ b/hare3/malfeasance.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strconv" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" @@ -52,6 +53,20 @@ func NewMalfeasanceHandler( return mh } +func (mh *MalfeasanceHandler) Info(data wire.ProofData) (map[string]string, error) { + hp, ok := data.(*wire.HareProof) + if !ok { + return nil, errors.New("wrong message type for hare equivocation") + } + return map[string]string{ + "msg1": hp.Messages[0].InnerMsg.MsgHash.String(), + "msg2": hp.Messages[1].InnerMsg.MsgHash.String(), + "layer": hp.Messages[0].InnerMsg.Layer.String(), + "round": strconv.FormatUint(uint64(hp.Messages[0].InnerMsg.Round), 10), + "smesher_id": hp.Messages[0].SmesherID.String(), + }, nil +} + func (mh *MalfeasanceHandler) Validate(ctx context.Context, data wire.ProofData) (types.NodeID, error) { hp, ok := data.(*wire.HareProof) if !ok { diff --git a/malfeasance/handler.go b/malfeasance/handler.go index 45acdfcd54..5d4eea62a9 100644 --- a/malfeasance/handler.go +++ b/malfeasance/handler.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "slices" + "strconv" "time" "go.uber.org/zap" @@ -32,31 +33,22 @@ var ( type MalfeasanceType byte const ( - // V1 types. MultipleATXs MalfeasanceType = MalfeasanceType(wire.MultipleATXs) MultipleBallots = MalfeasanceType(wire.MultipleBallots) HareEquivocation = MalfeasanceType(wire.HareEquivocation) InvalidPostIndex = MalfeasanceType(wire.InvalidPostIndex) InvalidPrevATX = MalfeasanceType(wire.InvalidPrevATX) - - // V2 types - // TODO(mafa): for future use. - InvalidActivation MalfeasanceType = iota + 10 - InvalidBallot - InvalidHareMsg ) // Handler processes MalfeasanceProof from gossip and, if deems it valid, propagates it to peers. type Handler struct { - logger *zap.Logger - cdb *datastore.CachedDB - - handlersV1 map[MalfeasanceType]HandlerV1 - handlersV2 map[MalfeasanceType]HandlerV2 - + logger *zap.Logger + cdb *datastore.CachedDB self p2p.Peer nodeIDs []types.NodeID tortoise tortoise + + handlers map[MalfeasanceType]MalfeasanceHandler } func NewHandler( @@ -73,33 +65,50 @@ func NewHandler( nodeIDs: nodeID, tortoise: tortoise, - handlersV1: make(map[MalfeasanceType]HandlerV1), - handlersV2: make(map[MalfeasanceType]HandlerV2), + handlers: make(map[MalfeasanceType]MalfeasanceHandler), } } -func (h *Handler) RegisterHandlerV1(malfeasanceType MalfeasanceType, handler HandlerV1) { - h.handlersV1[malfeasanceType] = handler -} +func (h *Handler) RegisterHandler(malfeasanceType MalfeasanceType, handler MalfeasanceHandler) { + if _, ok := h.handlers[malfeasanceType]; ok { + h.logger.Panic("handler already registered", zap.Int("malfeasanceType", int(malfeasanceType))) + } -func (h *Handler) RegisterHandlerV2(malfeasanceType MalfeasanceType, handler HandlerV2) { - h.handlersV2[malfeasanceType] = handler + h.handlers[malfeasanceType] = handler } -func (h *Handler) reportMalfeasance(smesher types.NodeID, mp *wire.MalfeasanceProof) { +func (h *Handler) reportMalfeasance(smesher types.NodeID, proof []byte) { h.tortoise.OnMalfeasance(smesher) - events.ReportMalfeasance(smesher, mp) + events.ReportMalfeasance(smesher, proof) if slices.Contains(h.nodeIDs, smesher) { - events.EmitOwnMalfeasanceProof(smesher, mp) + events.EmitOwnMalfeasanceProof(smesher, proof) } } func (h *Handler) countProof(mp *wire.MalfeasanceProof) { - h.handlersV1[MalfeasanceType(mp.Proof.Type)].ReportProof(numProofs) + h.handlers[MalfeasanceType(mp.Proof.Type)].ReportProof(numProofs) } func (h *Handler) countInvalidProof(p *wire.MalfeasanceProof) { - h.handlersV1[MalfeasanceType(p.Proof.Type)].ReportInvalidProof(numInvalidProofs) + h.handlers[MalfeasanceType(p.Proof.Type)].ReportInvalidProof(numInvalidProofs) +} + +func (h *Handler) Info(data []byte) (map[string]string, error) { + var p wire.MalfeasanceProof + if err := codec.Decode(data, &p); err != nil { + return nil, fmt.Errorf("decode malfeasance proof: %w", err) + } + mh, ok := h.handlers[MalfeasanceType(p.Proof.Type)] + if !ok { + return nil, fmt.Errorf("unknown malfeasance type %d", p.Proof.Type) + } + properties, err := mh.Info(p.Proof.Data) + if err != nil { + return nil, fmt.Errorf("malfeasance info: %w", err) + } + properties["domain"] = "0" // for malfeasance V1 there are no domains + properties["type"] = strconv.FormatUint(uint64(p.Proof.Type), 10) + return properties, nil } // HandleSyncedMalfeasanceProof is the sync validator for MalfeasanceProof. @@ -115,7 +124,7 @@ func (h *Handler) HandleSyncedMalfeasanceProof( h.logger.Error("malformed message (sync)", log.ZContext(ctx), zap.Error(err)) return errMalformedData } - nodeID, err := h.validateAndSave(ctx, &wire.MalfeasanceGossip{MalfeasanceProof: p}) + nodeID, err := h.validateAndSave(ctx, &p) if err == nil && types.Hash32(nodeID) != expHash { return fmt.Errorf( "%w: malfeasance proof want %s, got %s", @@ -135,52 +144,48 @@ func (h *Handler) HandleMalfeasanceProof(ctx context.Context, peer p2p.Peer, dat h.logger.Error("malformed message", log.ZContext(ctx), zap.Error(err)) return errMalformedData } + if p.Eligibility != nil { + numMalformed.Inc() + return fmt.Errorf("%w: eligibility field was deprecated with hare3", pubsub.ErrValidationReject) + } if peer == h.self { - id, err := h.Validate(ctx, &p) + id, err := h.Validate(ctx, &p.MalfeasanceProof) if err != nil { h.countInvalidProof(&p.MalfeasanceProof) return err } - h.reportMalfeasance(id, &p.MalfeasanceProof) + h.reportMalfeasance(id, codec.MustEncode(&p.MalfeasanceProof)) // node saves malfeasance proof eagerly/atomically with the malicious data. // it has validated the proof before saving to db. h.countProof(&p.MalfeasanceProof) return nil } - _, err := h.validateAndSave(ctx, &p) + _, err := h.validateAndSave(ctx, &p.MalfeasanceProof) return err } -func (h *Handler) validateAndSave(ctx context.Context, p *wire.MalfeasanceGossip) (types.NodeID, error) { - if p.Eligibility != nil { - numMalformed.Inc() - return types.EmptyNodeID, fmt.Errorf( - "%w: eligibility field was deprecated with hare3", - pubsub.ErrValidationReject, - ) - } +func (h *Handler) validateAndSave(ctx context.Context, p *wire.MalfeasanceProof) (types.NodeID, error) { + p.SetReceived(time.Now()) nodeID, err := h.Validate(ctx, p) switch { case errors.Is(err, errUnknownProof): numMalformed.Inc() return types.EmptyNodeID, err case err != nil: - h.countInvalidProof(&p.MalfeasanceProof) + h.countInvalidProof(p) return types.EmptyNodeID, errors.Join(err, pubsub.ErrValidationReject) } + proofBytes := codec.MustEncode(p) if err := h.cdb.WithTx(ctx, func(dbtx sql.Transaction) error { malicious, err := identities.IsMalicious(dbtx, nodeID) if err != nil { return fmt.Errorf("check known malicious: %w", err) - } else if malicious { + } + if malicious { h.logger.Debug("known malicious identity", log.ZContext(ctx), zap.Stringer("smesher", nodeID)) return ErrKnownProof } - encoded, err := codec.Encode(&p.MalfeasanceProof) - if err != nil { - h.logger.Panic("failed to encode MalfeasanceProof", zap.Error(err)) - } - if err := identities.SetMalicious(dbtx, nodeID, encoded, time.Now()); err != nil { + if err := identities.SetMalicious(dbtx, nodeID, proofBytes, time.Now()); err != nil { return fmt.Errorf("add malfeasance proof: %w", err) } return nil @@ -193,11 +198,11 @@ func (h *Handler) validateAndSave(ctx context.Context, p *wire.MalfeasanceGossip zap.Error(err), ) } - return types.EmptyNodeID, err + return nodeID, err } - h.reportMalfeasance(nodeID, &p.MalfeasanceProof) - h.cdb.CacheMalfeasanceProof(nodeID, &p.MalfeasanceProof) - h.countProof(&p.MalfeasanceProof) + h.reportMalfeasance(nodeID, proofBytes) + h.cdb.CacheMalfeasanceProof(nodeID, proofBytes) + h.countProof(p) h.logger.Debug("new malfeasance proof", log.ZContext(ctx), zap.Stringer("smesher", nodeID), @@ -206,8 +211,8 @@ func (h *Handler) validateAndSave(ctx context.Context, p *wire.MalfeasanceGossip return nodeID, nil } -func (h *Handler) Validate(ctx context.Context, p *wire.MalfeasanceGossip) (types.NodeID, error) { - mh, ok := h.handlersV1[MalfeasanceType(p.Proof.Type)] +func (h *Handler) Validate(ctx context.Context, p *wire.MalfeasanceProof) (types.NodeID, error) { + mh, ok := h.handlers[MalfeasanceType(p.Proof.Type)] if !ok { return types.EmptyNodeID, fmt.Errorf("%w: unknown malfeasance type", errUnknownProof) } diff --git a/malfeasance/handler_test.go b/malfeasance/handler_test.go index 0a9d09d40f..bdc5768537 100644 --- a/malfeasance/handler_test.go +++ b/malfeasance/handler_test.go @@ -3,6 +3,8 @@ package malfeasance import ( "context" "errors" + "fmt" + "strconv" "testing" "time" @@ -92,7 +94,7 @@ func TestHandler_HandleMalfeasanceProof(t *testing.T) { h := newHandler(t) ctrl := gomock.NewController(t) - handler := NewMockHandlerV1(ctrl) + handler := NewMockMalfeasanceHandler(ctrl) handler.EXPECT().Validate(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, data wire.ProofData) (types.NodeID, error) { require.IsType(t, &wire.AtxProof{}, data) @@ -100,7 +102,7 @@ func TestHandler_HandleMalfeasanceProof(t *testing.T) { }, ) handler.EXPECT().ReportInvalidProof(gomock.Any()) - h.RegisterHandlerV1(MultipleATXs, handler) + h.RegisterHandler(MultipleATXs, handler) gossip := &wire.MalfeasanceGossip{ MalfeasanceProof: wire.MalfeasanceProof{ @@ -122,7 +124,7 @@ func TestHandler_HandleMalfeasanceProof(t *testing.T) { nodeID := types.RandomNodeID() ctrl := gomock.NewController(t) - handler := NewMockHandlerV1(ctrl) + handler := NewMockMalfeasanceHandler(ctrl) handler.EXPECT().Validate(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, data wire.ProofData) (types.NodeID, error) { require.IsType(t, &wire.AtxProof{}, data) @@ -130,7 +132,7 @@ func TestHandler_HandleMalfeasanceProof(t *testing.T) { }, ) handler.EXPECT().ReportProof(gomock.Any()) - h.RegisterHandlerV1(MultipleATXs, handler) + h.RegisterHandler(MultipleATXs, handler) gossip := &wire.MalfeasanceGossip{ MalfeasanceProof: wire.MalfeasanceProof{ @@ -165,14 +167,14 @@ func TestHandler_HandleMalfeasanceProof(t *testing.T) { identities.SetMalicious(h.db, nodeID, codec.MustEncode(proof), time.Now()) ctrl := gomock.NewController(t) - handler := NewMockHandlerV1(ctrl) + handler := NewMockMalfeasanceHandler(ctrl) handler.EXPECT().Validate(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, data wire.ProofData) (types.NodeID, error) { require.IsType(t, &wire.AtxProof{}, data) return nodeID, nil }, ) - h.RegisterHandlerV1(MultipleATXs, handler) + h.RegisterHandler(MultipleATXs, handler) gossip := &wire.MalfeasanceGossip{ MalfeasanceProof: wire.MalfeasanceProof{ @@ -233,7 +235,7 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) { nodeID := types.RandomNodeID() ctrl := gomock.NewController(t) - handler := NewMockHandlerV1(ctrl) + handler := NewMockMalfeasanceHandler(ctrl) handler.EXPECT().Validate(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, data wire.ProofData) (types.NodeID, error) { require.IsType(t, &wire.AtxProof{}, data) @@ -241,7 +243,7 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) { }, ) handler.EXPECT().ReportProof(gomock.Any()) - h.RegisterHandlerV1(MultipleATXs, handler) + h.RegisterHandler(MultipleATXs, handler) proof := &wire.MalfeasanceProof{ Layer: types.LayerID(22), @@ -267,7 +269,7 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) { nodeID := types.RandomNodeID() ctrl := gomock.NewController(t) - handler := NewMockHandlerV1(ctrl) + handler := NewMockMalfeasanceHandler(ctrl) handler.EXPECT().Validate(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, data wire.ProofData) (types.NodeID, error) { require.IsType(t, &wire.AtxProof{}, data) @@ -275,7 +277,7 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) { }, ) handler.EXPECT().ReportInvalidProof(gomock.Any()) - h.RegisterHandlerV1(MultipleATXs, handler) + h.RegisterHandler(MultipleATXs, handler) proof := &wire.MalfeasanceProof{ Layer: types.LayerID(22), @@ -300,7 +302,7 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) { nodeID := types.RandomNodeID() ctrl := gomock.NewController(t) - handler := NewMockHandlerV1(ctrl) + handler := NewMockMalfeasanceHandler(ctrl) handler.EXPECT().Validate(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, data wire.ProofData) (types.NodeID, error) { require.IsType(t, &wire.AtxProof{}, data) @@ -308,7 +310,7 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) { }, ) handler.EXPECT().ReportProof(gomock.Any()) - h.RegisterHandlerV1(MultipleATXs, handler) + h.RegisterHandler(MultipleATXs, handler) proof := &wire.MalfeasanceProof{ Layer: types.LayerID(22), @@ -343,14 +345,14 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) { identities.SetMalicious(h.db, nodeID, proofBytes, time.Now()) ctrl := gomock.NewController(t) - handler := NewMockHandlerV1(ctrl) + handler := NewMockMalfeasanceHandler(ctrl) handler.EXPECT().Validate(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, data wire.ProofData) (types.NodeID, error) { require.IsType(t, &wire.AtxProof{}, data) return nodeID, nil }, ) - h.RegisterHandlerV1(MultipleATXs, handler) + h.RegisterHandler(MultipleATXs, handler) newProof := &wire.MalfeasanceProof{ Layer: types.LayerID(22), @@ -370,3 +372,85 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) { require.Equal(t, proofBytes, blob.Bytes) }) } + +func TestHandler_Info(t *testing.T) { + t.Run("malformed data", func(t *testing.T) { + h := newHandler(t) + + info, err := h.Info(types.RandomBytes(32)) + require.ErrorContains(t, err, "decode malfeasance proof:") + require.Nil(t, info) + }) + + t.Run("unknown malfeasance type", func(t *testing.T) { + h := newHandler(t) + + proof := &wire.MalfeasanceProof{ + Layer: types.LayerID(22), + Proof: wire.Proof{ + Type: wire.MultipleATXs, + Data: &wire.AtxProof{}, + }, + } + proofBytes := codec.MustEncode(proof) + + info, err := h.Info(proofBytes) + require.ErrorContains(t, err, fmt.Sprintf("unknown malfeasance type %d", wire.MultipleATXs)) + require.Nil(t, info) + }) + + t.Run("invalid proof", func(t *testing.T) { + h := newHandler(t) + + ctrl := gomock.NewController(t) + handler := NewMockMalfeasanceHandler(ctrl) + handler.EXPECT().Info(gomock.Any()).Return(nil, errors.New("invalid proof")) + h.RegisterHandler(MultipleATXs, handler) + + proof := &wire.MalfeasanceProof{ + Layer: types.LayerID(22), + Proof: wire.Proof{ + Type: wire.MultipleATXs, + Data: &wire.AtxProof{}, + }, + } + proofBytes := codec.MustEncode(proof) + + info, err := h.Info(proofBytes) + require.ErrorContains(t, err, "invalid proof") + require.Nil(t, info) + }) + + t.Run("valid proof", func(t *testing.T) { + h := newHandler(t) + + properties := map[string]string{ + "key": "value", + } + + ctrl := gomock.NewController(t) + handler := NewMockMalfeasanceHandler(ctrl) + handler.EXPECT().Info(gomock.Any()).Return(properties, nil) + h.RegisterHandler(MultipleATXs, handler) + + proof := &wire.MalfeasanceProof{ + Layer: types.LayerID(22), + Proof: wire.Proof{ + Type: wire.MultipleATXs, + Data: &wire.AtxProof{}, + }, + } + proofBytes := codec.MustEncode(proof) + expectedProperties := map[string]string{ + "domain": "0", + "type": strconv.FormatUint(uint64(wire.MultipleATXs), 10), + } + for k, v := range properties { + expectedProperties[k] = v + } + + info, err := h.Info(proofBytes) + require.NoError(t, err) + require.Equal(t, expectedProperties, info) + }) +} diff --git a/malfeasance/interface.go b/malfeasance/interface.go index 46b3cdb9f9..3486d5878f 100644 --- a/malfeasance/interface.go +++ b/malfeasance/interface.go @@ -15,12 +15,9 @@ type tortoise interface { OnMalfeasance(types.NodeID) } -type HandlerV1 interface { +type MalfeasanceHandler interface { Validate(ctx context.Context, data wire.ProofData) (types.NodeID, error) + Info(data wire.ProofData) (map[string]string, error) ReportProof(vec *prometheus.CounterVec) ReportInvalidProof(vec *prometheus.CounterVec) } - -type HandlerV2 interface { - Validate(ctx context.Context, data []byte) (types.NodeID, error) -} diff --git a/malfeasance/mocks.go b/malfeasance/mocks.go index d0be0c3a1a..f123149e61 100644 --- a/malfeasance/mocks.go +++ b/malfeasance/mocks.go @@ -78,165 +78,142 @@ func (c *MocktortoiseOnMalfeasanceCall) DoAndReturn(f func(types.NodeID)) *Mockt return c } -// MockHandlerV1 is a mock of HandlerV1 interface. -type MockHandlerV1 struct { +// MockMalfeasanceHandler is a mock of MalfeasanceHandler interface. +type MockMalfeasanceHandler struct { ctrl *gomock.Controller - recorder *MockHandlerV1MockRecorder + recorder *MockMalfeasanceHandlerMockRecorder } -// MockHandlerV1MockRecorder is the mock recorder for MockHandlerV1. -type MockHandlerV1MockRecorder struct { - mock *MockHandlerV1 +// MockMalfeasanceHandlerMockRecorder is the mock recorder for MockMalfeasanceHandler. +type MockMalfeasanceHandlerMockRecorder struct { + mock *MockMalfeasanceHandler } -// NewMockHandlerV1 creates a new mock instance. -func NewMockHandlerV1(ctrl *gomock.Controller) *MockHandlerV1 { - mock := &MockHandlerV1{ctrl: ctrl} - mock.recorder = &MockHandlerV1MockRecorder{mock} +// NewMockMalfeasanceHandler creates a new mock instance. +func NewMockMalfeasanceHandler(ctrl *gomock.Controller) *MockMalfeasanceHandler { + mock := &MockMalfeasanceHandler{ctrl: ctrl} + mock.recorder = &MockMalfeasanceHandlerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockHandlerV1) EXPECT() *MockHandlerV1MockRecorder { +func (m *MockMalfeasanceHandler) EXPECT() *MockMalfeasanceHandlerMockRecorder { return m.recorder } -// ReportInvalidProof mocks base method. -func (m *MockHandlerV1) ReportInvalidProof(vec *prometheus.CounterVec) { +// Info mocks base method. +func (m *MockMalfeasanceHandler) Info(data wire.ProofData) (map[string]string, error) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReportInvalidProof", vec) + ret := m.ctrl.Call(m, "Info", data) + ret0, _ := ret[0].(map[string]string) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// ReportInvalidProof indicates an expected call of ReportInvalidProof. -func (mr *MockHandlerV1MockRecorder) ReportInvalidProof(vec any) *MockHandlerV1ReportInvalidProofCall { +// Info indicates an expected call of Info. +func (mr *MockMalfeasanceHandlerMockRecorder) Info(data any) *MockMalfeasanceHandlerInfoCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportInvalidProof", reflect.TypeOf((*MockHandlerV1)(nil).ReportInvalidProof), vec) - return &MockHandlerV1ReportInvalidProofCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockMalfeasanceHandler)(nil).Info), data) + return &MockMalfeasanceHandlerInfoCall{Call: call} } -// MockHandlerV1ReportInvalidProofCall wrap *gomock.Call -type MockHandlerV1ReportInvalidProofCall struct { +// MockMalfeasanceHandlerInfoCall wrap *gomock.Call +type MockMalfeasanceHandlerInfoCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockHandlerV1ReportInvalidProofCall) Return() *MockHandlerV1ReportInvalidProofCall { - c.Call = c.Call.Return() +func (c *MockMalfeasanceHandlerInfoCall) Return(arg0 map[string]string, arg1 error) *MockMalfeasanceHandlerInfoCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockHandlerV1ReportInvalidProofCall) Do(f func(*prometheus.CounterVec)) *MockHandlerV1ReportInvalidProofCall { +func (c *MockMalfeasanceHandlerInfoCall) Do(f func(wire.ProofData) (map[string]string, error)) *MockMalfeasanceHandlerInfoCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockHandlerV1ReportInvalidProofCall) DoAndReturn(f func(*prometheus.CounterVec)) *MockHandlerV1ReportInvalidProofCall { +func (c *MockMalfeasanceHandlerInfoCall) DoAndReturn(f func(wire.ProofData) (map[string]string, error)) *MockMalfeasanceHandlerInfoCall { c.Call = c.Call.DoAndReturn(f) return c } -// ReportProof mocks base method. -func (m *MockHandlerV1) ReportProof(vec *prometheus.CounterVec) { +// ReportInvalidProof mocks base method. +func (m *MockMalfeasanceHandler) ReportInvalidProof(vec *prometheus.CounterVec) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReportProof", vec) + m.ctrl.Call(m, "ReportInvalidProof", vec) } -// ReportProof indicates an expected call of ReportProof. -func (mr *MockHandlerV1MockRecorder) ReportProof(vec any) *MockHandlerV1ReportProofCall { +// ReportInvalidProof indicates an expected call of ReportInvalidProof. +func (mr *MockMalfeasanceHandlerMockRecorder) ReportInvalidProof(vec any) *MockMalfeasanceHandlerReportInvalidProofCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportProof", reflect.TypeOf((*MockHandlerV1)(nil).ReportProof), vec) - return &MockHandlerV1ReportProofCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportInvalidProof", reflect.TypeOf((*MockMalfeasanceHandler)(nil).ReportInvalidProof), vec) + return &MockMalfeasanceHandlerReportInvalidProofCall{Call: call} } -// MockHandlerV1ReportProofCall wrap *gomock.Call -type MockHandlerV1ReportProofCall struct { +// MockMalfeasanceHandlerReportInvalidProofCall wrap *gomock.Call +type MockMalfeasanceHandlerReportInvalidProofCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockHandlerV1ReportProofCall) Return() *MockHandlerV1ReportProofCall { +func (c *MockMalfeasanceHandlerReportInvalidProofCall) Return() *MockMalfeasanceHandlerReportInvalidProofCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do -func (c *MockHandlerV1ReportProofCall) Do(f func(*prometheus.CounterVec)) *MockHandlerV1ReportProofCall { +func (c *MockMalfeasanceHandlerReportInvalidProofCall) Do(f func(*prometheus.CounterVec)) *MockMalfeasanceHandlerReportInvalidProofCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockHandlerV1ReportProofCall) DoAndReturn(f func(*prometheus.CounterVec)) *MockHandlerV1ReportProofCall { +func (c *MockMalfeasanceHandlerReportInvalidProofCall) DoAndReturn(f func(*prometheus.CounterVec)) *MockMalfeasanceHandlerReportInvalidProofCall { c.Call = c.Call.DoAndReturn(f) return c } -// Validate mocks base method. -func (m *MockHandlerV1) Validate(ctx context.Context, data wire.ProofData) (types.NodeID, error) { +// ReportProof mocks base method. +func (m *MockMalfeasanceHandler) ReportProof(vec *prometheus.CounterVec) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Validate", ctx, data) - ret0, _ := ret[0].(types.NodeID) - ret1, _ := ret[1].(error) - return ret0, ret1 + m.ctrl.Call(m, "ReportProof", vec) } -// Validate indicates an expected call of Validate. -func (mr *MockHandlerV1MockRecorder) Validate(ctx, data any) *MockHandlerV1ValidateCall { +// ReportProof indicates an expected call of ReportProof. +func (mr *MockMalfeasanceHandlerMockRecorder) ReportProof(vec any) *MockMalfeasanceHandlerReportProofCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockHandlerV1)(nil).Validate), ctx, data) - return &MockHandlerV1ValidateCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportProof", reflect.TypeOf((*MockMalfeasanceHandler)(nil).ReportProof), vec) + return &MockMalfeasanceHandlerReportProofCall{Call: call} } -// MockHandlerV1ValidateCall wrap *gomock.Call -type MockHandlerV1ValidateCall struct { +// MockMalfeasanceHandlerReportProofCall wrap *gomock.Call +type MockMalfeasanceHandlerReportProofCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockHandlerV1ValidateCall) Return(arg0 types.NodeID, arg1 error) *MockHandlerV1ValidateCall { - c.Call = c.Call.Return(arg0, arg1) +func (c *MockMalfeasanceHandlerReportProofCall) Return() *MockMalfeasanceHandlerReportProofCall { + c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do -func (c *MockHandlerV1ValidateCall) Do(f func(context.Context, wire.ProofData) (types.NodeID, error)) *MockHandlerV1ValidateCall { +func (c *MockMalfeasanceHandlerReportProofCall) Do(f func(*prometheus.CounterVec)) *MockMalfeasanceHandlerReportProofCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockHandlerV1ValidateCall) DoAndReturn(f func(context.Context, wire.ProofData) (types.NodeID, error)) *MockHandlerV1ValidateCall { +func (c *MockMalfeasanceHandlerReportProofCall) DoAndReturn(f func(*prometheus.CounterVec)) *MockMalfeasanceHandlerReportProofCall { c.Call = c.Call.DoAndReturn(f) return c } -// MockHandlerV2 is a mock of HandlerV2 interface. -type MockHandlerV2 struct { - ctrl *gomock.Controller - recorder *MockHandlerV2MockRecorder -} - -// MockHandlerV2MockRecorder is the mock recorder for MockHandlerV2. -type MockHandlerV2MockRecorder struct { - mock *MockHandlerV2 -} - -// NewMockHandlerV2 creates a new mock instance. -func NewMockHandlerV2(ctrl *gomock.Controller) *MockHandlerV2 { - mock := &MockHandlerV2{ctrl: ctrl} - mock.recorder = &MockHandlerV2MockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockHandlerV2) EXPECT() *MockHandlerV2MockRecorder { - return m.recorder -} - // Validate mocks base method. -func (m *MockHandlerV2) Validate(ctx context.Context, data []byte) (types.NodeID, error) { +func (m *MockMalfeasanceHandler) Validate(ctx context.Context, data wire.ProofData) (types.NodeID, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Validate", ctx, data) ret0, _ := ret[0].(types.NodeID) @@ -245,31 +222,31 @@ func (m *MockHandlerV2) Validate(ctx context.Context, data []byte) (types.NodeID } // Validate indicates an expected call of Validate. -func (mr *MockHandlerV2MockRecorder) Validate(ctx, data any) *MockHandlerV2ValidateCall { +func (mr *MockMalfeasanceHandlerMockRecorder) Validate(ctx, data any) *MockMalfeasanceHandlerValidateCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockHandlerV2)(nil).Validate), ctx, data) - return &MockHandlerV2ValidateCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockMalfeasanceHandler)(nil).Validate), ctx, data) + return &MockMalfeasanceHandlerValidateCall{Call: call} } -// MockHandlerV2ValidateCall wrap *gomock.Call -type MockHandlerV2ValidateCall struct { +// MockMalfeasanceHandlerValidateCall wrap *gomock.Call +type MockMalfeasanceHandlerValidateCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockHandlerV2ValidateCall) Return(arg0 types.NodeID, arg1 error) *MockHandlerV2ValidateCall { +func (c *MockMalfeasanceHandlerValidateCall) Return(arg0 types.NodeID, arg1 error) *MockMalfeasanceHandlerValidateCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockHandlerV2ValidateCall) Do(f func(context.Context, []byte) (types.NodeID, error)) *MockHandlerV2ValidateCall { +func (c *MockMalfeasanceHandlerValidateCall) Do(f func(context.Context, wire.ProofData) (types.NodeID, error)) *MockMalfeasanceHandlerValidateCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockHandlerV2ValidateCall) DoAndReturn(f func(context.Context, []byte) (types.NodeID, error)) *MockHandlerV2ValidateCall { +func (c *MockMalfeasanceHandlerValidateCall) DoAndReturn(f func(context.Context, wire.ProofData) (types.NodeID, error)) *MockMalfeasanceHandlerValidateCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/malfeasance/wire/malfeasance.go b/malfeasance/wire/malfeasance.go index 0ffd8e4228..32509721a8 100644 --- a/malfeasance/wire/malfeasance.go +++ b/malfeasance/wire/malfeasance.go @@ -101,8 +101,6 @@ type Proof struct { type ProofData interface { scale.Type - - isProof() } func (e *Proof) EncodeScale(enc *scale.Encoder) (int, error) { @@ -199,8 +197,6 @@ type AtxProof struct { Messages [2]AtxProofMsg } -func (ap *AtxProof) isProof() {} - func (ap *AtxProof) MarshalLogObject(encoder zapcore.ObjectEncoder) error { encoder.AddObject("first", &ap.Messages[0].InnerMsg) encoder.AddObject("second", &ap.Messages[1].InnerMsg) @@ -211,8 +207,6 @@ type BallotProof struct { Messages [2]BallotProofMsg } -func (bp *BallotProof) isProof() {} - func (bp *BallotProof) MarshalLogObject(encoder zapcore.ObjectEncoder) error { encoder.AddObject("first", &bp.Messages[0].InnerMsg) encoder.AddObject("second", &bp.Messages[1].InnerMsg) @@ -223,8 +217,6 @@ type HareProof struct { Messages [2]HareProofMsg } -func (hp *HareProof) isProof() {} - func (hp *HareProof) MarshalLogObject(encoder zapcore.ObjectEncoder) error { encoder.AddObject("first", &hp.Messages[0].InnerMsg) encoder.AddObject("second", &hp.Messages[1].InnerMsg) @@ -260,8 +252,6 @@ type InvalidPostIndexProof struct { InvalidIdx uint32 } -func (p *InvalidPostIndexProof) isProof() {} - type BallotProofMsg struct { InnerMsg types.BallotMetadata @@ -322,8 +312,6 @@ type InvalidPrevATXProof struct { Atx2 wire.ActivationTxV1 } -func (p *InvalidPrevATXProof) isProof() {} - func MalfeasanceInfo(smesher types.NodeID, mp *MalfeasanceProof) string { var b strings.Builder b.WriteString(fmt.Sprintf("generate layer: %v\n", mp.Layer)) diff --git a/mesh/malfeasance.go b/mesh/malfeasance.go index ad3f090e2a..fc074aa221 100644 --- a/mesh/malfeasance.go +++ b/mesh/malfeasance.go @@ -50,6 +50,19 @@ func NewMalfeasanceHandler( return mh } +func (mh *MalfeasanceHandler) Info(data wire.ProofData) (map[string]string, error) { + bp, ok := data.(*wire.BallotProof) + if !ok { + return nil, errors.New("wrong message type for multi ballots") + } + return map[string]string{ + "msg1": bp.Messages[0].InnerMsg.MsgHash.String(), + "msg2": bp.Messages[1].InnerMsg.MsgHash.String(), + "layer": bp.Messages[0].InnerMsg.Layer.String(), + "smesher_id": bp.Messages[0].SmesherID.String(), + }, nil +} + func (mh *MalfeasanceHandler) Validate(ctx context.Context, data wire.ProofData) (types.NodeID, error) { bp, ok := data.(*wire.BallotProof) if !ok { diff --git a/node/flags/string_to_uint64_test.go b/node/flags/string_to_uint64_test.go index cac3d11903..935c662694 100644 --- a/node/flags/string_to_uint64_test.go +++ b/node/flags/string_to_uint64_test.go @@ -53,8 +53,7 @@ func TestStringToUint64Value(t *testing.T) { full += arg + "," err := parser.Set(arg) if len(tc.err) > 0 { - require.Error(t, err) - require.Contains(t, err.Error(), tc.err) + require.ErrorContains(t, err, tc.err) } else { require.NoError(t, err) } @@ -67,8 +66,7 @@ func TestStringToUint64Value(t *testing.T) { require.Equal(t, tc.expected, value) require.Equal(t, tc.expected, valFull) } else { - require.Error(t, err) - require.Contains(t, err.Error(), tc.err) + require.ErrorContains(t, err, tc.err) } }) } diff --git a/node/node.go b/node/node.go index efd4f84110..3c66b0dee5 100644 --- a/node/node.go +++ b/node/node.go @@ -370,52 +370,53 @@ func New(opts ...Option) *App { // App is the cli app singleton. type App struct { *cobra.Command - fileLock *flock.Flock - signers []*signing.EdSigner - Config *config.Config - db sql.StateDatabase - cachedDB *datastore.CachedDB - dbMetrics *dbmetrics.DBMetricsCollector - localDB sql.LocalDatabase - grpcPublicServer *grpcserver.Server - grpcPrivateServer *grpcserver.Server - grpcPostServer *grpcserver.Server - grpcTLSServer *grpcserver.Server - jsonAPIServer *grpcserver.JSONHTTPServer - grpcServices map[grpcserver.Service]grpcserver.ServiceAPI - pprofService *http.Server - profilerService *pyroscope.Profiler - syncer *syncer.Syncer - proposalListener *proposals.Handler - proposalBuilder *miner.ProposalBuilder - mesh *mesh.Mesh - atxsdata *atxsdata.Data - clock *timesync.NodeClock - hare3 *hare3.Hare - hare4 *hare4.Hare - hareResultsChan chan hare4.ConsensusOutput - hOracle *eligibility.Oracle - blockGen *blocks.Generator - certifier *blocks.Certifier - atxBuilder *activation.Builder - nipostBuilder *activation.NIPostBuilder - atxHandler *activation.Handler - txHandler *txs.TxHandler - validator *activation.Validator - edVerifier *signing.EdVerifier - beaconProtocol *beacon.ProtocolDriver - log log.Log - syncLogger log.Log - svm *vm.VM - conState *txs.ConservativeState - fetcher *fetch.Fetch - ptimesync *peersync.Sync - tortoise *tortoise.Tortoise - updater *bootstrap.Updater - poetDb *activation.PoetDb - postVerifier activation.PostVerifier - postSupervisor *activation.PostSupervisor - errCh chan error + fileLock *flock.Flock + signers []*signing.EdSigner + Config *config.Config + db sql.StateDatabase + cachedDB *datastore.CachedDB + dbMetrics *dbmetrics.DBMetricsCollector + localDB sql.LocalDatabase + grpcPublicServer *grpcserver.Server + grpcPrivateServer *grpcserver.Server + grpcPostServer *grpcserver.Server + grpcTLSServer *grpcserver.Server + jsonAPIServer *grpcserver.JSONHTTPServer + grpcServices map[grpcserver.Service]grpcserver.ServiceAPI + pprofService *http.Server + profilerService *pyroscope.Profiler + syncer *syncer.Syncer + proposalListener *proposals.Handler + proposalBuilder *miner.ProposalBuilder + mesh *mesh.Mesh + atxsdata *atxsdata.Data + clock *timesync.NodeClock + hare3 *hare3.Hare + hare4 *hare4.Hare + hareResultsChan chan hare4.ConsensusOutput + hOracle *eligibility.Oracle + blockGen *blocks.Generator + certifier *blocks.Certifier + atxBuilder *activation.Builder + nipostBuilder *activation.NIPostBuilder + atxHandler *activation.Handler + txHandler *txs.TxHandler + validator *activation.Validator + edVerifier *signing.EdVerifier + beaconProtocol *beacon.ProtocolDriver + log log.Log + syncLogger log.Log + svm *vm.VM + conState *txs.ConservativeState + fetcher *fetch.Fetch + ptimesync *peersync.Sync + tortoise *tortoise.Tortoise + updater *bootstrap.Updater + poetDb *activation.PoetDb + postVerifier activation.PostVerifier + postSupervisor *activation.PostSupervisor + malfeasanceHandler *malfeasance.Handler + errCh chan error host *p2p.Host @@ -1143,18 +1144,18 @@ func (app *App) initServices(ctx context.Context) error { for _, s := range app.signers { nodeIDs = append(nodeIDs, s.NodeID()) } - malfeasanceHandler := malfeasance.NewHandler( + app.malfeasanceHandler = malfeasance.NewHandler( app.cachedDB, malfeasanceLogger, app.host.ID(), nodeIDs, trtl, ) - malfeasanceHandler.RegisterHandlerV1(malfeasance.MultipleATXs, activationMH) - malfeasanceHandler.RegisterHandlerV1(malfeasance.MultipleBallots, meshMH) - malfeasanceHandler.RegisterHandlerV1(malfeasance.HareEquivocation, hareMH) - malfeasanceHandler.RegisterHandlerV1(malfeasance.InvalidPostIndex, invalidPostMH) - malfeasanceHandler.RegisterHandlerV1(malfeasance.InvalidPrevATX, invalidPrevMH) + app.malfeasanceHandler.RegisterHandler(malfeasance.MultipleATXs, activationMH) + app.malfeasanceHandler.RegisterHandler(malfeasance.MultipleBallots, meshMH) + app.malfeasanceHandler.RegisterHandler(malfeasance.HareEquivocation, hareMH) + app.malfeasanceHandler.RegisterHandler(malfeasance.InvalidPostIndex, invalidPostMH) + app.malfeasanceHandler.RegisterHandler(malfeasance.InvalidPrevATX, invalidPrevMH) fetcher.SetValidators( fetch.ValidatorFunc( @@ -1199,7 +1200,7 @@ func (app *App) initServices(ctx context.Context) error { ), fetch.ValidatorFunc( pubsub.DropPeerOnSyncValidationReject( - malfeasanceHandler.HandleSyncedMalfeasanceProof, + app.malfeasanceHandler.HandleSyncedMalfeasanceProof, app.host, lg.Zap(), ), @@ -1260,7 +1261,7 @@ func (app *App) initServices(ctx context.Context) error { ) app.host.Register( pubsub.MalfeasanceProof, - pubsub.ChainGossipHandler(atxSyncHandler, malfeasanceHandler.HandleMalfeasanceProof), + pubsub.ChainGossipHandler(atxSyncHandler, app.malfeasanceHandler.HandleMalfeasanceProof), ) app.proposalBuilder = proposalBuilder @@ -1558,10 +1559,19 @@ func (app *App) grpcService(svc grpcserver.Service, lg log.Log) (grpcserver.Serv service := v2alpha1.NewRewardStreamService(app.db) app.grpcServices[svc] = service return service, nil + case v2alpha1.Malfeasance: + service := v2alpha1.NewMalfeasanceService(app.db, app.malfeasanceHandler) + app.grpcServices[svc] = service + return service, nil + case v2alpha1.MalfeasanceStream: + service := v2alpha1.NewMalfeasanceStreamService(app.db, app.malfeasanceHandler) + app.grpcServices[svc] = service + return service, nil case v2alpha1.Network: service := v2alpha1.NewNetworkService( app.clock.GenesisTime(), - app.Config) + app.Config, + ) app.grpcServices[svc] = service return service, nil case v2alpha1.Node: diff --git a/proposals/handler.go b/proposals/handler.go index c1bac83869..c5729dc357 100644 --- a/proposals/handler.go +++ b/proposals/handler.go @@ -490,7 +490,7 @@ func (h *Handler) checkBallotSyntacticValidity(ctx context.Context, b *types.Bal decoded, err := h.tortoise.DecodeBallot(b.ToTortoiseData()) if err != nil { return nil, fmt.Errorf( - "%w: failed to decode ballot id %s. %v", + "%w: failed to decode ballot id %s: %v", fetch.ErrIgnore, b.ID().AsHash32().ShortString(), err, diff --git a/proposals/handler_test.go b/proposals/handler_test.go index 8a466fba35..9d3208f74c 100644 --- a/proposals/handler_test.go +++ b/proposals/handler_test.go @@ -808,7 +808,7 @@ func TestBallot_DecodeBeforeVotesConsistency(t *testing.T) { th.md.EXPECT().DecodeBallot(decoded.BallotTortoiseData).Return(decoded, expected) err := th.HandleSyncedBallot(context.Background(), b.ID().AsHash32(), peer, data) require.ErrorIs(t, err, fetch.ErrIgnore) - require.Contains(t, err.Error(), expected.Error()) + require.ErrorContains(t, err, expected.Error()) } func TestBallot_DecodedStoreFailure(t *testing.T) { diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index a94f3afa73..21b3ed0815 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -780,7 +780,8 @@ func IterateAtxsOps( _, err := db.Exec( fullQuery+builder.FilterFrom(operations), builder.BindingsFrom(operations), - decoder(fn)) + decoder(fn), + ) return err } diff --git a/sql/builder/builder.go b/sql/builder/builder.go index d7342ed727..4af899b3c7 100644 --- a/sql/builder/builder.go +++ b/sql/builder/builder.go @@ -11,13 +11,14 @@ import ( type token string const ( - Eq token = "=" - NotEq token = "!=" - Gt token = ">" - Gte token = ">=" - Lt token = "<" - Lte token = "<=" - In token = "in" + Eq token = "=" + NotEq token = "!=" + Gt token = ">" + Gte token = ">=" + Lt token = "<" + Lte token = "<=" + In token = "in" + IsNotNull token = "is not null" ) type operator string @@ -37,6 +38,7 @@ const ( Layer field = "layer" Address field = "address" Principal field = "principal" + Proof field = "proof" ) type modifier string @@ -135,22 +137,26 @@ func FilterFrom(operations Operations) string { } } queryBuilder.WriteString(" )") - } else { - if op.Token == In { - values, ok := op.Value.([][]byte) - if !ok { - panic("value for 'In' token must be a slice of []byte") - } - params := make([]string, len(values)) - for j := range values { - params[j] = fmt.Sprintf("?%d", bindIndex) - bindIndex++ - } - fmt.Fprintf(&queryBuilder, " %s%s %s (%s)", op.Prefix, op.Field, op.Token, strings.Join(params, ", ")) - } else { - fmt.Fprintf(&queryBuilder, " %s%s %s ?%d", op.Prefix, op.Field, op.Token, bindIndex) + continue + } + + switch op.Token { + case In: + values, ok := op.Value.([][]byte) + if !ok { + panic("value for 'In' token must be a slice of []byte") + } + params := make([]string, len(values)) + for j := range values { + params[j] = fmt.Sprintf("?%d", bindIndex) bindIndex++ } + fmt.Fprintf(&queryBuilder, " %s%s %s (%s)", op.Prefix, op.Field, op.Token, strings.Join(params, ", ")) + case IsNotNull: + fmt.Fprintf(&queryBuilder, " %s%s %s", op.Prefix, op.Field, op.Token) + default: + fmt.Fprintf(&queryBuilder, " %s%s %s ?%d", op.Prefix, op.Field, op.Token, bindIndex) + bindIndex++ } } @@ -191,6 +197,8 @@ func bindValue(stmt *sql.Statement, bindIndex int, value any) int { stmt.BindBytes(bindIndex, v) bindIndex++ } + case nil: + // do nothing default: panic(fmt.Sprintf("unexpected type %T", value)) } diff --git a/sql/builder/builder_test.go b/sql/builder/builder_test.go index 74988bb802..11a654ba8c 100644 --- a/sql/builder/builder_test.go +++ b/sql/builder/builder_test.go @@ -3,6 +3,8 @@ package builder import ( "testing" + "github.com/stretchr/testify/require" + "github.com/spacemeshos/go-spacemesh/common/types" ) @@ -16,10 +18,7 @@ func TestFilterFrom_WithSingleFilter(t *testing.T) { expected := " where epoch = ?1" actual := FilterFrom(operations) - - if actual != expected { - t.Errorf("Expected '%s', but got '%s'", expected, actual) - } + require.Equal(t, expected, actual) } func TestFilterFrom_WithMultipleFilters(t *testing.T) { @@ -33,10 +32,7 @@ func TestFilterFrom_WithMultipleFilters(t *testing.T) { expected := " where epoch = ?1 and pubkey = ?2" actual := FilterFrom(operations) - - if actual != expected { - t.Errorf("Expected '%s', but got '%s'", expected, actual) - } + require.Equal(t, expected, actual) } func TestFilterFrom_WithGroupFilters(t *testing.T) { @@ -55,10 +51,7 @@ func TestFilterFrom_WithGroupFilters(t *testing.T) { expected := " where ( epoch = ?1 and pubkey = ?2 )" actual := FilterFrom(operations) - - if actual != expected { - t.Errorf("Expected '%s', but got '%s'", expected, actual) - } + require.Equal(t, expected, actual) } func TestFilterFrom_WithInToken(t *testing.T) { @@ -71,10 +64,7 @@ func TestFilterFrom_WithInToken(t *testing.T) { expected := " where epoch in (?1, ?2)" actual := FilterFrom(operations) - - if actual != expected { - t.Errorf("Expected '%s', but got '%s'", expected, actual) - } + require.Equal(t, expected, actual) } func TestFilterFrom_WithModifiers(t *testing.T) { @@ -91,8 +81,19 @@ func TestFilterFrom_WithModifiers(t *testing.T) { expected := " where epoch = ?1 order by epoch limit 10" actual := FilterFrom(operations) + require.Equal(t, expected, actual) +} - if actual != expected { - t.Errorf("Expected '%s', but got '%s'", expected, actual) +func TestFilterFrom_NotNull(t *testing.T) { + t.Parallel() + operations := Operations{ + Filter: []Op{ + {Field: Proof, Token: IsNotNull}, + {Field: Id, Token: Eq, Value: 1}, + }, } + + expected := " where proof is not null and id = ?1" + actual := FilterFrom(operations) + require.Equal(t, expected, actual) } diff --git a/sql/database_test.go b/sql/database_test.go index 0104c08813..4b528d704f 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -563,7 +563,7 @@ func TestSchemaDrift(t *testing.T) { WithLogger(logger), ) require.Error(t, err) - require.Contains(t, err.Error(), "newtbl") + require.ErrorContains(t, err, "newtbl") require.Equal(t, 0, observedLogs.Len(), "expected 0 log messages") db, err = Open("file:"+dbFile, diff --git a/sql/identities/identities.go b/sql/identities/identities.go index 257bec03ee..df34f50296 100644 --- a/sql/identities/identities.go +++ b/sql/identities/identities.go @@ -9,6 +9,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/builder" ) // SetMalicious records identity as malicious. @@ -62,6 +63,26 @@ func LoadMalfeasanceBlob(_ context.Context, db sql.Executor, nodeID []byte, blob return err } +func IterateMaliciousOps( + db sql.Executor, + operations builder.Operations, + fn func(types.NodeID, []byte, time.Time) bool, +) error { + _, err := db.Exec( + "select pubkey, proof, received from identities"+builder.FilterFrom(operations), + builder.BindingsFrom(operations), + func(stmt *sql.Statement) bool { + var id types.NodeID + stmt.ColumnBytes(0, id[:]) + proof := make([]byte, stmt.ColumnLen(1)) + stmt.ColumnBytes(1, proof) + received := time.Unix(0, stmt.ColumnInt64(2)) + return fn(id, proof, received) + }, + ) + return err +} + // IterateMalicious invokes the specified callback for each malicious node ID. // It stops if the callback returns an error. func IterateMalicious( diff --git a/sql/identities/identities_test.go b/sql/identities/identities_test.go index 555a9be270..44f3001673 100644 --- a/sql/identities/identities_test.go +++ b/sql/identities/identities_test.go @@ -11,6 +11,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/builder" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/statesql" ) @@ -302,3 +303,96 @@ func TestEquivocationSetByMarriageATX(t *testing.T) { require.Empty(t, set) }) } + +func Test_IterateMaliciousOps(t *testing.T) { + db := statesql.InMemory() + tt := []struct { + id types.NodeID + proof []byte + }{ + { + types.RandomNodeID(), + types.RandomBytes(11), + }, + { + types.RandomNodeID(), + types.RandomBytes(11), + }, + { + types.RandomNodeID(), + types.RandomBytes(11), + }, + } + + for _, tc := range tt { + err := identities.SetMalicious(db, tc.id, tc.proof, time.Now()) + require.NoError(t, err) + } + + var got []struct { + id types.NodeID + proof []byte + } + err := identities.IterateMaliciousOps(db, builder.Operations{}, + func(id types.NodeID, proof []byte, _ time.Time) bool { + got = append(got, struct { + id types.NodeID + proof []byte + }{id, proof}) + return true + }) + require.NoError(t, err) + require.ElementsMatch(t, tt, got) +} + +func Test_IterateMaliciousOpsWithFilter(t *testing.T) { + db := statesql.InMemory() + tt := []struct { + id types.NodeID + proof []byte + }{ + { + types.RandomNodeID(), + types.RandomBytes(11), + }, + { + types.RandomNodeID(), + nil, + }, + { + types.RandomNodeID(), + types.RandomBytes(11), + }, + } + + for _, tc := range tt { + err := identities.SetMalicious(db, tc.id, tc.proof, time.Now()) + require.NoError(t, err) + } + + var got []struct { + id types.NodeID + proof []byte + } + ops := builder.Operations{} + ops.Filter = append(ops.Filter, builder.Op{ + Field: builder.Smesher, + Token: builder.In, + Value: [][]byte{tt[0].id.Bytes(), tt[1].id.Bytes()}, // first two ids + }) + ops.Filter = append(ops.Filter, builder.Op{ + Field: builder.Proof, + Token: builder.IsNotNull, // only entries which have a proof + }) + + err := identities.IterateMaliciousOps(db, ops, func(id types.NodeID, proof []byte, _ time.Time) bool { + got = append(got, struct { + id types.NodeID + proof []byte + }{id, proof}) + return true + }) + require.NoError(t, err) + // only the first element should be in the result + require.ElementsMatch(t, tt[:1], got) +} diff --git a/syncer/malsync/syncer.go b/syncer/malsync/syncer.go index 633e4a284a..f35e2175cb 100644 --- a/syncer/malsync/syncer.go +++ b/syncer/malsync/syncer.go @@ -283,7 +283,7 @@ func (s *Syncer) downloadNodeIDs(ctx context.Context, initial bool, updates chan case <-ctx.Done(): return nil // TODO(ivan4th) this has to be randomized in a followup - // when sync will be schedulled in advance, in order to smooth out request rate across the network + // when sync will be scheduled in advance, in order to smooth out request rate across the network case <-s.clock.After(interval): } }