diff --git a/embedchain/apps/app.py b/embedchain/apps/app.py index 03dd3027fd..56d1353ead 100644 --- a/embedchain/apps/app.py +++ b/embedchain/apps/app.py @@ -10,7 +10,7 @@ from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.openai import OpenAIEmbedder from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm from embedchain.llm.openai import OpenAILlm from embedchain.utils import validate_yaml_config diff --git a/embedchain/bots/base.py b/embedchain/bots/base.py index a804b71b0a..384ca47e39 100644 --- a/embedchain/bots/base.py +++ b/embedchain/bots/base.py @@ -3,8 +3,8 @@ from embedchain import Pipeline as App from embedchain.config import AddConfig, BaseLlmConfig, PipelineConfig from embedchain.embedder.openai import OpenAIEmbedder -from embedchain.helper.json_serializable import (JSONSerializable, - register_deserializable) +from embedchain.helpers.json_serializable import (JSONSerializable, + register_deserializable) from embedchain.llm.openai import OpenAILlm from embedchain.vectordb.chroma import ChromaDB diff --git a/embedchain/bots/discord.py b/embedchain/bots/discord.py index 2d2d482e45..adbf3b7c74 100644 --- a/embedchain/bots/discord.py +++ b/embedchain/bots/discord.py @@ -2,7 +2,7 @@ import logging import os -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from .base import BaseBot diff --git a/embedchain/bots/poe.py b/embedchain/bots/poe.py index ebd71824eb..762d090b15 100644 --- a/embedchain/bots/poe.py +++ b/embedchain/bots/poe.py @@ -3,7 +3,7 @@ import os from typing import List, Optional -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from .base import BaseBot diff --git a/embedchain/bots/slack.py b/embedchain/bots/slack.py index 21834d8d48..43e39f2614 100644 --- a/embedchain/bots/slack.py +++ b/embedchain/bots/slack.py @@ -5,7 +5,7 @@ import sys from embedchain import App -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from .base import BaseBot diff --git a/embedchain/bots/whatsapp.py b/embedchain/bots/whatsapp.py index c9193f6f6f..5106d40dca 100644 --- a/embedchain/bots/whatsapp.py +++ b/embedchain/bots/whatsapp.py @@ -4,7 +4,7 @@ import signal import sys -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from .base import BaseBot diff --git a/embedchain/chunkers/base_chunker.py b/embedchain/chunkers/base_chunker.py index f7b5cb09f9..7130d30bf8 100644 --- a/embedchain/chunkers/base_chunker.py +++ b/embedchain/chunkers/base_chunker.py @@ -1,6 +1,6 @@ import hashlib -from embedchain.helper.json_serializable import JSONSerializable +from embedchain.helpers.json_serializable import JSONSerializable from embedchain.models.data_type import DataType diff --git a/embedchain/chunkers/common_chunker.py b/embedchain/chunkers/common_chunker.py index 7d607357fa..1527e33966 100644 --- a/embedchain/chunkers/common_chunker.py +++ b/embedchain/chunkers/common_chunker.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/discourse.py b/embedchain/chunkers/discourse.py index f78c616e1c..14898bf013 100644 --- a/embedchain/chunkers/discourse.py +++ b/embedchain/chunkers/discourse.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/docs_site.py b/embedchain/chunkers/docs_site.py index 1b7c00d9c3..d51dc8ee2e 100644 --- a/embedchain/chunkers/docs_site.py +++ b/embedchain/chunkers/docs_site.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/docx_file.py b/embedchain/chunkers/docx_file.py index 77fee55cb4..1452349e81 100644 --- a/embedchain/chunkers/docx_file.py +++ b/embedchain/chunkers/docx_file.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/gmail.py b/embedchain/chunkers/gmail.py index 59e87e7ec3..6b804f5461 100644 --- a/embedchain/chunkers/gmail.py +++ b/embedchain/chunkers/gmail.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/json.py b/embedchain/chunkers/json.py index 4eeee7ebe5..ebc5254195 100644 --- a/embedchain/chunkers/json.py +++ b/embedchain/chunkers/json.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/mdx.py b/embedchain/chunkers/mdx.py index 05225fac71..1c277dda7b 100644 --- a/embedchain/chunkers/mdx.py +++ b/embedchain/chunkers/mdx.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/mysql.py b/embedchain/chunkers/mysql.py index 7a2ec7fc0c..2b1c11acef 100644 --- a/embedchain/chunkers/mysql.py +++ b/embedchain/chunkers/mysql.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/notion.py b/embedchain/chunkers/notion.py index e0e2569cfe..190d59b57b 100644 --- a/embedchain/chunkers/notion.py +++ b/embedchain/chunkers/notion.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/pdf_file.py b/embedchain/chunkers/pdf_file.py index eb2355a4e1..56bae064ee 100644 --- a/embedchain/chunkers/pdf_file.py +++ b/embedchain/chunkers/pdf_file.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/postgres.py b/embedchain/chunkers/postgres.py index 168b6fcd52..7c6859bd01 100644 --- a/embedchain/chunkers/postgres.py +++ b/embedchain/chunkers/postgres.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/qna_pair.py b/embedchain/chunkers/qna_pair.py index 30f2b52f90..c0d8277b11 100644 --- a/embedchain/chunkers/qna_pair.py +++ b/embedchain/chunkers/qna_pair.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/sitemap.py b/embedchain/chunkers/sitemap.py index 6405001131..64e773742d 100644 --- a/embedchain/chunkers/sitemap.py +++ b/embedchain/chunkers/sitemap.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/slack.py b/embedchain/chunkers/slack.py index 93453b3b3a..595682bebe 100644 --- a/embedchain/chunkers/slack.py +++ b/embedchain/chunkers/slack.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/substack.py b/embedchain/chunkers/substack.py index 85f04fa622..92cacd6cb0 100644 --- a/embedchain/chunkers/substack.py +++ b/embedchain/chunkers/substack.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/text.py b/embedchain/chunkers/text.py index 0c36c863bf..f33d60c46a 100644 --- a/embedchain/chunkers/text.py +++ b/embedchain/chunkers/text.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/unstructured_file.py b/embedchain/chunkers/unstructured_file.py index ab0322b028..d55f23ef0a 100644 --- a/embedchain/chunkers/unstructured_file.py +++ b/embedchain/chunkers/unstructured_file.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/web_page.py b/embedchain/chunkers/web_page.py index b3da9ea3f8..253b2b4132 100644 --- a/embedchain/chunkers/web_page.py +++ b/embedchain/chunkers/web_page.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/xml.py b/embedchain/chunkers/xml.py index cae519ab95..c1bab0a77a 100644 --- a/embedchain/chunkers/xml.py +++ b/embedchain/chunkers/xml.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/chunkers/youtube_video.py b/embedchain/chunkers/youtube_video.py index d2ca025c79..bde0a8f781 100644 --- a/embedchain/chunkers/youtube_video.py +++ b/embedchain/chunkers/youtube_video.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.add_config import ChunkerConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/config/add_config.py b/embedchain/config/add_config.py index 16e99d6877..6695511849 100644 --- a/embedchain/config/add_config.py +++ b/embedchain/config/add_config.py @@ -3,7 +3,7 @@ from typing import Callable, Optional from embedchain.config.base_config import BaseConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/config/apps/app_config.py b/embedchain/config/apps/app_config.py index f0efaff9f1..4ab90b022c 100644 --- a/embedchain/config/apps/app_config.py +++ b/embedchain/config/apps/app_config.py @@ -1,6 +1,6 @@ from typing import Optional -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from .base_app_config import BaseAppConfig diff --git a/embedchain/config/apps/base_app_config.py b/embedchain/config/apps/base_app_config.py index d12b1c6309..f3a864700e 100644 --- a/embedchain/config/apps/base_app_config.py +++ b/embedchain/config/apps/base_app_config.py @@ -2,7 +2,7 @@ from typing import Optional from embedchain.config.base_config import BaseConfig -from embedchain.helper.json_serializable import JSONSerializable +from embedchain.helpers.json_serializable import JSONSerializable from embedchain.vectordb.base import BaseVectorDB diff --git a/embedchain/config/base_config.py b/embedchain/config/base_config.py index b02801f470..ff672f19b6 100644 --- a/embedchain/config/base_config.py +++ b/embedchain/config/base_config.py @@ -1,6 +1,6 @@ from typing import Any, Dict -from embedchain.helper.json_serializable import JSONSerializable +from embedchain.helpers.json_serializable import JSONSerializable class BaseConfig(JSONSerializable): diff --git a/embedchain/config/embedder/base.py b/embedchain/config/embedder/base.py index 8e8501e744..9227e07912 100644 --- a/embedchain/config/embedder/base.py +++ b/embedchain/config/embedder/base.py @@ -1,6 +1,6 @@ from typing import Optional -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/config/llm/base.py b/embedchain/config/llm/base.py index a98c1b6d58..6dbfdb90fd 100644 --- a/embedchain/config/llm/base.py +++ b/embedchain/config/llm/base.py @@ -1,9 +1,9 @@ import re from string import Template -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from embedchain.config.base_config import BaseConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable DEFAULT_PROMPT = """ Use the following pieces of context to answer the query at the end. @@ -68,6 +68,7 @@ def __init__( system_prompt: Optional[str] = None, where: Dict[str, Any] = None, query_type: Optional[str] = None, + callbacks: Optional[List] = None, ): """ Initializes a configuration class instance for the LLM. @@ -98,6 +99,8 @@ def __init__( :type system_prompt: Optional[str], optional :param where: A dictionary of key-value pairs to filter the database results., defaults to None :type where: Dict[str, Any], optional + :param callbacks: Langchain callback functions to use, defaults to None + :type callbacks: Optional[List], optional :raises ValueError: If the template is not valid as template should contain $context and $query (and optionally $history) :raises ValueError: Stream is not boolean @@ -113,6 +116,7 @@ def __init__( self.deployment_name = deployment_name self.system_prompt = system_prompt self.query_type = query_type + self.callbacks = callbacks if type(template) is str: template = Template(template) diff --git a/embedchain/config/pipeline_config.py b/embedchain/config/pipeline_config.py index e46456a1b5..8bfd4b4eaa 100644 --- a/embedchain/config/pipeline_config.py +++ b/embedchain/config/pipeline_config.py @@ -1,6 +1,6 @@ from typing import Optional -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from .apps.base_app_config import BaseAppConfig diff --git a/embedchain/config/vectordb/chroma.py b/embedchain/config/vectordb/chroma.py index 38bf092134..d25de1c342 100644 --- a/embedchain/config/vectordb/chroma.py +++ b/embedchain/config/vectordb/chroma.py @@ -1,7 +1,7 @@ from typing import Optional from embedchain.config.vectordb.base import BaseVectorDbConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/config/vectordb/elasticsearch.py b/embedchain/config/vectordb/elasticsearch.py index 75498bd245..77d54a16c1 100644 --- a/embedchain/config/vectordb/elasticsearch.py +++ b/embedchain/config/vectordb/elasticsearch.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Union from embedchain.config.vectordb.base import BaseVectorDbConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/config/vectordb/opensearch.py b/embedchain/config/vectordb/opensearch.py index 85d517b2da..d8dc9a109b 100644 --- a/embedchain/config/vectordb/opensearch.py +++ b/embedchain/config/vectordb/opensearch.py @@ -1,7 +1,7 @@ from typing import Dict, Optional, Tuple from embedchain.config.vectordb.base import BaseVectorDbConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/config/vectordb/pinecone.py b/embedchain/config/vectordb/pinecone.py index 7bd462ae46..e9165fdca5 100644 --- a/embedchain/config/vectordb/pinecone.py +++ b/embedchain/config/vectordb/pinecone.py @@ -1,7 +1,7 @@ from typing import Dict, Optional from embedchain.config.vectordb.base import BaseVectorDbConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/config/vectordb/qdrant.py b/embedchain/config/vectordb/qdrant.py index 4468c7b24b..9802212e8c 100644 --- a/embedchain/config/vectordb/qdrant.py +++ b/embedchain/config/vectordb/qdrant.py @@ -1,7 +1,7 @@ from typing import Dict, Optional from embedchain.config.vectordb.base import BaseVectorDbConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/config/vectordb/weaviate.py b/embedchain/config/vectordb/weaviate.py index 4035877b57..2db24134b0 100644 --- a/embedchain/config/vectordb/weaviate.py +++ b/embedchain/config/vectordb/weaviate.py @@ -1,7 +1,7 @@ from typing import Dict, Optional from embedchain.config.vectordb.base import BaseVectorDbConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/config/vectordb/zilliz.py b/embedchain/config/vectordb/zilliz.py index fbd6ec0256..ba91922c02 100644 --- a/embedchain/config/vectordb/zilliz.py +++ b/embedchain/config/vectordb/zilliz.py @@ -2,7 +2,7 @@ from typing import Optional from embedchain.config.vectordb.base import BaseVectorDbConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable @register_deserializable diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index 1c218d8b8e..368e432598 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -4,7 +4,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config import AddConfig from embedchain.config.add_config import ChunkerConfig, LoaderConfig -from embedchain.helper.json_serializable import JSONSerializable +from embedchain.helpers.json_serializable import JSONSerializable from embedchain.loaders.base_loader import BaseLoader from embedchain.models.data_type import DataType diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 88c098a6d0..dfd1627e2a 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -13,7 +13,7 @@ from embedchain.constants import SQLITE_PATH from embedchain.data_formatter import DataFormatter from embedchain.embedder.base import BaseEmbedder -from embedchain.helper.json_serializable import JSONSerializable +from embedchain.helpers.json_serializable import JSONSerializable from embedchain.llm.base import BaseLlm from embedchain.loaders.base_loader import BaseLoader from embedchain.models.data_type import (DataType, DirectDataType, diff --git a/embedchain/embedder/base.py b/embedchain/embedder/base.py index 50ed475b94..14941b2f18 100644 --- a/embedchain/embedder/base.py +++ b/embedchain/embedder/base.py @@ -3,12 +3,12 @@ from embedchain.config.embedder.base import BaseEmbedderConfig try: - from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction + from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings except RuntimeError: from embedchain.utils import use_pysqlite3 use_pysqlite3() - from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction + from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings class EmbeddingFunc(EmbeddingFunction): diff --git a/embedchain/helpers/__init__.py b/embedchain/helpers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/embedchain/helpers/callbacks.py b/embedchain/helpers/callbacks.py new file mode 100644 index 0000000000..3c7ab356a9 --- /dev/null +++ b/embedchain/helpers/callbacks.py @@ -0,0 +1,73 @@ +import queue +from typing import Any, Dict, List, Union + +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain.schema import LLMResult + +STOP_ITEM = "[END]" +""" +This is a special item that is used to signal the end of the stream. +""" + + +class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler): + """ + This is a callback handler that yields the tokens as they are generated. + For a usage example, see the :func:`generate` function below. + """ + + q: queue.Queue + """ + The queue to write the tokens to as they are generated. + """ + + def __init__(self, q: queue.Queue) -> None: + """ + Initialize the callback handler. + q: The queue to write the tokens to as they are generated. + """ + super().__init__() + self.q = q + + def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: + """Run when LLM starts running.""" + with self.q.mutex: + self.q.queue.clear() + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + self.q.put(token) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + self.q.put(STOP_ITEM) + + def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: + """Run when LLM errors.""" + self.q.put("%s: %s" % (type(error).__name__, str(error))) + self.q.put(STOP_ITEM) + + +def generate(rq: queue.Queue): + """ + This is a generator that yields the items in the queue until it reaches the stop item. + + Usage example: + ``` + def askQuestion(callback_fn: StreamingStdOutCallbackHandlerYield): + llm = OpenAI(streaming=True, callbacks=[callback_fn]) + return llm(prompt="Write a poem about a tree.") + + @app.route("/", methods=["GET"]) + def generate_output(): + q = Queue() + callback_fn = StreamingStdOutCallbackHandlerYield(q) + threading.Thread(target=askQuestion, args=(callback_fn,)).start() + return Response(generate(q), mimetype="text/event-stream") + ``` + """ + while True: + result: str = rq.get() + if result == STOP_ITEM or result is None: + break + yield result diff --git a/embedchain/helper/json_serializable.py b/embedchain/helpers/json_serializable.py similarity index 100% rename from embedchain/helper/json_serializable.py rename to embedchain/helpers/json_serializable.py diff --git a/embedchain/llm/anthropic.py b/embedchain/llm/anthropic.py index 492854f6f1..aab38640dc 100644 --- a/embedchain/llm/anthropic.py +++ b/embedchain/llm/anthropic.py @@ -3,7 +3,7 @@ from typing import Optional from embedchain.config import BaseLlmConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm diff --git a/embedchain/llm/azure_openai.py b/embedchain/llm/azure_openai.py index 34c4f08156..99588aa0ba 100644 --- a/embedchain/llm/azure_openai.py +++ b/embedchain/llm/azure_openai.py @@ -2,7 +2,7 @@ from typing import Optional from embedchain.config import BaseLlmConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm diff --git a/embedchain/llm/base.py b/embedchain/llm/base.py index 8bb38833ec..dc2471e42c 100644 --- a/embedchain/llm/base.py +++ b/embedchain/llm/base.py @@ -7,7 +7,7 @@ from embedchain.config.llm.base import (DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE) -from embedchain.helper.json_serializable import JSONSerializable +from embedchain.helpers.json_serializable import JSONSerializable from embedchain.memory.base import ECChatMemory from embedchain.memory.message import ChatMessage diff --git a/embedchain/llm/cohere.py b/embedchain/llm/cohere.py index 0811c0672d..e996db6b7f 100644 --- a/embedchain/llm/cohere.py +++ b/embedchain/llm/cohere.py @@ -5,7 +5,7 @@ from langchain.llms import Cohere from embedchain.config import BaseLlmConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm diff --git a/embedchain/llm/gpt4all.py b/embedchain/llm/gpt4all.py index 699e9d339e..73950bd7c0 100644 --- a/embedchain/llm/gpt4all.py +++ b/embedchain/llm/gpt4all.py @@ -4,7 +4,7 @@ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from embedchain.config import BaseLlmConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm diff --git a/embedchain/llm/huggingface.py b/embedchain/llm/huggingface.py index 97adc67b9c..4da6d5176f 100644 --- a/embedchain/llm/huggingface.py +++ b/embedchain/llm/huggingface.py @@ -5,7 +5,7 @@ from langchain.llms import HuggingFaceHub from embedchain.config import BaseLlmConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm diff --git a/embedchain/llm/jina.py b/embedchain/llm/jina.py index 2af5a7980b..b09c0fd3eb 100644 --- a/embedchain/llm/jina.py +++ b/embedchain/llm/jina.py @@ -5,7 +5,7 @@ from langchain.schema import HumanMessage, SystemMessage from embedchain.config import BaseLlmConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm diff --git a/embedchain/llm/llama2.py b/embedchain/llm/llama2.py index d4f4aa2f5d..fa2587bb85 100644 --- a/embedchain/llm/llama2.py +++ b/embedchain/llm/llama2.py @@ -5,7 +5,7 @@ from langchain.llms import Replicate from embedchain.config import BaseLlmConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py index 9e69085c0c..9efa019b35 100644 --- a/embedchain/llm/openai.py +++ b/embedchain/llm/openai.py @@ -4,7 +4,7 @@ from langchain.schema import HumanMessage, SystemMessage from embedchain.config import BaseLlmConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm @@ -34,7 +34,8 @@ def _get_answer(prompt: str, config: BaseLlmConfig) -> str: from langchain.callbacks.streaming_stdout import \ StreamingStdOutCallbackHandler - chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()]) + callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()] + chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks) else: chat = ChatOpenAI(**kwargs) return chat(messages).content diff --git a/embedchain/llm/vertex_ai.py b/embedchain/llm/vertex_ai.py index 224c65f166..c453c67b52 100644 --- a/embedchain/llm/vertex_ai.py +++ b/embedchain/llm/vertex_ai.py @@ -3,7 +3,7 @@ from typing import Optional from embedchain.config import BaseLlmConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm diff --git a/embedchain/loaders/base_loader.py b/embedchain/loaders/base_loader.py index bba58b813e..26da91a2aa 100644 --- a/embedchain/loaders/base_loader.py +++ b/embedchain/loaders/base_loader.py @@ -1,4 +1,4 @@ -from embedchain.helper.json_serializable import JSONSerializable +from embedchain.helpers.json_serializable import JSONSerializable class BaseLoader(JSONSerializable): diff --git a/embedchain/loaders/docs_site_loader.py b/embedchain/loaders/docs_site_loader.py index 2b880722c8..ea76442e53 100644 --- a/embedchain/loaders/docs_site_loader.py +++ b/embedchain/loaders/docs_site_loader.py @@ -12,7 +12,7 @@ ) from None -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader diff --git a/embedchain/loaders/docx_file.py b/embedchain/loaders/docx_file.py index 59846c0516..d72e64b297 100644 --- a/embedchain/loaders/docx_file.py +++ b/embedchain/loaders/docx_file.py @@ -6,7 +6,7 @@ raise ImportError( 'Docx file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' ) from None -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader diff --git a/embedchain/loaders/local_qna_pair.py b/embedchain/loaders/local_qna_pair.py index ffaa6feaa2..1158d4a994 100644 --- a/embedchain/loaders/local_qna_pair.py +++ b/embedchain/loaders/local_qna_pair.py @@ -1,6 +1,6 @@ import hashlib -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader diff --git a/embedchain/loaders/local_text.py b/embedchain/loaders/local_text.py index 118cbd3afd..e03ee12b28 100644 --- a/embedchain/loaders/local_text.py +++ b/embedchain/loaders/local_text.py @@ -1,6 +1,6 @@ import hashlib -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader diff --git a/embedchain/loaders/mdx.py b/embedchain/loaders/mdx.py index 9d73b6cad2..45b112f14d 100644 --- a/embedchain/loaders/mdx.py +++ b/embedchain/loaders/mdx.py @@ -1,6 +1,6 @@ import hashlib -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader diff --git a/embedchain/loaders/notion.py b/embedchain/loaders/notion.py index 7ff84ed583..e0e981e80e 100644 --- a/embedchain/loaders/notion.py +++ b/embedchain/loaders/notion.py @@ -10,7 +10,7 @@ ) from None -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils import clean_string diff --git a/embedchain/loaders/pdf_file.py b/embedchain/loaders/pdf_file.py index 6b03554ba7..03495edbe0 100644 --- a/embedchain/loaders/pdf_file.py +++ b/embedchain/loaders/pdf_file.py @@ -6,7 +6,7 @@ raise ImportError( 'PDF File requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' ) from None -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils import clean_string diff --git a/embedchain/loaders/sitemap.py b/embedchain/loaders/sitemap.py index 8b449f212a..06e7e23906 100644 --- a/embedchain/loaders/sitemap.py +++ b/embedchain/loaders/sitemap.py @@ -13,7 +13,7 @@ 'Sitemap requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' ) from None -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.web_page import WebPageLoader from embedchain.utils import is_readable diff --git a/embedchain/loaders/substack.py b/embedchain/loaders/substack.py index 4dcc609eeb..0d46b6d1a8 100644 --- a/embedchain/loaders/substack.py +++ b/embedchain/loaders/substack.py @@ -4,7 +4,7 @@ import requests -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils import is_readable diff --git a/embedchain/loaders/unstructured_file.py b/embedchain/loaders/unstructured_file.py index be8cd931fc..294c596c27 100644 --- a/embedchain/loaders/unstructured_file.py +++ b/embedchain/loaders/unstructured_file.py @@ -1,6 +1,6 @@ import hashlib -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils import clean_string diff --git a/embedchain/loaders/web_page.py b/embedchain/loaders/web_page.py index 931031826f..ecf03e9df7 100644 --- a/embedchain/loaders/web_page.py +++ b/embedchain/loaders/web_page.py @@ -10,7 +10,7 @@ 'Webpage requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' ) from None -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils import clean_string diff --git a/embedchain/loaders/xml.py b/embedchain/loaders/xml.py index d200ffb25b..00fe477017 100644 --- a/embedchain/loaders/xml.py +++ b/embedchain/loaders/xml.py @@ -6,7 +6,7 @@ raise ImportError( 'XML file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' ) from None -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils import clean_string diff --git a/embedchain/loaders/youtube_video.py b/embedchain/loaders/youtube_video.py index c3262822b5..2aa0802473 100644 --- a/embedchain/loaders/youtube_video.py +++ b/embedchain/loaders/youtube_video.py @@ -6,7 +6,7 @@ raise ImportError( 'YouTube video requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' ) from None -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils import clean_string diff --git a/embedchain/memory/message.py b/embedchain/memory/message.py index 383b7c0f34..99081d2d5c 100644 --- a/embedchain/memory/message.py +++ b/embedchain/memory/message.py @@ -1,7 +1,7 @@ import logging from typing import Any, Dict, Optional -from embedchain.helper.json_serializable import JSONSerializable +from embedchain.helpers.json_serializable import JSONSerializable class BaseMessage(JSONSerializable): diff --git a/embedchain/pipeline.py b/embedchain/pipeline.py index 91faf83c40..00a8a92df8 100644 --- a/embedchain/pipeline.py +++ b/embedchain/pipeline.py @@ -15,7 +15,7 @@ from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.openai import OpenAIEmbedder from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm from embedchain.llm.openai import OpenAILlm from embedchain.telemetry.posthog import AnonymousTelemetry diff --git a/embedchain/vectordb/base.py b/embedchain/vectordb/base.py index 77d6efc7a8..12f9693f9e 100644 --- a/embedchain/vectordb/base.py +++ b/embedchain/vectordb/base.py @@ -1,6 +1,6 @@ from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.embedder.base import BaseEmbedder -from embedchain.helper.json_serializable import JSONSerializable +from embedchain.helpers.json_serializable import JSONSerializable class BaseVectorDB(JSONSerializable): diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index f5f1cbbe1d..f32cf525af 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -6,7 +6,7 @@ from tqdm import tqdm from embedchain.config import ChromaDbConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.vectordb.base import BaseVectorDB try: diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index b273708022..5ae6fd7c41 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -10,7 +10,7 @@ ) from None from embedchain.config import ElasticsearchDBConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.vectordb.base import BaseVectorDB diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py index 1ffe881a8d..a86f6292bb 100644 --- a/embedchain/vectordb/opensearch.py +++ b/embedchain/vectordb/opensearch.py @@ -16,7 +16,7 @@ from langchain.vectorstores import OpenSearchVectorSearch from embedchain.config import OpenSearchDBConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.vectordb.base import BaseVectorDB diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index 86a817ac91..c3420c0928 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -9,7 +9,7 @@ ) from None from embedchain.config.vectordb.pinecone import PineconeDBConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.vectordb.base import BaseVectorDB diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py index fde91caf6b..6ff329cbc5 100644 --- a/embedchain/vectordb/weaviate.py +++ b/embedchain/vectordb/weaviate.py @@ -10,7 +10,7 @@ ) from None from embedchain.config.vectordb.weaviate import WeaviateDBConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.vectordb.base import BaseVectorDB diff --git a/embedchain/vectordb/zilliz.py b/embedchain/vectordb/zilliz.py index 7779f34495..0608c12fef 100644 --- a/embedchain/vectordb/zilliz.py +++ b/embedchain/vectordb/zilliz.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple, Union from embedchain.config import ZillizDBConfig -from embedchain.helper.json_serializable import register_deserializable +from embedchain.helpers.json_serializable import register_deserializable from embedchain.vectordb.base import BaseVectorDB try: diff --git a/pyproject.toml b/pyproject.toml index 167459bd85..dc6167b625 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.18" +version = "0.1.19" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" authors = [ "Taranjeet Singh ", diff --git a/tests/helper_classes/test_json_serializable.py b/tests/helper_classes/test_json_serializable.py index ba06a005ff..3cbe27634d 100644 --- a/tests/helper_classes/test_json_serializable.py +++ b/tests/helper_classes/test_json_serializable.py @@ -4,8 +4,8 @@ from embedchain import App from embedchain.config import AppConfig, BaseLlmConfig -from embedchain.helper.json_serializable import (JSONSerializable, - register_deserializable) +from embedchain.helpers.json_serializable import (JSONSerializable, + register_deserializable) class TestJsonSerializable(unittest.TestCase):