Skip to content

Commit

Permalink
Merge pull request #85 from alan-turing-institute/llama2-chat
Browse files Browse the repository at this point in the history
Add chat engine mode
  • Loading branch information
rchan26 authored Sep 14, 2023
2 parents 49fef87 + ee4106f commit c28f5dc
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 20 deletions.
23 changes: 21 additions & 2 deletions slack_bot/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@
),
default=None,
)
parser.add_argument(
"--mode",
type=str,
help=(
"Select which mode to use "
"(ignored if not using llama-index-llama-cpp or llama-index-hf). "
"Default is 'chat'."
),
default=None,
choices=["chat", "query"],
)
parser.add_argument(
"--path",
"-p",
Expand Down Expand Up @@ -87,7 +98,7 @@
"-d",
type=pathlib.Path,
help="Location for data",
default=(pathlib.Path(__file__).parent.parent / "data").resolve(),
default=None,
)
parser.add_argument(
"--which-index",
Expand All @@ -100,7 +111,7 @@
"files in the data directory, 'handbook' will "
"only use 'handbook.csv' file."
),
default="all_data",
default=None,
choices=["all_data", "public", "handbook"],
)

Expand Down Expand Up @@ -142,6 +153,13 @@
if not which_index:
which_index = "all_data"

# Set mode
mode = os.environ.get("LLAMA_MODE")
if args.mode:
mode = args.mode
if not mode:
mode = "chat"

# Initialise a new Slack bot with the requested model
try:
model = MODELS[model_name.lower()]
Expand Down Expand Up @@ -182,6 +200,7 @@
force_new_index=force_new_index,
data_dir=data_dir,
which_index=which_index,
mode=mode,
**model_args,
)

Expand Down
13 changes: 5 additions & 8 deletions slack_bot/slack_bot/bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ def __call__(self, client: SocketModeClient, req: SocketModeRequest) -> None:
try:
# Extract event from payload
event = req.payload["event"]
sender_is_bot = "bot_id" in event

# Ignore messages from bots
if sender_is_bot:
logging.info(f"Ignoring an event triggered by a bot.")
if event.get("bot_id") is not None:
logging.info("Ignoring an event triggered by a bot.")
return None
if event.get("hidden") is not None:
logging.info("Ignoring hidden message.")
return None

# Extract user and message information
Expand All @@ -38,11 +40,6 @@ def __call__(self, client: SocketModeClient, req: SocketModeRequest) -> None:
event_type = event["type"]
event_subtype = event.get("subtype", None)

# Ignore changes to messages.
if event_type == "message" and event_subtype == "message_changed":
logging.info(f"Ignoring a change to a message.")
return None

# Start processing the message
logging.info(f"Processing message '{message}' from user '{user_id}'.")

Expand Down
44 changes: 34 additions & 10 deletions slack_bot/slack_bot/models/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import pathlib
import re
import sys
from typing import Any, List, Optional

import pandas as pd
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(
data_dir: pathlib.Path,
which_index: str,
chunk_size: Optional[int] = None,
mode: str = "chat",
k: int = 3,
chunk_overlap_ratio: float = 0.1,
force_new_index: bool = False,
Expand All @@ -62,6 +64,9 @@ def __init__(
chunk_size : Optional[int], optional
Maximum size of chunks to use, by default None.
If None, this is computed as `ceil(max_input_size / k)`.
mode : Optional[str], optional
The type of engine to use when interacting with the data, options of "chat" or "query".
Default is "chat".
k : int, optional
`similarity_top_k` to use in query engine, by default 3
chunk_overlap_ratio : float, optional
Expand All @@ -79,6 +84,7 @@ def __init__(
self.num_output = num_output
if chunk_size is None:
chunk_size = math.ceil(max_input_size / k)
self.mode = mode
self.chunk_size = chunk_size
self.chunk_overlap_ratio = chunk_overlap_ratio
self.data_dir = data_dir
Expand Down Expand Up @@ -132,8 +138,17 @@ def __init__(
storage_context=storage_context, service_context=service_context
)

self.query_engine = self.index.as_query_engine(similarity_top_k=k)
logging.info("Done setting up Huggingface backend for query engine.")
if self.mode == "query":
self.query_engine = self.index.as_query_engine(similarity_top_k=k)
logging.info("Done setting up Huggingface backend for query engine.")
elif self.mode == "chat":
self.chat_engine = self.index.as_chat_engine(
chat_mode="context", similarity_top_k=k
)
logging.info("Done setting up Huggingface backend for chat engine.")
else:
logging.error("Mode must either be 'query' or 'chat'.")
sys.exit(1)

self.error_response_template = (
"Oh no! When I tried to get a response to your prompt, "
Expand Down Expand Up @@ -170,7 +185,7 @@ def _format_sources(response: RESPONSE_TYPE) -> str:

def _get_response(self, msg_in: str, user_id: str) -> str:
"""
Method to obtain a response from the query engine given
Method to obtain a response from the query/chat engine given
a message and a user id.
Parameters
Expand All @@ -186,13 +201,22 @@ def _get_response(self, msg_in: str, user_id: str) -> str:
String containing the response from the query engine.
"""
try:
query_response = self.query_engine.query(msg_in)
# concatenate the response with the resources that it used
response = (
query_response.response
+ "\n\n\n"
+ self._format_sources(query_response)
)
if self.mode == "query":
query_response = self.query_engine.query(msg_in)
# concatenate the response with the resources that it used
response = (
query_response.response
+ "\n\n\n"
+ self._format_sources(query_response)
)
elif self.mode == "chat":
chat_response = self.chat_engine.chat(msg_in)
# concatenate the response with the resources that it used
response = (
chat_response.response
+ "\n\n\n"
+ self._format_sources(chat_response)
)
except Exception as e: # ignore: broad-except
response = self.error_response_template.format(repr(e))
pattern = (
Expand Down

0 comments on commit c28f5dc

Please sign in to comment.