forked from SciPhi-AI/R2R
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathembedding.py
75 lines (61 loc) · 2.25 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
Abstract base class for embedding pipelines.
"""
import logging
import uuid
from abc import abstractmethod
from typing import Any, Optional
from ..providers.embedding import EmbeddingProvider
from ..providers.logging import LoggingDatabaseConnection
from ..providers.vector_db import VectorDBProvider, VectorEntry
from .pipeline import Pipeline
logger = logging.getLogger(__name__)
class EmbeddingPipeline(Pipeline):
def __init__(
self,
embedding_model: str,
embeddings_provider: EmbeddingProvider,
db: VectorDBProvider,
logging_provider: Optional[LoggingDatabaseConnection] = None,
**kwargs,
):
self.embedding_model = embedding_model
self.embeddings_provider = embeddings_provider
self.db = db
super().__init__(logging_provider=logging_provider, **kwargs)
def initialize_pipeline(self) -> None:
self.pipeline_run_info = {"run_id": uuid.uuid4(), "type": "embedding"}
@abstractmethod
def extract_text(self, document: Any) -> str:
pass
@abstractmethod
def transform_text(self, text: str) -> str:
pass
@abstractmethod
def chunk_text(self, text: str) -> list[str]:
pass
@abstractmethod
def transform_chunks(
self, chunks: list[Any], metadatas: list[dict]
) -> list[Any]:
pass
@abstractmethod
def embed_chunks(self, chunks: list[Any]) -> list[list[float]]:
pass
@abstractmethod
def store_chunks(self, chunks: list[VectorEntry], **kwargs) -> None:
pass
def run(self, document: Any, **kwargs):
self.initialize_pipeline()
logger.debug(
f"Running the `BasicEmbeddingPipeline` with id={self.pipeline_run_info['run_id']}."
)
logger.debug(f"Pipeline run type: {self.pipeline_run_info['type']}")
documents = [document] if not isinstance(document, list) else document
for document in documents:
text = self.extract_text(document)
transformed_text = self.transform_text(text)
chunks = self.chunk_text(transformed_text)
transformed_chunks = self.transform_chunks(chunks, [])
embeddings = self.embed_chunks(transformed_chunks)
self.store_chunks(embeddings)