Skip to content

Commit

Permalink
Update embedding_fn signature to newest chroma db's (#969)
Browse files Browse the repository at this point in the history
  • Loading branch information
sidmohanty11 authored Nov 21, 2023
1 parent 9fcf213 commit 85f3ac4
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions embedchain/embedder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@
from embedchain.config.embedder.base import BaseEmbedderConfig

try:
from chromadb.api.types import Documents, Embeddings
from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
except RuntimeError:
from embedchain.utils import use_pysqlite3

use_pysqlite3()
from chromadb.api.types import Documents, Embeddings
from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction


class EmbeddingFunc(EmbeddingFunction):
def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
self.embedding_fn = embedding_fn

def __call__(self, input: Embeddable) -> Embeddings:
return self.embedding_fn(input)


class BaseEmbedder:
Expand Down Expand Up @@ -66,7 +74,4 @@ def _langchain_default_concept(embeddings: Any):
:rtype: Callable
"""

def embed_function(texts: Documents) -> Embeddings:
return embeddings.embed_documents(texts)

return embed_function
return EmbeddingFunc(embeddings.embed_documents)

0 comments on commit 85f3ac4

Please sign in to comment.