diff --git a/execute/report/roots.go b/execute/report/roots.go index d3550bc37..09bb11b70 100644 --- a/execute/report/roots.go +++ b/execute/report/roots.go @@ -16,7 +16,7 @@ func ConstructMerkleTree( lggr logger.Logger, ) (*merklemulti.Tree[[32]byte], error) { // Ensure we have the expected number of messages - numMsgs := int(report.SequenceNumberRange.End() - report.SequenceNumberRange.Start() + 1) + numMsgs := report.SequenceNumberRange.Length() if numMsgs != len(report.Messages) { return nil, fmt.Errorf( "malformed report %s, unexpected number of messages: expected %d, got %d", diff --git a/pkg/reader/ccip.go b/pkg/reader/ccip.go index b0993a989..e55323b03 100644 --- a/pkg/reader/ccip.go +++ b/pkg/reader/ccip.go @@ -265,7 +265,7 @@ func (r *ccipChainReader) ExecutedMessageRanges( query.LimitAndSort{ SortBy: []query.SortBy{query.NewSortBySequence(query.Asc)}, Limit: query.Limit{ - Count: uint64(seqNumRange.End() - seqNumRange.Start() + 1), + Count: uint64(seqNumRange.Length()), }, }, &dataTyp, @@ -332,7 +332,7 @@ func (r *ccipChainReader) MsgsBetweenSeqNums( query.NewSortBySequence(query.Asc), }, Limit: query.Limit{ - Count: uint64(seqNumRange.End() - seqNumRange.Start() + 1), + Count: uint64(seqNumRange.Length()), }, }, &SendRequestedEvent{}, diff --git a/pkg/types/ccipocr3/generic_types.go b/pkg/types/ccipocr3/generic_types.go index 4895e4e3c..ae8c8de62 100644 --- a/pkg/types/ccipocr3/generic_types.go +++ b/pkg/types/ccipocr3/generic_types.go @@ -38,6 +38,9 @@ func (s SeqNum) String() string { } func NewSeqNumRange(start, end SeqNum) SeqNumRange { + if end < start { + start, end = end, start + } return SeqNumRange{start, end} } @@ -65,7 +68,7 @@ func (s *SeqNumRange) SetEnd(v SeqNum) { func (s *SeqNumRange) Limit(n uint64) SeqNumRange { limitedRange := NewSeqNumRange(s.Start(), s.End()) - numElems := s.End() - s.Start() + 1 + numElems := s.Length() if numElems <= 0 { return limitedRange } @@ -96,6 +99,9 @@ func (s SeqNumRange) String() string { } func (s SeqNumRange) Length() int { + if s.End() < s.Start() { + s[0], s[1] = s[1], s[0] + } return int(s.End() - s.Start() + 1) } diff --git a/pkg/types/ccipocr3/generic_types_test.go b/pkg/types/ccipocr3/generic_types_test.go index 207a6016a..d82d3fd8e 100644 --- a/pkg/types/ccipocr3/generic_types_test.go +++ b/pkg/types/ccipocr3/generic_types_test.go @@ -34,6 +34,10 @@ func TestSeqNumRange(t *testing.T) { assert.Equal(t, "[1 -> 2]", NewSeqNumRange(1, 2).String()) assert.Equal(t, "[0 -> 0]", SeqNumRange{}.String()) }) + + t.Run("end before start", func(t *testing.T) { + assert.Equal(t, NewSeqNumRange(10, 20), NewSeqNumRange(20, 10)) + }) } func TestSeqNumRange_Overlap(t *testing.T) { @@ -118,10 +122,10 @@ func TestSeqNumRangeLimit(t *testing.T) { want: NewSeqNumRange(0, 0), }, { - name: "wrong range", + name: "wrong range is repaired", rng: NewSeqNumRange(20, 15), n: 3, - want: NewSeqNumRange(20, 15), + want: NewSeqNumRange(15, 17), }, }