Skip to content

Commit

Permalink
adding gguf real time streaming method and examples
Browse files Browse the repository at this point in the history
  • Loading branch information
DARREN OBERST authored and DARREN OBERST committed May 11, 2024
1 parent 73f8e8d commit ff4fa6f
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 0 deletions.
53 changes: 53 additions & 0 deletions examples/Models/gguf_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@

""" This example illustrates how to use the stream method for GGUF models for fast streaming of inference,
especially for real-time chat interactions.
Please note that the stream method has been implemented for GGUF models starting in llmware-0.2.13. This will be
any model with GGUFGenerativeModel class, and generally includes models with names that end in "gguf".
See also the chat UI example in the UI examples folder.
We would recommend using a chat optimized model, and have included a representative list below.
"""


from llmware.models import ModelCatalog
from llmware.gguf_configs import GGUFConfigs

# sets an absolute output maximum for the GGUF engine - normally set by default at 256
GGUFConfigs().set_config("max_output_tokens", 1000)

chat_models = ["phi-3-gguf",
"llama-2-7b-chat-gguf",
"llama-3-instruct-bartowski-gguf",
"openhermes-mistral-7b-gguf",
"zephyr-7b-gguf",
"tiny-llama-chat-gguf"]

model_name = chat_models[0]

# maximum output can be set optionally at any number up to the "max_output_tokens" set
model = ModelCatalog().load_model(model_name, max_output=200)

text_out = ""

token_count = 0

prompt = "I am interested in gaining an understanding of the banking industry. What topics should I research?"

# since model.stream provides a generator, then use as follows to consume the generator

for streamed_token in model.stream(prompt):

text_out += streamed_token
if text_out.strip():
print(streamed_token, end="")

token_count += 1

# final output text and token count

print("\n\n***total text out***: ", text_out)
print("\n***total tokens***: ", token_count)
76 changes: 76 additions & 0 deletions examples/UI/gguf_streaming_chatbot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

""" This example shows how to build a local chatbot prototype using llmware and Streamlit. The example shows
how to use several GGUF chat models in the LLMWare catalog, along with using the model.stream method which
provides a real time generator for displaying the bot response in real-time.
This is purposefully super simple script (but surprisingly fun) to provide the core of the recipe.
The Streamlit code below is derived from Streamlit tutorials available at:
https://docs.streamlit.io/develop/tutorials/llms/build-conversational-apps
If you are new to using Steamlit, to run this example:
1. pip3 install streamlit
2. to run, go to the command line: streamlit run "path/to/gguf_streaming_chatbot.py"
"""

import streamlit as st
from llmware.models import ModelCatalog
from llmware.gguf_configs import GGUFConfigs

GGUFConfigs().set_config("max_output_tokens", 500)


def simple_chat_ui_app (model_name):

st.title(f"Simple Chat with {model_name}")

model = ModelCatalog().load_model(model_name, temperature=0.3, sample=True, max_output=450)

# initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []

# display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])

# accept user input
prompt = st.chat_input("Say something")
if prompt:

with st.chat_message("user"):
st.markdown(prompt)

with st.chat_message("assistant"):

# note that the st.write_stream method consumes a generator - so pass model.stream(prompt) directly
bot_response = st.write_stream(model.stream(prompt))

st.session_state.messages.append({"role": "user", "content": prompt})
st.session_state.messages.append({"role": "assistant", "content": bot_response})

return 0


if __name__ == "__main__":

# a few representative good chat models that can run locally
# note: will take a minute for the first time it is downloaded and cached locally

chat_models = ["phi-3-gguf",
"llama-2-7b-chat-gguf",
"llama-3-instruct-bartowski-gguf",
"openhermes-mistral-7b-gguf",
"zephyr-7b-gguf",
"tiny-llama-chat-gguf"]

model_name = chat_models[0]

simple_chat_ui_app(model_name)



9 changes: 9 additions & 0 deletions llmware/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,15 @@
"link": "https://huggingface.co/bartowski/Meta-Llama-3-8B-Instruct-GGUF",
"custom_model_files": [], "custom_model_repo": ""},

{"model_name": "tiny-llama-chat-gguf", "display_name": "tiny-llama-chat-gguf",
"model_family": "GGUFGenerativeModel", "model_category": "generative_local", "model_location": "llmware_repo",
"context_window": 2048, "instruction_following": False, "prompt_wrapper": "hf_chat",
"temperature": 0.3, "sample_default": True, "trailing_space": "",
"gguf_file": "tiny-llama-chat.gguf",
"gguf_repo": "llmware/bonchon",
"link": "https://huggingface.co/llmware/bonchon",
"custom_model_files": [], "custom_model_repo": ""},

# end - new llama-3 quantized models

# whisper-cpp models
Expand Down
111 changes: 111 additions & 0 deletions llmware/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6031,6 +6031,117 @@ def function_call(self, context, function=None, params=None, get_logits=True,

return output_response

def stream(self, prompt, add_context=None, add_prompt_engineering=None, api_key=None, inference_dict=None,
get_logits=False, disable_eos=False):

""" Main method for text streaming generation. Returns a generator function that yields one
token at a time for real-time streaming to console or UI. """

# first prepare the prompt

if add_context:
self.add_context = add_context

if add_prompt_engineering:
self.add_prompt_engineering = add_prompt_engineering

# update default handling for no add_prompt_engineering

if not self.add_prompt_engineering:
if self.add_context:
self.add_prompt_engineering = "default_with_context"
else:
self.add_prompt_engineering = "default_no_context"

# end - update

# show warning if function calling model
if self.fc_supported:
logging.warning("warning: this is a function calling model - using .inference may lead to unexpected "
"results. Recommended to use the .function_call method to ensure correct prompt "
"template packaging.")

# start with clean logits_record and output_tokens for each function call
self.logits_record = []
self.output_tokens = []

if get_logits:
self.get_logits = get_logits

if inference_dict:

if "temperature" in inference_dict:
self.temperature = inference_dict["temperature"]

if "max_tokens" in inference_dict:
self.target_requested_output_tokens = inference_dict["max_tokens"]

# prompt = prompt

if self.add_prompt_engineering:
prompt_enriched = self.prompt_engineer(prompt, self.add_context, inference_dict=inference_dict)
prompt_final = prompt_enriched

# most models perform better with no trailing space or line-break at the end of prompt
# -- in most cases, the trailing space will be ""
# -- yi model prefers a trailing "\n"
# -- keep as parameterized option to maximize generation performance
# -- can be passed either thru model_card or model config from HF

prompt = prompt_final + self.trailing_space

# output_response = self._inference(text_prompt)

# starts _inference here
completion_tokens = [] if len(prompt) > 0 else [self.token_bos()]

prompt_tokens = (
(
self.tokenize(prompt.encode("utf-8"), special=True)
if prompt != ""
else [self.token_bos()]
)
if isinstance(prompt, str)
else prompt
)

# confirm that input is smaller than context_window
input_len = len(prompt_tokens)
context_window = self.n_ctx()

if input_len > context_window:
logging.warning("update: GGUFGenerativeModel - input is too long for model context window - truncating")
min_output_len = 10
prompt_tokens = prompt_tokens[0:context_window-min_output_len]
input_len = len(prompt_tokens)

text = b""

# disable_eos = True

for token in self.generate(prompt_tokens):

completion_tokens.append(token)

if not disable_eos:
if token == self._token_eos:
break

if len(completion_tokens) > self.max_output_len:
break

# stop if combined input + output at context window size
if (input_len + len(completion_tokens)) >= context_window:
break

new_token = self.detokenize([token]).decode('utf-8',errors='ignore')

yield new_token

text_str = text.decode("utf-8", errors="ignore")

return text_str


class WhisperCPPModel:

Expand Down

0 comments on commit ff4fa6f

Please sign in to comment.