-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy path_embed.py
203 lines (187 loc) · 10.3 KB
/
_embed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""String embedder."""
from functools import partial
from typing import Literal
import numpy as np
from litellm import embedding
from llama_cpp import LLAMA_POOLING_TYPE_NONE, Llama
from tqdm.auto import tqdm, trange
from raglite._config import RAGLiteConfig
from raglite._litellm import LlamaCppPythonLLM
from raglite._typing import FloatMatrix, IntVector
def _embed_sentences_with_late_chunking( # noqa: PLR0915
sentences: list[str], *, config: RAGLiteConfig | None = None
) -> FloatMatrix:
"""Embed a document's sentences with late chunking."""
def _count_tokens(
sentences: list[str], embedder: Llama, sentinel_char: str, sentinel_tokens: list[int]
) -> list[int]:
# Join the sentences with the sentinel token and tokenise the result.
sentences_tokens = np.asarray(
embedder.tokenize(sentinel_char.join(sentences).encode(), add_bos=False), dtype=np.intp
)
# Map all sentinel token variants to the first one.
for sentinel_token in sentinel_tokens[1:]:
sentences_tokens[sentences_tokens == sentinel_token] = sentinel_tokens[0]
# Count how many tokens there are in between sentinel tokens to recover the token counts.
sentinel_indices = np.where(sentences_tokens == sentinel_tokens[0])[0]
num_tokens = np.diff(sentinel_indices, prepend=0, append=len(sentences_tokens))
assert len(num_tokens) == len(sentences), f"Sentinel `{sentinel_char}` appears in document"
num_tokens_list: list[int] = num_tokens.tolist()
return num_tokens_list
def _create_segment(
content_start_index: int,
max_tokens_preamble: int,
max_tokens_content: int,
num_tokens: IntVector,
) -> tuple[int, int]:
# Compute the segment sentence start index so that the segment preamble has no more than
# max_tokens_preamble tokens between [segment_start_index, content_start_index).
cumsum_backwards = np.cumsum(num_tokens[:content_start_index][::-1])
offset_preamble = np.searchsorted(cumsum_backwards, max_tokens_preamble, side="right")
segment_start_index = content_start_index - int(offset_preamble)
# Allow a larger segment content if we didn't use all of the allowed preamble tokens.
max_tokens_content = max_tokens_content + (
max_tokens_preamble - np.sum(num_tokens[segment_start_index:content_start_index])
)
# Compute the segment sentence end index so that the segment content has no more than
# max_tokens_content tokens between [content_start_index, segment_end_index).
cumsum_forwards = np.cumsum(num_tokens[content_start_index:])
offset_segment = np.searchsorted(cumsum_forwards, max_tokens_content, side="right")
segment_end_index = content_start_index + int(offset_segment)
return segment_start_index, segment_end_index
# Assert that we're using a llama-cpp-python model, since API-based embedding models don't
# support outputting token-level embeddings.
config = config or RAGLiteConfig()
assert config.embedder.startswith("llama-cpp-python")
embedder = LlamaCppPythonLLM.llm(
config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
)
n_ctx = embedder.n_ctx()
n_batch = embedder.n_batch
# Identify the tokens corresponding to a sentinel character.
sentinel_char = "⊕"
sentinel_test = f"A{sentinel_char}B {sentinel_char} C.\n{sentinel_char}D"
sentinel_tokens = [
token
for token in embedder.tokenize(sentinel_test.encode(), add_bos=False)
if sentinel_char in embedder.detokenize([token]).decode()
]
assert len(sentinel_tokens), f"Sentinel `{sentinel_char}` not supported by embedder"
# Compute the number of tokens per sentence. We use a method based on a sentinel token to
# minimise the number of calls to embedder.tokenize, which incurs a significant overhead
# (presumably to load the tokenizer) [1].
# TODO: Make token counting faster and more robust once [1] is fixed.
# [1] https://github.com/abetlen/llama-cpp-python/issues/1763
num_tokens_list: list[int] = []
sentence_batch, sentence_batch_len = [], 0
for i, sentence in enumerate(sentences):
sentence_batch.append(sentence)
sentence_batch_len += len(sentence)
if i == len(sentences) - 1 or sentence_batch_len > (n_ctx // 2):
num_tokens_list.extend(
_count_tokens(sentence_batch, embedder, sentinel_char, sentinel_tokens)
)
sentence_batch, sentence_batch_len = [], 0
num_tokens = np.asarray(num_tokens_list, dtype=np.intp)
# Compute the maximum number of tokens for each segment's preamble and content.
# Unfortunately, llama-cpp-python truncates the input to n_batch tokens and crashes if you try
# to increase it [1]. Until this is fixed, we have to limit max_tokens to n_batch.
# TODO: Improve the context window size once [1] is fixed.
# [1] https://github.com/abetlen/llama-cpp-python/issues/1762
max_tokens = min(n_ctx, n_batch) - 16
max_tokens_preamble = round(0.382 * max_tokens) # Golden ratio.
max_tokens_content = max_tokens - max_tokens_preamble
# Compute a list of segments, each consisting of a preamble and content.
segments = []
content_start_index = 0
while content_start_index < len(sentences):
segment_start_index, segment_end_index = _create_segment(
content_start_index, max_tokens_preamble, max_tokens_content, num_tokens
)
segments.append((segment_start_index, content_start_index, segment_end_index))
content_start_index = segment_end_index
# Embed the segments and apply late chunking.
sentence_embeddings_list: list[FloatMatrix] = []
if len(segments) > 1 or segments[0][2] > 128: # noqa: PLR2004
segments = tqdm(segments, desc="Embedding", unit="segment", dynamic_ncols=True)
for segment in segments:
# Get the token embeddings of the entire segment, including preamble and content.
segment_start_index, content_start_index, segment_end_index = segment
segment_sentences = sentences[segment_start_index:segment_end_index]
segment_embedding = np.asarray(embedder.embed("".join(segment_sentences)))
# Split the segment embeddings into embedding matrices per sentence.
segment_tokens = num_tokens[segment_start_index:segment_end_index]
sentence_size = np.round(
len(segment_embedding) * (segment_tokens / np.sum(segment_tokens))
).astype(np.intp)
sentence_matrices = np.split(segment_embedding, np.cumsum(sentence_size)[:-1])
# Compute the segment sentence embeddings by averaging the token embeddings.
content_sentence_embeddings = [
np.mean(sentence_matrix, axis=0, keepdims=True)
for sentence_matrix in sentence_matrices[content_start_index - segment_start_index :]
]
sentence_embeddings_list.append(np.vstack(content_sentence_embeddings))
sentence_embeddings = np.vstack(sentence_embeddings_list)
# Normalise the sentence embeddings to unit norm and cast to half precision.
if config.embedder_normalize:
sentence_embeddings /= np.linalg.norm(sentence_embeddings, axis=1, keepdims=True)
sentence_embeddings = sentence_embeddings.astype(np.float16)
return sentence_embeddings
def _embed_sentences_with_windowing(
sentences: list[str], *, config: RAGLiteConfig | None = None
) -> FloatMatrix:
"""Embed a document's sentences with windowing."""
def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> FloatMatrix:
# Embed the batch of strings.
if config.embedder.startswith("llama-cpp-python"):
# LiteLLM doesn't yet support registering a custom embedder, so we handle it here.
# Additionally, we explicitly manually pool the token embeddings to obtain sentence
# embeddings because token embeddings are universally supported, while sequence
# embeddings are only supported by some models.
embedder = LlamaCppPythonLLM.llm(
config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
)
embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)])
else:
# Use LiteLLM's API to embed the batch of strings.
response = embedding(config.embedder, string_batch)
embeddings = np.asarray([item["embedding"] for item in response["data"]])
# Normalise the embeddings to unit norm and cast to half precision.
if config.embedder_normalize:
embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings.astype(np.float16)
return embeddings
# Window the sentences with a lookback of `config.embedder_sentence_window_size - 1` sentences.
config = config or RAGLiteConfig()
sentence_windows = [
"".join(sentences[max(0, i - (config.embedder_sentence_window_size - 1)) : i + 1])
for i in range(len(sentences))
]
# Embed the sentence windows in batches.
batch_size = 64
batch_range = (
partial(trange, desc="Embedding", unit="batch", dynamic_ncols=True)
if len(sentence_windows) > batch_size
else range
)
batch_embeddings = [
_embed_string_batch(sentence_windows[i : i + batch_size], config=config)
for i in batch_range(0, len(sentence_windows), batch_size) # type: ignore[operator]
]
sentence_embeddings = np.vstack(batch_embeddings)
return sentence_embeddings
def sentence_embedding_type(
*,
config: RAGLiteConfig | None = None,
) -> Literal["late_chunking", "windowing"]:
"""Return the type of sentence embeddings."""
config = config or RAGLiteConfig()
return "late_chunking" if config.embedder.startswith("llama-cpp-python") else "windowing"
def embed_sentences(sentences: list[str], *, config: RAGLiteConfig | None = None) -> FloatMatrix:
"""Embed the sentences of a document as a NumPy matrix with one row per sentence."""
config = config or RAGLiteConfig()
if sentence_embedding_type(config=config) == "late_chunking":
sentence_embeddings = _embed_sentences_with_late_chunking(sentences, config=config)
else:
sentence_embeddings = _embed_sentences_with_windowing(sentences, config=config)
return sentence_embeddings