diff --git a/examples/getting_started.py b/examples/getting_started.py index af0d5012fa..afe76ffa20 100644 --- a/examples/getting_started.py +++ b/examples/getting_started.py @@ -1,21 +1,27 @@ +import logging + +from typing import Optional + from haystack.document_stores import InMemoryDocumentStore from haystack.utils import build_pipeline, add_example_data, print_answers +logger = logging.getLogger(__name__) -def getting_started(provider, API_KEY): + +def getting_started(provider, API_KEY, API_BASE: Optional[str] = None): """ This getting_started example shows you how to use LLMs with your data with a technique called Retrieval Augmented Generation - RAG. :param provider: We are model agnostic :) Here, you can choose from: "anthropic", "cohere", "huggingface", and "openai". :param API_KEY: The API key matching the provider. + :param API_BASE: The URL to use for a custom endpoint, e.g., if using LM Studio. Only openai provider supported. /v1 at the end is needed (e.g., http://localhost:1234/v1) """ - # We support many different databases. Here we load a simple and lightweight in-memory database. document_store = InMemoryDocumentStore(use_bm25=True) # Pipelines are the main abstraction in Haystack, they connect components like LLMs and databases. - pipeline = build_pipeline(provider, API_KEY, document_store) + pipeline = build_pipeline(provider, API_KEY, API_BASE, document_store) # Download and add Game of Thrones TXT articles to Haystack's database. # You can also provide a folder with your local documents. @@ -23,7 +29,7 @@ def getting_started(provider, API_KEY): add_example_data(document_store, "data/GoT_getting_started") # Ask a question on the data you just added. - result = pipeline.run(query="Who is the father of Arya Stark?") + result = pipeline.run(query="Who is the father of Arya Stark?", debug=True) # For details such as which documents were used to generate the answer, look into the object. print_answers(result, details="medium") @@ -31,4 +37,5 @@ def getting_started(provider, API_KEY): if __name__ == "__main__": + # getting_started(provider="openai", API_KEY="NOT NEEDED", API_BASE="http://192.168.1.100:1234/v1") getting_started(provider="openai", API_KEY="ADD KEY HERE") diff --git a/haystack/nodes/prompt/prompt_model.py b/haystack/nodes/prompt/prompt_model.py index c8071e49f6..62d28b0d5c 100644 --- a/haystack/nodes/prompt/prompt_model.py +++ b/haystack/nodes/prompt/prompt_model.py @@ -37,6 +37,7 @@ def __init__( model_name_or_path: str = "google/flan-t5-base", max_length: Optional[int] = 100, api_key: Optional[str] = None, + api_base: Optional[str] = None, timeout: Optional[float] = None, use_auth_token: Optional[Union[str, bool]] = None, use_gpu: Optional[bool] = None, @@ -65,6 +66,7 @@ def __init__( self.model_name_or_path = model_name_or_path self.max_length = max_length self.api_key = api_key + self.api_base = api_base self.timeout = timeout self.use_auth_token = use_auth_token self.use_gpu = use_gpu @@ -83,6 +85,9 @@ def create_invocation_layer( "use_gpu": self.use_gpu, "devices": self.devices, } + if self.api_base is not None: + kwargs["api_base"] = self.api_base + all_kwargs = {**self.model_kwargs, **kwargs} if isinstance(invocation_layer_class, str): diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index 92ec069acb..0b77c89682 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -57,6 +57,7 @@ def __init__( output_variable: Optional[str] = None, max_length: Optional[int] = 100, api_key: Optional[str] = None, + api_base: Optional[str] = None, timeout: Optional[float] = None, use_auth_token: Optional[Union[str, bool]] = None, use_gpu: Optional[bool] = None, @@ -114,6 +115,7 @@ def __init__( model_name_or_path=model_name_or_path, max_length=max_length, api_key=api_key, + api_base=api_base, timeout=timeout, use_auth_token=use_auth_token, use_gpu=use_gpu, diff --git a/haystack/utils/getting_started.py b/haystack/utils/getting_started.py index cd54e7169d..b20c3dc539 100644 --- a/haystack/utils/getting_started.py +++ b/haystack/utils/getting_started.py @@ -7,7 +7,7 @@ logger = logging.getLogger(__name__) -def build_pipeline(provider, API_KEY, document_store): +def build_pipeline(provider, API_KEY, API_BASE, document_store): # Importing top-level causes a circular import from haystack.nodes import AnswerParser, PromptNode, PromptTemplate, BM25Retriever from haystack.pipelines import Pipeline @@ -42,6 +42,7 @@ def build_pipeline(provider, API_KEY, document_store): prompt_node = PromptNode( model_name_or_path="gpt-3.5-turbo-0301", api_key=API_KEY, + api_base=API_BASE, default_prompt_template=question_answering_with_references, ) else: diff --git a/releasenotes/notes/override-api-base-67bc046a5cc5f46d.yaml b/releasenotes/notes/override-api-base-67bc046a5cc5f46d.yaml new file mode 100644 index 0000000000..bf0eb5b3c5 --- /dev/null +++ b/releasenotes/notes/override-api-base-67bc046a5cc5f46d.yaml @@ -0,0 +1,8 @@ +--- +enhancements: + - | + API_BASE can now be passed as an optional parameter in the getting_started sample. Only openai provider is supported in this set of changes. + PromptNode and PromptModel were enhanced to allow passing of this parameter. + This allows RAG against a local endpoint (e.g, http://localhost:1234/v1), so long as it is OpenAI compatible (such as LM Studio) + + Logging in the getting started sample was made more verbose, to make it easier for people to see what was happening under the covers.