diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 4caa98654e9..42c6ddde9f4 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -608,7 +608,7 @@ func (h *HandlerV2) syntacticallyValidateDeps( // validate all niposts var smesherCommitment *types.ATXID - for _, niposts := range atx.NIPosts { + for idx, niposts := range atx.NIPosts { for _, post := range niposts.Posts { id := equivocationSet[post.MarriageIndex] var commitment types.ATXID @@ -632,21 +632,14 @@ func (h *HandlerV2) syntacticallyValidateDeps( id, commitment, wire.PostFromWireV1(&post.Post), - niposts.Challenge[:], + niposts.Challenge.Bytes(), post.NumUnits, PostSubset([]byte(h.local)), ) invalidIdx := &verifying.ErrInvalidIndex{} if errors.As(err, invalidIdx) { - h.logger.Debug( - "ATX with invalid post index", - zap.Stringer("id", atx.ID()), - zap.Int("index", invalidIdx.Index), - ) - // TODO(mafa): publish solo or merged invalid post malfeasance proof - var proof wire.Proof - if err := h.malPublisher.Publish(ctx, id, proof); err != nil { - return nil, fmt.Errorf("publishing malfeasance proof for invalid post: %w", err) + if err := h.publishInvalidPostProof(ctx, atx, id, idx, uint32(invalidIdx.Index)); err != nil { + return nil, fmt.Errorf("publishing invalid post proof: %w", err) } } if err != nil { @@ -674,6 +667,49 @@ func (h *HandlerV2) syntacticallyValidateDeps( return &result, nil } +func (h *HandlerV2) publishInvalidPostProof( + ctx context.Context, + atx *wire.ActivationTxV2, + nodeID types.NodeID, + nipostIndex int, + invalidPostIndex uint32, +) error { + h.logger.Debug( + "ATX with invalid post index", + zap.Stringer("id", atx.ID()), + zap.Uint32("index", invalidPostIndex), + ) + initialAtx := &wire.ActivationTxV2{} + if atx.Initial != nil { + initialAtx = atx + } else { + initialID, err := atxs.GetFirstIDByNodeID(h.cdb, nodeID) + if err != nil { + return fmt.Errorf("fetch initial ATX for ID %s: %w", nodeID.ShortString(), err) + } + + var initialAtxBytes sql.Blob + v, err := atxs.LoadBlob(ctx, h.cdb, initialID.Bytes(), &initialAtxBytes) + if err != nil { + return fmt.Errorf("fetch initial ATX blob for ID %s: %w", nodeID.ShortString(), err) + } + if v != types.AtxV2 { + // TODO(mafa): this needs to be fixed + return fmt.Errorf("initial ATX is not V2 for ID %s", nodeID.ShortString()) + } + codec.MustDecode(initialAtxBytes.Bytes, initialAtx) + } + + proof, err := wire.NewInvalidPostProof(h.cdb, atx, initialAtx, nodeID, nipostIndex, invalidPostIndex) + if err != nil { + return fmt.Errorf("creating invalid post proof: %w", err) + } + if err := h.malPublisher.Publish(ctx, nodeID, proof); err != nil { + return fmt.Errorf("publishing malfeasance proof for invalid post: %w", err) + } + return nil +} + func (h *HandlerV2) checkMalicious(ctx context.Context, tx sql.Transaction, atx *activationTx) (bool, error) { malicious, err := malfeasance.IsMalicious(tx, atx.SmesherID) if err != nil { @@ -808,6 +844,7 @@ func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx sql.Transaction, at zap.Stringer("smesher_id", atx.SmesherID), ) + // TODO(mafa): finish proof var proof wire.Proof return true, h.malPublisher.Publish(ctx, atx.SmesherID, proof) } diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 03dec1db2cb..1eb6b550b46 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -1523,6 +1523,7 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { require.ErrorContains(t, err, "post failure") }) t.Run("invalid PoST index - generates a malfeasance proof", func(t *testing.T) { + // TODO(mafa): add such a test for solo and merged ATXs atxHandler := newV2TestHandler(t, golden) atx := newInitialATXv2(t, golden) @@ -1531,9 +1532,9 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { atxHandler.mValidator.EXPECT().PoetMembership(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) atxHandler.mValidator.EXPECT(). PostV2( - gomock.Any(), - sig.NodeID(), - golden, + context.Background(), + atx.SmesherID, + atx.Initial.CommitmentATX, wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), atx.NIPosts[0].Challenge.Bytes(), atx.TotalNumUnits(), @@ -1541,8 +1542,36 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { ). Return(verifying.ErrInvalidIndex{Index: 7}) - // TODO(mafa): update assertion to expect a malfeasance proof that can be verified - atxHandler.mMalPublish.EXPECT().Publish(gomock.Any(), sig.NodeID(), gomock.Any()) + verifier := wire.NewMockMalfeasanceValidator(atxHandler.ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return atxHandler.edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + verifier.EXPECT().PostIndex( + context.Background(), + atx.SmesherID, + atx.Initial.CommitmentATX, + wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), + atx.NIPosts[0].Challenge.Bytes(), + atx.TotalNumUnits(), + 7, + ).Return(errors.New("invalid post index")) + + atxHandler.mMalPublish.EXPECT().Publish( + gomock.Any(), + sig.NodeID(), + gomock.Cond(func(data wire.Proof) bool { + _, ok := data.(*wire.ProofInvalidPost) + return ok + }), + ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { + malProof := proof.(*wire.ProofInvalidPost) + nId, err := malProof.Valid(ctx, verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), nId) + return nil + }) _, err := atxHandler.syntacticallyValidateDeps(context.Background(), atx) vErr := &verifying.ErrInvalidIndex{} require.ErrorAs(t, err, vErr) @@ -1961,6 +1990,9 @@ func newInitialATXv2(tb testing.TB, golden types.ATXID) *wire.ActivationTxV2 { Initial: &wire.InitialAtxPartsV2{CommitmentATX: golden}, NIPosts: []wire.NIPostV2{ { + Membership: wire.MerkleProofV2{ + Nodes: make([]types.Hash32, 32), + }, Challenge: types.RandomHash(), Posts: []wire.SubPostV2{ { diff --git a/activation/wire/wire_v2.go b/activation/wire/wire_v2.go index 3b809c2244a..08a1390f5d5 100644 --- a/activation/wire/wire_v2.go +++ b/activation/wire/wire_v2.go @@ -464,9 +464,15 @@ func (sp *SubPostV2) merkleTree(tree *merkle.Tree, prevATXs []types.ATXID) { binary.LittleEndian.PutUint32(marriageIndex[:], sp.MarriageIndex) tree.AddLeaf(marriageIndex.Bytes()) - if int(sp.PrevATXIndex) < len(prevATXs) { - // if prevATXIndex is out of range, it will be detected by syntactical validation + switch { + case len(prevATXs) == 0: // special case for initial ATX: prevATXs is empty + tree.AddLeaf(types.EmptyATXID.Bytes()) + case int(sp.PrevATXIndex) < len(prevATXs): tree.AddLeaf(prevATXs[sp.PrevATXIndex].Bytes()) + default: + // prevATXIndex is out of range, don't fail ATXID generation + // will be detected by syntactical validation + tree.AddLeaf(types.EmptyATXID.Bytes()) } var leafIndex types.Hash32