From 11c0bcad0205f7558273aca11f4b23443a2e95fa Mon Sep 17 00:00:00 2001 From: rchan Date: Tue, 12 Sep 2023 18:28:04 +0100 Subject: [PATCH 1/9] llamacpp query engine first commit --- slack_bot/run.py | 113 ++++++++++++++++------ slack_bot/slack_bot/models/__init__.py | 8 +- slack_bot/slack_bot/models/llama_index.py | 58 ++++++++++- 3 files changed, 146 insertions(+), 33 deletions(-) diff --git a/slack_bot/run.py b/slack_bot/run.py index 7b28b176..45a075c7 100755 --- a/slack_bot/run.py +++ b/slack_bot/run.py @@ -10,6 +10,13 @@ from slack_bot import MODELS, Bot +DEFAULT_LLAMA_CPP_GGUF_MODEL = ( + "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve" + "/main/llama-2-13b-chat.Q6_K.gguf" +) +DEFAULT_HF_MODEL = "StabilityAI/stablelm-tuned-alpha-3b" + + if __name__ == "__main__": # Parse command line arguments parser = argparse.ArgumentParser() @@ -17,32 +24,61 @@ "--model", "-m", help="Select which model to use", default=None, choices=MODELS ) parser.add_argument( - "--hf_model", - "-hf", - help="""Select which HuggingFace model to use - (ignored if not using llama-huggingface model)""", - default="StabilityAI/stablelm-tuned-alpha-3b", + "--model_name", + "-n", + type=str | None, + help=( + "Select which LlamaCPP or HuggingFace model to use " + "(ignored if not using llama-index-llama-cpp or llama-index-hf). " + "Default model for llama-index-llama-cpp is downloaded from " + f"{DEFAULT_LLAMA_CPP_GGUF_MODEL}. " + "Default model for llama-index-hf is downloaded from " + f"{DEFAULT_HF_MODEL}." + ), + default=None, + ) + parser.add_argument( + "--path", + "-p", + help=( + "Whether or not the model_name passed is a path to the model " + "(ignored if not using llama-index-llama-cpp)" + ), + action="store_true", ) parser.add_argument( "--max_input_size", "-max", - help="""Select maximum input size for HuggingFace model - (ignored if not using llama-huggingface model)""", + type=int, + help=( + "Select maximum input size for LlamaCPP or HuggingFace model " + "(ignored if not using llama-index-llama-cpp or llama-index-hf)" + ), default=4096, ) + parser.add_argument( + "--n_gpu_layers", + "-ngl", + help=( + "Select number of GPU layers for LlamaCPP model " + "(ignored if not using llama-index-llama-cpp)" + ), + default=0, + ) parser.add_argument( "--device", "-dev", - help="""Select device for HuggingFace model - (ignored if not using llama-huggingface model)""", + help=( + "Select device for HuggingFace model " + "(ignored if not using llama-index-hf model)" + ), default="auto", ) parser.add_argument( "--force-new-index", "-f", help="Recreate the index vector store or not", - action=argparse.BooleanOptionalAction, - default=False, + action="store_true", ) parser.add_argument( "--data-dir", @@ -53,11 +89,13 @@ parser.add_argument( "--which-index", "-w", - help="""Specifies the directory name for looking up/writing indices. - Currently supports 'all_data', 'public' and 'handbook'. - If regenerating index, 'all_data' will use all .txt .md. and .csv - files in the data directory, 'handbook' will - only use 'handbook.csv' file.""", + help=( + "Specifies the directory name for looking up/writing indices. " + "Currently supports 'all_data', 'public' and 'handbook'. " + "If regenerating index, 'all_data' will use all .txt .md. and .csv " + "files in the data directory, 'handbook' will " + "only use 'handbook.csv' file." + ), default="all_data", choices=["all_data", "public", "handbook"], ) @@ -107,24 +145,39 @@ logging.error(f"Model {model_name} was not recognised") sys.exit(1) + # Initialise LLM reponse model logging.info(f"Initialising bot with model: {model_name}") - if model_name == "llama-index-hf": - response_model = model( - model_name=args.hf_model, - max_input_size=args.max_input_size, - device=args.device, - force_new_index=force_new_index, - data_dir=data_dir, - which_index=which_index, - ) + # Set up any model args that are required + if model_name == "llama-index-llama-cpp": + if args.model_name is None: + args.model_name = DEFAULT_LLAMA_CPP_GGUF_MODEL + + model_args = { + "model_name": args.model_name, + "path": args.path, + "max_input_size": args.max_input_size, + } + elif model_name == "llama-index-hf": + if args.model_name is None: + args.model_name = DEFAULT_HF_MODEL + + model_args = { + "model_name": args.model_name, + "max_input_size": args.max_input_size, + "device": args.device, + } else: - response_model = model( - force_new_index=force_new_index, - data_dir=data_dir, - which_index=which_index, - ) + model_args = {} + + response_model = model( + force_new_index=force_new_index, + data_dir=data_dir, + which_index=which_index, + **model_args, + ) + # Initialise Bot with response model logging.info(f"Initalising bot with model: {response_model}") slack_bot = Bot(response_model) diff --git a/slack_bot/slack_bot/models/__init__.py b/slack_bot/slack_bot/models/__init__.py index 1878cb8b..e5d28844 100644 --- a/slack_bot/slack_bot/models/__init__.py +++ b/slack_bot/slack_bot/models/__init__.py @@ -1,7 +1,12 @@ from .base import ResponseModel from .chat_completion import ChatCompletionAzure, ChatCompletionOpenAI from .hello import Hello -from .llama_index import LlamaIndexGPTAzure, LlamaIndexGPTOpenAI, LlamaIndexHF +from .llama_index import ( + LlamaIndexGPTAzure, + LlamaIndexGPTOpenAI, + LlamaIndexHF, + LlamaIndexLlamaCPP, +) # Please ensure that any models needing OPENAI_API_KEY are named *openai* # Please ensure that any models needing OPENAI_AZURE_API_BASE and OPENAI_AZURE_API_KEY are named *azure* @@ -9,6 +14,7 @@ "chat-completion-azure": ChatCompletionAzure, "chat-completion-openai": ChatCompletionOpenAI, "hello": Hello, + "llama-index-llama-cpp": LlamaIndexLlamaCPP, "llama-index-hf": LlamaIndexHF, "llama-index-gpt-azure": LlamaIndexGPTAzure, "llama-index-gpt-openai": LlamaIndexGPTOpenAI, diff --git a/slack_bot/slack_bot/models/llama_index.py b/slack_bot/slack_bot/models/llama_index.py index 43dd322c..1b73636b 100644 --- a/slack_bot/slack_bot/models/llama_index.py +++ b/slack_bot/slack_bot/models/llama_index.py @@ -17,8 +17,9 @@ load_index_from_storage, ) from llama_index.indices.vector_store.base import VectorStoreIndex -from llama_index.llms import AzureOpenAI, HuggingFaceLLM, OpenAI +from llama_index.llms import AzureOpenAI, HuggingFaceLLM, LlamaCPP, OpenAI from llama_index.llms.base import LLM +from llama_index.llms.llama_utils import completion_to_prompt, messages_to_prompt from llama_index.prompts import PromptTemplate from llama_index.response.schema import RESPONSE_TYPE @@ -332,6 +333,59 @@ def channel_mention(self, message: str, user_id: str) -> MessageResponse: return MessageResponse(backend_response) +class LlamaIndexLlamaCPP(LlamaIndex): + def __init__( + self, + model_name: str, + path: bool, + n_gpu_layers: int = 0, + *args: Any, + **kwargs: Any, + ) -> None: + """ + `LlamaIndexLlamaCPP` is a subclass of `LlamaIndex` that uses + llama-cpp to implement the LLM. + + Parameters + ---------- + model_name : str + Either the path to the model or the URL to download the model from + path : bool, optional + If True, model_name is used as a path to the model file, + otherwise it should be the URL to download the model + n_gpu_layers : int, optional + Number of layers to offload to GPU. + If -1, all layers are offloaded, by default 0 + """ + super().__init__(*args, model_name=model_name, **kwargs) + self.path = path + self.n_gpu_layers = n_gpu_layers + + def _prep_llm(self) -> LLM: + logging.info( + f"Setting up LlamaCPP LLM (model {self.model_name}) on {self.n_gpu_layers} GPU layers" + ) + logging.info( + f"LlamaCPP-args: (context_window: {self.max_input_size}, num_output: {self.num_output})" + ) + + return LlamaCPP( + model_url=self.model_name if not self.path else None, + model_path=self.model_name if self.path else None, + temperature=0.1, + max_new_tokens=self.num_output, + context_window=self.max_input_size, + # kwargs to pass to __call__() + generate_kwargs={}, + # kwargs to pass to __init__() + model_kwargs={"n_gpu_layers": self.n_gpu_layers}, + # transform inputs into Llama2 format + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + verbose=False, + ) + + class LlamaIndexHF(LlamaIndex): def __init__( self, @@ -365,7 +419,7 @@ def _prep_llm(self) -> LLM: max_new_tokens=self.num_output, # TODO: allow user to specify the query wrapper prompt for their model query_wrapper_prompt=PromptTemplate("<|USER|>{query_str}<|ASSISTANT|>"), - generate_kwargs={"temperature": 0.25, "do_sample": False}, + generate_kwargs={"temperature": 0.1, "do_sample": False}, tokenizer_name=self.model_name, model_name=self.model_name, device_map=self.device or "auto", From 9504735c86a60a9363c68db25235a5f461d879ef Mon Sep 17 00:00:00 2001 From: rchan Date: Tue, 12 Sep 2023 18:35:26 +0100 Subject: [PATCH 2/9] fix types in slack_bot/run.py --- slack_bot/run.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/slack_bot/run.py b/slack_bot/run.py index 45a075c7..d9b6decb 100755 --- a/slack_bot/run.py +++ b/slack_bot/run.py @@ -26,7 +26,7 @@ parser.add_argument( "--model_name", "-n", - type=str | None, + type=str, help=( "Select which LlamaCPP or HuggingFace model to use " "(ignored if not using llama-index-llama-cpp or llama-index-hf). " @@ -59,6 +59,7 @@ parser.add_argument( "--n_gpu_layers", "-ngl", + type=int, help=( "Select number of GPU layers for LlamaCPP model " "(ignored if not using llama-index-llama-cpp)" @@ -68,6 +69,7 @@ parser.add_argument( "--device", "-dev", + type=str, help=( "Select device for HuggingFace model " "(ignored if not using llama-index-hf model)" @@ -83,12 +85,14 @@ parser.add_argument( "--data-dir", "-d", + type=pathlib.Path, help="Location for data", default=(pathlib.Path(__file__).parent.parent / "data").resolve(), ) parser.add_argument( "--which-index", "-w", + type=str, help=( "Specifies the directory name for looking up/writing indices. " "Currently supports 'all_data', 'public' and 'handbook'. " From 24138ce82e27493edc4929be7a31b0e4d1b33075 Mon Sep 17 00:00:00 2001 From: rchan Date: Tue, 12 Sep 2023 19:13:07 +0100 Subject: [PATCH 3/9] assign attributes before super in llama-cpp --- slack_bot/slack_bot/models/llama_index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slack_bot/slack_bot/models/llama_index.py b/slack_bot/slack_bot/models/llama_index.py index 1b73636b..fb3c83e9 100644 --- a/slack_bot/slack_bot/models/llama_index.py +++ b/slack_bot/slack_bot/models/llama_index.py @@ -357,9 +357,9 @@ def __init__( Number of layers to offload to GPU. If -1, all layers are offloaded, by default 0 """ - super().__init__(*args, model_name=model_name, **kwargs) self.path = path self.n_gpu_layers = n_gpu_layers + super().__init__(*args, model_name=model_name, **kwargs) def _prep_llm(self) -> LLM: logging.info( From 5e08b1d16c3b1a3d4628b83f258507f518d9fcae Mon Sep 17 00:00:00 2001 From: rchan Date: Wed, 13 Sep 2023 08:25:05 +0100 Subject: [PATCH 4/9] update llama-index, llama-cpp-python, llama-hub --- poetry.lock | 61 +++++++++++++++++++++++++++++++++++++++++--------- pyproject.toml | 6 ++--- 2 files changed, 53 insertions(+), 14 deletions(-) diff --git a/poetry.lock b/poetry.lock index bee4578e..be0a76f1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1315,6 +1315,7 @@ files = [ {file = "greenlet-2.0.2-cp27-cp27m-win32.whl", hash = "sha256:6c3acb79b0bfd4fe733dff8bc62695283b57949ebcca05ae5c129eb606ff2d74"}, {file = "greenlet-2.0.2-cp27-cp27m-win_amd64.whl", hash = "sha256:283737e0da3f08bd637b5ad058507e578dd462db259f7f6e4c5c365ba4ee9343"}, {file = "greenlet-2.0.2-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:d27ec7509b9c18b6d73f2f5ede2622441de812e7b1a80bbd446cb0633bd3d5ae"}, + {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d967650d3f56af314b72df7089d96cda1083a7fc2da05b375d2bc48c82ab3f3c"}, {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:30bcf80dda7f15ac77ba5af2b961bdd9dbc77fd4ac6105cee85b0d0a5fcf74df"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26fbfce90728d82bc9e6c38ea4d038cba20b7faf8a0ca53a9c07b67318d46088"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9190f09060ea4debddd24665d6804b995a9c122ef5917ab26e1566dcc712ceeb"}, @@ -1323,6 +1324,7 @@ files = [ {file = "greenlet-2.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:76ae285c8104046b3a7f06b42f29c7b73f77683df18c49ab5af7983994c2dd91"}, {file = "greenlet-2.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:2d4686f195e32d36b4d7cf2d166857dbd0ee9f3d20ae349b6bf8afc8485b3645"}, {file = "greenlet-2.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c4302695ad8027363e96311df24ee28978162cdcdd2006476c43970b384a244c"}, + {file = "greenlet-2.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d4606a527e30548153be1a9f155f4e283d109ffba663a15856089fb55f933e47"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c48f54ef8e05f04d6eff74b8233f6063cb1ed960243eacc474ee73a2ea8573ca"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a1846f1b999e78e13837c93c778dcfc3365902cfb8d1bdb7dd73ead37059f0d0"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a06ad5312349fec0ab944664b01d26f8d1f05009566339ac6f63f56589bc1a2"}, @@ -1352,6 +1354,7 @@ files = [ {file = "greenlet-2.0.2-cp37-cp37m-win32.whl", hash = "sha256:3f6ea9bd35eb450837a3d80e77b517ea5bc56b4647f5502cd28de13675ee12f7"}, {file = "greenlet-2.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:7492e2b7bd7c9b9916388d9df23fa49d9b88ac0640db0a5b4ecc2b653bf451e3"}, {file = "greenlet-2.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b864ba53912b6c3ab6bcb2beb19f19edd01a6bfcbdfe1f37ddd1778abfe75a30"}, + {file = "greenlet-2.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1087300cf9700bbf455b1b97e24db18f2f77b55302a68272c56209d5587c12d1"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:ba2956617f1c42598a308a84c6cf021a90ff3862eddafd20c3333d50f0edb45b"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc3a569657468b6f3fb60587e48356fe512c1754ca05a564f11366ac9e306526"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8eab883b3b2a38cc1e050819ef06a7e6344d4a990d24d45bc6f2cf959045a45b"}, @@ -1360,6 +1363,7 @@ files = [ {file = "greenlet-2.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b0ef99cdbe2b682b9ccbb964743a6aca37905fda5e0452e5ee239b1654d37f2a"}, {file = "greenlet-2.0.2-cp38-cp38-win32.whl", hash = "sha256:b80f600eddddce72320dbbc8e3784d16bd3fb7b517e82476d8da921f27d4b249"}, {file = "greenlet-2.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:4d2e11331fc0c02b6e84b0d28ece3a36e0548ee1a1ce9ddde03752d9b79bba40"}, + {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8512a0c38cfd4e66a858ddd1b17705587900dd760c6003998e9472b77b56d417"}, {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:88d9ab96491d38a5ab7c56dd7a3cc37d83336ecc564e4e8816dbed12e5aaefc8"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:561091a7be172ab497a3527602d467e2b3fbe75f9e783d8b8ce403fa414f71a6"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:971ce5e14dc5e73715755d0ca2975ac88cfdaefcaab078a284fea6cfabf866df"}, @@ -1994,12 +1998,12 @@ requests = ">=2,<3" [[package]] name = "llama-cpp-python" -version = "0.1.84" -description = "A Python wrapper for llama.cpp" +version = "0.2.2" +description = "Python bindings for the llama.cpp library" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "llama_cpp_python-0.1.84.tar.gz", hash = "sha256:8840bfa90acfdd80486e3c11393fe6ff6841598f03278bdf3502e2d901978f13"}, + {file = "llama_cpp_python-0.2.2.tar.gz", hash = "sha256:b8646919839730241109be9b2016577d5df85466d72f6db0690bc9529cf8e446"}, ] [package.dependencies] @@ -2008,17 +2012,20 @@ numpy = ">=1.20.0" typing-extensions = ">=4.5.0" [package.extras] +all = ["llama_cpp_python[dev,server,test]"] +dev = ["black (>=23.3.0)", "httpx (>=0.24.1)", "mkdocs (>=1.4.3)", "mkdocs-material (>=9.1.18)", "mkdocstrings[python] (>=0.22.0)", "pytest (>=7.4.0)", "twine (>=4.0.2)"] server = ["fastapi (>=0.100.0)", "pydantic-settings (>=2.0.1)", "sse-starlette (>=1.6.1)", "uvicorn (>=0.22.0)"] +test = ["httpx (>=0.24.1)", "pytest (>=7.4.0)"] [[package]] name = "llama-hub" -version = "0.0.26" +version = "0.0.30" description = "A library of community-driven data loaders for LLMs. Use with LlamaIndex and/or LangChain. " optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "llama_hub-0.0.26-py3-none-any.whl", hash = "sha256:7cc4e4ac44ff4d4a57ed8fe9cbfcc968503a296a151d89c5f19cfd1bd7f9d4aa"}, - {file = "llama_hub-0.0.26.tar.gz", hash = "sha256:3b24d9396d977b60f1b3475896567140a7b52434c5bd94c981bc2f30732f3c7b"}, + {file = "llama_hub-0.0.30-py3-none-any.whl", hash = "sha256:e8878331db968af6210ee85de5577e39657b9beb1fe7f9bd617fa5e2c3b6b25a"}, + {file = "llama_hub-0.0.30.tar.gz", hash = "sha256:77a82458fdb3e491f75e40a16dbdc1eb8066387922aa7c5854fb93025060e8c1"}, ] [package.dependencies] @@ -2030,13 +2037,13 @@ retrying = "*" [[package]] name = "llama-index" -version = "0.8.24.post1" +version = "0.8.25" description = "Interface between LLMs and your data" optional = false python-versions = "*" files = [ - {file = "llama_index-0.8.24.post1-py3-none-any.whl", hash = "sha256:4b7645a445d394640bad8c66a67483df29f7f0af25c53360cb382075be0c6c34"}, - {file = "llama_index-0.8.24.post1.tar.gz", hash = "sha256:7cd47cf6ba64d24dbc6db712bcd4834767e0d35890559feee139bd4fa90ad916"}, + {file = "llama_index-0.8.25-py3-none-any.whl", hash = "sha256:bb887d66fd92be21dec3c881249af36f92519cc74fc516868fede66a5b065ea7"}, + {file = "llama_index-0.8.25.tar.gz", hash = "sha256:03b02cb04c9930ecce0a04dc58779cd95030aeb708aceeac3440e846501d2a2a"}, ] [package.dependencies] @@ -3403,6 +3410,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -3410,8 +3418,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -3428,6 +3443,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -3435,6 +3451,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -3829,35 +3846,57 @@ files = [ {file = "safetensors-0.3.3-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:17f41344d9a075f2f21b289a49a62e98baff54b5754240ba896063bce31626bf"}, {file = "safetensors-0.3.3-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:f1045f798e1a16a6ced98d6a42ec72936d367a2eec81dc5fade6ed54638cd7d2"}, {file = "safetensors-0.3.3-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:eaf0e4bc91da13f21ac846a39429eb3f3b7ed06295a32321fa3eb1a59b5c70f3"}, + {file = "safetensors-0.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25149180d4dc8ca48bac2ac3852a9424b466e36336a39659b35b21b2116f96fc"}, + {file = "safetensors-0.3.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9e943bf78c39de8865398a71818315e7d5d1af93c7b30d4da3fc852e62ad9bc"}, + {file = "safetensors-0.3.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cccfcac04a010354e87c7a2fe16a1ff004fc4f6e7ef8efc966ed30122ce00bc7"}, {file = "safetensors-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a07121f427e646a50d18c1be0fa1a2cbf6398624c31149cd7e6b35486d72189e"}, {file = "safetensors-0.3.3-cp310-cp310-win32.whl", hash = "sha256:a85e29cbfddfea86453cc0f4889b4bcc6b9c155be9a60e27be479a34e199e7ef"}, + {file = "safetensors-0.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:e13adad4a3e591378f71068d14e92343e626cf698ff805f61cdb946e684a218e"}, {file = "safetensors-0.3.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:cbc3312f134baf07334dd517341a4b470b2931f090bd9284888acb7dfaf4606f"}, {file = "safetensors-0.3.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d15030af39d5d30c22bcbc6d180c65405b7ea4c05b7bab14a570eac7d7d43722"}, {file = "safetensors-0.3.3-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:f84a74cbe9859b28e3d6d7715ac1dd3097bebf8d772694098f6d42435245860c"}, {file = "safetensors-0.3.3-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:10d637423d98ab2e6a4ad96abf4534eb26fcaf8ca3115623e64c00759374e90d"}, {file = "safetensors-0.3.3-cp311-cp311-macosx_13_0_universal2.whl", hash = "sha256:3b46f5de8b44084aff2e480874c550c399c730c84b2e8ad1bddb062c94aa14e9"}, + {file = "safetensors-0.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e76da691a82dfaf752854fa6d17c8eba0c8466370c5ad8cf1bfdf832d3c7ee17"}, + {file = "safetensors-0.3.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4e342fd54e66aa9512dd13e410f791e47aa4feeb5f4c9a20882c72f3d272f29"}, + {file = "safetensors-0.3.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:178fd30b5dc73bce14a39187d948cedd0e5698e2f055b7ea16b5a96c9b17438e"}, {file = "safetensors-0.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e8fdf7407dba44587ed5e79d5de3533d242648e1f2041760b21474bd5ea5c8c"}, {file = "safetensors-0.3.3-cp311-cp311-win32.whl", hash = "sha256:7d3b744cee8d7a46ffa68db1a2ff1a1a432488e3f7a5a97856fe69e22139d50c"}, + {file = "safetensors-0.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:f579877d30feec9b6ba409d05fa174633a4fc095675a4a82971d831a8bb60b97"}, {file = "safetensors-0.3.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:2fff5b19a1b462c17322998b2f4b8bce43c16fe208968174d2f3a1446284ceed"}, {file = "safetensors-0.3.3-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:41adb1d39e8aad04b16879e3e0cbcb849315999fad73bc992091a01e379cb058"}, {file = "safetensors-0.3.3-cp37-cp37m-macosx_12_0_x86_64.whl", hash = "sha256:0f2b404250b3b877b11d34afcc30d80e7035714a1116a3df56acaca6b6c00096"}, {file = "safetensors-0.3.3-cp37-cp37m-macosx_13_0_x86_64.whl", hash = "sha256:b43956ef20e9f4f2e648818a9e7b3499edd6b753a0f5526d4f6a6826fbee8446"}, + {file = "safetensors-0.3.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d61a99b34169981f088ccfbb2c91170843efc869a0a0532f422db7211bf4f474"}, + {file = "safetensors-0.3.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c0008aab36cd20e9a051a68563c6f80d40f238c2611811d7faa5a18bf3fd3984"}, + {file = "safetensors-0.3.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:93d54166072b143084fdcd214a080a088050c1bb1651016b55942701b31334e4"}, {file = "safetensors-0.3.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c32ee08f61cea56a5d62bbf94af95df6040c8ab574afffaeb7b44ae5da1e9e3"}, {file = "safetensors-0.3.3-cp37-cp37m-win32.whl", hash = "sha256:351600f367badd59f7bfe86d317bb768dd8c59c1561c6fac43cafbd9c1af7827"}, + {file = "safetensors-0.3.3-cp37-cp37m-win_amd64.whl", hash = "sha256:034717e297849dae1af0a7027a14b8647bd2e272c24106dced64d83e10d468d1"}, {file = "safetensors-0.3.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8530399666748634bc0b301a6a5523756931b0c2680d188e743d16304afe917a"}, {file = "safetensors-0.3.3-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:9d741c1f1621e489ba10aa3d135b54202684f6e205df52e219d5eecd673a80c9"}, + {file = "safetensors-0.3.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:0c345fd85b4d2093a5109596ff4cd9dfc2e84992e881b4857fbc4a93a3b89ddb"}, {file = "safetensors-0.3.3-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:69ccee8d05f55cdf76f7e6c87d2bdfb648c16778ef8acfd2ecc495e273e9233e"}, + {file = "safetensors-0.3.3-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:c08a9a4b7a4ca389232fa8d097aebc20bbd4f61e477abc7065b5c18b8202dede"}, {file = "safetensors-0.3.3-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:a002868d2e3f49bbe81bee2655a411c24fa1f8e68b703dec6629cb989d6ae42e"}, + {file = "safetensors-0.3.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3bd2704cb41faa44d3ec23e8b97330346da0395aec87f8eaf9c9e2c086cdbf13"}, + {file = "safetensors-0.3.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b2951bf3f0ad63df5e6a95263652bd6c194a6eb36fd4f2d29421cd63424c883"}, + {file = "safetensors-0.3.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:07114cec116253ca2e7230fdea30acf76828f21614afd596d7b5438a2f719bd8"}, {file = "safetensors-0.3.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ab43aeeb9eadbb6b460df3568a662e6f1911ecc39387f8752afcb6a7d96c087"}, {file = "safetensors-0.3.3-cp38-cp38-win32.whl", hash = "sha256:f2f59fce31dd3429daca7269a6b06f65e6547a0c248f5116976c3f1e9b73f251"}, + {file = "safetensors-0.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:c31ca0d8610f57799925bf08616856b39518ab772c65093ef1516762e796fde4"}, {file = "safetensors-0.3.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:59a596b3225c96d59af412385981f17dd95314e3fffdf359c7e3f5bb97730a19"}, {file = "safetensors-0.3.3-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:82a16e92210a6221edd75ab17acdd468dd958ef5023d9c6c1289606cc30d1479"}, {file = "safetensors-0.3.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:98a929e763a581f516373ef31983ed1257d2d0da912a8e05d5cd12e9e441c93a"}, {file = "safetensors-0.3.3-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:12b83f1986cd16ea0454c636c37b11e819d60dd952c26978310a0835133480b7"}, {file = "safetensors-0.3.3-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:f439175c827c2f1bbd54df42789c5204a10983a30bc4242bc7deaf854a24f3f0"}, {file = "safetensors-0.3.3-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:0085be33b8cbcb13079b3a8e131656e05b0bc5e6970530d4c24150f7afd76d70"}, + {file = "safetensors-0.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e3ec70c87b1e910769034206ad5efc051069b105aac1687f6edcd02526767f4"}, + {file = "safetensors-0.3.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f490132383e5e490e710608f4acffcb98ed37f91b885c7217d3f9f10aaff9048"}, + {file = "safetensors-0.3.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:79d1b6c7ed5596baf79c80fbce5198c3cdcc521ae6a157699f427aba1a90082d"}, {file = "safetensors-0.3.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad3cc8006e7a86ee7c88bd2813ec59cd7cc75b03e6fa4af89b9c7b235b438d68"}, {file = "safetensors-0.3.3-cp39-cp39-win32.whl", hash = "sha256:ab29f54c6b8c301ca05fa014728996bd83aac6e21528f893aaf8945c71f42b6d"}, + {file = "safetensors-0.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:0fa82004eae1a71e2aa29843ef99de9350e459a0fc2f65fc6ee0da9690933d2d"}, {file = "safetensors-0.3.3.tar.gz", hash = "sha256:edb7072d788c4f929d0f5735d3a2fb51e5a27f833587828583b7f5747af1a2b8"}, ] @@ -5163,4 +5202,4 @@ huggingface-llm = ["accelerate"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "9981e743e250cbb1ef7eb21016b9885b9dedeab9681d56c80d93535927bf1c44" +content-hash = "dcfffa248872d8a7dcf5f4b2142556ba54613005790e59a95d205c9b53889e1e" diff --git a/pyproject.toml b/pyproject.toml index 826c83e9..79b1c08c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,9 +22,9 @@ einops = { version="^0.6.1", optional=true } faiss-cpu = { version="^1.7.4", optional=true } gradio = { version="^3.34.0", optional=true } langchain = "^0.0.278" -llama-index = "^0.8.24" -llama-cpp-python = "^0.1.83" -llama-hub = "^0.0.26" +llama-index = "^0.8.25" +llama-cpp-python = "^0.2.2" +llama-hub = "^0.0.30" nbconvert = { version="^7.5.0", optional=true } openai = { version="^0.27.8", optional=true } pandas = "^2.0.2" From aa452acdfdee8c23e050b7be6abbb55e0e3740f1 Mon Sep 17 00:00:00 2001 From: rchan Date: Wed, 13 Sep 2023 08:25:21 +0100 Subject: [PATCH 5/9] llama-cpp use gpu --- slack_bot/run.py | 3 ++- slack_bot/slack_bot/models/llama_index.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/slack_bot/run.py b/slack_bot/run.py index d9b6decb..0499b08e 100755 --- a/slack_bot/run.py +++ b/slack_bot/run.py @@ -160,6 +160,7 @@ model_args = { "model_name": args.model_name, "path": args.path, + "n_gpu_layers": args.n_gpu_layers, "max_input_size": args.max_input_size, } elif model_name == "llama-index-hf": @@ -168,8 +169,8 @@ model_args = { "model_name": args.model_name, - "max_input_size": args.max_input_size, "device": args.device, + "max_input_size": args.max_input_size, } else: model_args = {} diff --git a/slack_bot/slack_bot/models/llama_index.py b/slack_bot/slack_bot/models/llama_index.py index fb3c83e9..3fc2bf2f 100644 --- a/slack_bot/slack_bot/models/llama_index.py +++ b/slack_bot/slack_bot/models/llama_index.py @@ -37,7 +37,6 @@ def __init__( max_input_size: int, data_dir: pathlib.Path, which_index: str, - device: str | None = None, chunk_size: Optional[int] = None, k: int = 3, chunk_overlap_ratio: float = 0.1, @@ -60,9 +59,6 @@ def __init__( which_index : str Which index to construct (if force_new_index is True) or use. Options are "handbook", "public", or "all_data". - device : str, optional - Device to use for the LLM, by default None. - This is ignored if the LLM is model from OpenAI or Azure. chunk_size : Optional[int], optional Maximum size of chunks to use, by default None. If None, this is computed as `ceil(max_input_size / k)`. @@ -81,7 +77,6 @@ def __init__( self.max_input_size = max_input_size self.model_name = model_name self.num_output = num_output - self.device = device if chunk_size is None: chunk_size = math.ceil(max_input_size / k) self.chunk_size = chunk_size @@ -382,7 +377,7 @@ def _prep_llm(self) -> LLM: # transform inputs into Llama2 format messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, - verbose=False, + verbose=True, ) @@ -390,6 +385,7 @@ class LlamaIndexHF(LlamaIndex): def __init__( self, model_name: str = "StabilityAI/stablelm-tuned-alpha-3b", + device: str = "auto", *args: Any, **kwargs: Any, ) -> None: @@ -402,13 +398,15 @@ def __init__( model_name : str, optional Model name from Huggingface's model hub, by default "StabilityAI/stablelm-tuned-alpha-3b". + device : str, optional + Device map to use for the LLM, by default "auto". """ + self.device = device super().__init__(*args, model_name=model_name, **kwargs) def _prep_llm(self) -> LLM: - dev = self.device or "auto" logging.info( - f"Setting up Huggingface LLM (model {self.model_name}) on device {dev}" + f"Setting up Huggingface LLM (model {self.model_name}) on device {self.device}" ) logging.info( f"HF-args: (context_window: {self.max_input_size}, num_output: {self.num_output})" From dbb01fb93c006212245936567119453163640026 Mon Sep 17 00:00:00 2001 From: rchan Date: Wed, 13 Sep 2023 09:22:11 +0100 Subject: [PATCH 6/9] replace cli args _ with - --- slack_bot/run.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/slack_bot/run.py b/slack_bot/run.py index 0499b08e..aab32e6f 100755 --- a/slack_bot/run.py +++ b/slack_bot/run.py @@ -47,7 +47,7 @@ action="store_true", ) parser.add_argument( - "--max_input_size", + "--max-input-size", "-max", type=int, help=( @@ -57,7 +57,7 @@ default=4096, ) parser.add_argument( - "--n_gpu_layers", + "--n-gpu-layers", "-ngl", type=int, help=( From 35d7868875c7f3d1d977d96b9138ab214a9c7835 Mon Sep 17 00:00:00 2001 From: rchan Date: Wed, 13 Sep 2023 09:39:01 +0100 Subject: [PATCH 7/9] replace - with _ in model_name --- slack_bot/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slack_bot/run.py b/slack_bot/run.py index aab32e6f..b8b6ef7c 100755 --- a/slack_bot/run.py +++ b/slack_bot/run.py @@ -24,7 +24,7 @@ "--model", "-m", help="Select which model to use", default=None, choices=MODELS ) parser.add_argument( - "--model_name", + "--model-name", "-n", type=str, help=( From 45b1849b861026338864f1d111ccd2e5cd2822a1 Mon Sep 17 00:00:00 2001 From: rchan Date: Wed, 13 Sep 2023 09:45:22 +0100 Subject: [PATCH 8/9] :memo: Add notebooks note in readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ca1d9a25..bd26f882 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ The Reginald project consists of: ├── docker │   └── Scripts for building a Docker image ├── models -│   └── REGinald models +│   └── REGinald models (in notebooks) └── slack_bot └── Python Slack bot ``` From 66cf7ac6228d24c6e7d191bf727ae44602068292 Mon Sep 17 00:00:00 2001 From: Rosie Wood Date: Thu, 14 Sep 2023 13:21:11 +0100 Subject: [PATCH 9/9] ensure 'hello' model still works --- slack_bot/run.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/slack_bot/run.py b/slack_bot/run.py index b8b6ef7c..c2cb6cce 100755 --- a/slack_bot/run.py +++ b/slack_bot/run.py @@ -175,12 +175,15 @@ else: model_args = {} - response_model = model( - force_new_index=force_new_index, - data_dir=data_dir, - which_index=which_index, - **model_args, - ) + if model_name == "hello": + response_model = model() + else: + response_model = model( + force_new_index=force_new_index, + data_dir=data_dir, + which_index=which_index, + **model_args, + ) # Initialise Bot with response model logging.info(f"Initalising bot with model: {response_model}")