forked from All-Hands-AI/OpenHands
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_chunk_localizer.py
136 lines (105 loc) Β· 4.38 KB
/
test_chunk_localizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import pytest
from openhands.utils.chunk_localizer import (
Chunk,
create_chunks,
get_top_k_chunk_matches,
normalized_lcs,
)
def test_chunk_creation():
chunk = Chunk(text='test chunk', line_range=(1, 1))
assert chunk.text == 'test chunk'
assert chunk.line_range == (1, 1)
assert chunk.normalized_lcs is None
def test_chunk_visualization(capsys):
chunk = Chunk(text='line1\nline2', line_range=(1, 2))
assert chunk.visualize() == '1|line1\n2|line2\n'
def test_create_chunks_raw_string():
text = 'line1\nline2\nline3\nline4\nline5'
chunks = create_chunks(text, size=2)
assert len(chunks) == 3
assert chunks[0].text == 'line1\nline2'
assert chunks[0].line_range == (1, 2)
assert chunks[1].text == 'line3\nline4'
assert chunks[1].line_range == (3, 4)
assert chunks[2].text == 'line5'
assert chunks[2].line_range == (5, 5)
def test_normalized_lcs():
chunk = 'abcdef'
edit_draft = 'abcxyz'
assert normalized_lcs(chunk, edit_draft) == 0.5
def test_get_top_k_chunk_matches():
text = 'chunk1\nchunk2\nchunk3\nchunk4'
query = 'chunk2'
matches = get_top_k_chunk_matches(text, query, k=2, max_chunk_size=1)
assert len(matches) == 2
assert matches[0].text == 'chunk2'
assert matches[0].line_range == (2, 2)
assert matches[0].normalized_lcs == 1.0
assert matches[1].text == 'chunk1'
assert matches[1].line_range == (1, 1)
assert matches[1].normalized_lcs == 5 / 6
assert matches[0].normalized_lcs > matches[1].normalized_lcs
def test_create_chunks_with_empty_lines():
text = 'line1\n\nline3\n\n\nline6'
chunks = create_chunks(text, size=2)
assert len(chunks) == 3
assert chunks[0].text == 'line1\n'
assert chunks[0].line_range == (1, 2)
assert chunks[1].text == 'line3\n'
assert chunks[1].line_range == (3, 4)
assert chunks[2].text == '\nline6'
assert chunks[2].line_range == (5, 6)
def test_create_chunks_with_large_size():
text = 'line1\nline2\nline3'
chunks = create_chunks(text, size=10)
assert len(chunks) == 1
assert chunks[0].text == text
assert chunks[0].line_range == (1, 3)
def test_create_chunks_with_last_chunk_smaller():
text = 'line1\nline2\nline3'
chunks = create_chunks(text, size=2)
assert len(chunks) == 2
assert chunks[0].text == 'line1\nline2'
assert chunks[0].line_range == (1, 2)
assert chunks[1].text == 'line3'
assert chunks[1].line_range == (3, 3)
def test_normalized_lcs_edge_cases():
assert normalized_lcs('', '') == 0.0
assert normalized_lcs('a', '') == 0.0
assert normalized_lcs('', 'a') == 0.0
assert normalized_lcs('abcde', 'ace') == 0.6
def test_get_top_k_chunk_matches_with_ties():
text = 'chunk1\nchunk2\nchunk3\nchunk1'
query = 'chunk'
matches = get_top_k_chunk_matches(text, query, k=3, max_chunk_size=1)
assert len(matches) == 3
assert all(match.normalized_lcs == 5 / 6 for match in matches)
assert {match.text for match in matches} == {'chunk1', 'chunk2', 'chunk3'}
def test_get_top_k_chunk_matches_with_large_k():
text = 'chunk1\nchunk2\nchunk3'
query = 'chunk'
matches = get_top_k_chunk_matches(text, query, k=10, max_chunk_size=1)
assert len(matches) == 3 # Should return all chunks even if k is larger
@pytest.mark.parametrize('chunk_size', [1, 2, 3, 4])
def test_create_chunks_different_sizes(chunk_size):
text = 'line1\nline2\nline3\nline4'
chunks = create_chunks(text, size=chunk_size)
assert len(chunks) == (4 + chunk_size - 1) // chunk_size
assert sum(len(chunk.text.split('\n')) for chunk in chunks) == 4
def test_chunk_visualization_with_special_characters():
chunk = Chunk(text='line1\nline2\t\nline3\r', line_range=(1, 3))
assert chunk.visualize() == '1|line1\n2|line2\t\n3|line3\r\n'
def test_normalized_lcs_with_unicode():
chunk = 'Hello, δΈη!'
edit_draft = 'Hello, world!'
assert 0 < normalized_lcs(chunk, edit_draft) < 1
def test_get_top_k_chunk_matches_with_overlapping_chunks():
text = 'chunk1\nchunk2\nchunk3\nchunk4'
query = 'chunk2\nchunk3'
matches = get_top_k_chunk_matches(text, query, k=2, max_chunk_size=2)
assert len(matches) == 2
assert matches[0].text == 'chunk1\nchunk2'
assert matches[0].line_range == (1, 2)
assert matches[1].text == 'chunk3\nchunk4'
assert matches[1].line_range == (3, 4)
assert matches[0].normalized_lcs == matches[1].normalized_lcs