-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move to LangChain APIs with llama3-8b (#13)
* Move to LangChain APIs with llama3-8b * Add comments
- Loading branch information
1 parent
60d1049
commit d660d3f
Showing
5 changed files
with
162 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import importlib.metadata | ||
import argparse | ||
|
||
def parse_arguments(): | ||
""" parse input arguments """ | ||
version = importlib.metadata.version('Llama4U') | ||
parser = argparse.ArgumentParser(description=f'Llama4U v{version}') | ||
parser.add_argument('-r', '--repo_id', type=str, required=False, help='Repository ID') | ||
parser.add_argument('-f', '--filename', type=str, required=False, help='Filename') | ||
parser.add_argument('-q', '--query', type=str, required=False, help='Single Query') | ||
parser.add_argument('-v', '--verbose', type=int, required=False, help='Enable verbose output') | ||
return parser.parse_args() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,111 +1,158 @@ | ||
""" Llama4U """ | ||
import importlib.metadata | ||
import sys | ||
import argparse | ||
from math import exp | ||
from statistics import median | ||
from os import devnull | ||
from contextlib import contextmanager,redirect_stderr | ||
from termcolor import colored | ||
from huggingface_hub import hf_hub_download | ||
from llama_cpp import Llama | ||
import llama_cpp | ||
from langchain_community.llms.llamacpp import LlamaCpp | ||
from langchain.chains.conversation.base import ConversationChain | ||
from langchain.memory.buffer import ConversationBufferMemory | ||
from langchain_core.prompts import ( | ||
ChatPromptTemplate, HumanMessagePromptTemplate | ||
) | ||
from input.input import parse_arguments | ||
|
||
LLAMA4U_STR = 'Llama4U' | ||
|
||
class Llama4U(): | ||
""" Llama4U """ | ||
|
||
# Model config parameters | ||
model_kwargs = { | ||
"n_gpu_layers": -1, | ||
"logits_all": True, | ||
'split_mode':llama_cpp.LLAMA_SPLIT_MODE_LAYER, | ||
'vocab_only': False, | ||
'use_mmap': True, | ||
'use_mlock': False, | ||
'kv_overrides': None, | ||
'seed': llama_cpp.LLAMA_DEFAULT_SEED, | ||
'n_ctx': 2048, | ||
'n_batch': 512, | ||
'rope_scaling_type': llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, | ||
'pooling_type': llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED, | ||
'rope_freq_base': 0.0, | ||
'rope_freq_scale': 0.0, | ||
'yarn_ext_factor':-1.0, | ||
'yarn_attn_factor': 1.0, | ||
'yarn_beta_fast': 32.0, | ||
'yarn_beta_slow': 1.0, | ||
'yarn_orig_ctx': 0, | ||
'embedding': False, | ||
'offload_kqv': True, | ||
'flash_attn': False, | ||
'last_n_tokens_size': 64, | ||
'lora_scale': 1.0, | ||
'numa': False, | ||
'chat_format': 'llama-2', | ||
'chat_handler': None, | ||
'verbose':True, | ||
} | ||
|
||
# Chat config parameters | ||
chat_kwargs = { | ||
'temperature': 0.2, | ||
'top_p': 0.95, | ||
'top_k': 40, | ||
'min_p': 0.05, | ||
'typical_p': 1.0, | ||
'max_tokens': None, | ||
'echo': False, | ||
'presence_penalty':0.0, | ||
'frequency_penalty':0.0, | ||
'repeat_penalty':1.1, | ||
'tfs_z':1.0, | ||
'mirostat_mode': 0, | ||
'mirostat_tau': 5.0, | ||
'mirostat_eta': 0.1, | ||
'logprobs': True, | ||
#'top_logprobs': 1, | ||
} | ||
|
||
# Define the human message template | ||
human_template = HumanMessagePromptTemplate.from_template( | ||
"{history}<|eot_id|>\n\n{input}<|eot_id|>" | ||
) | ||
|
||
# Combine the templates into a chat prompt template | ||
chat_template = ChatPromptTemplate.from_messages([human_template]) | ||
|
||
def __init__(self, | ||
hf_repo_id, | ||
model_filename | ||
): | ||
if hf_repo_id is None: | ||
hf_repo_id="PawanKrd/Meta-Llama-3-8B-Instruct-GGUF" | ||
self.hf_repo_id='PawanKrd/Meta-Llama-3-8B-Instruct-GGUF' | ||
if model_filename is None: | ||
model_filename="llama-3-8b-instruct.Q3_K_M.gguf" | ||
model_path = hf_hub_download(repo_id=hf_repo_id, filename=model_filename) | ||
|
||
# Instantiate model from downloaded file | ||
self.llm = Llama( | ||
n_gpu_layers=-1, | ||
max_new_tokens=2048, | ||
model_path=model_path, | ||
logits_all=True, | ||
model_filename='llama-3-8b-instruct.Q3_K_M.gguf' | ||
self.model_path = hf_hub_download(repo_id=self.hf_repo_id, filename=model_filename) | ||
|
||
# Initialize LLM | ||
self.llm = LlamaCpp( | ||
model_path=self.model_path, | ||
**self.model_kwargs, | ||
) | ||
|
||
def start_chat_session(self): | ||
""" Chat session loop """ | ||
my_messages = [ | ||
{"role": "system", | ||
"content": "A chat between a curious user and an artificial intelligence assistant. \ | ||
The assistant gives helpful, and polite answers to the user's questions."}, | ||
] | ||
# Initialize Conversation "Chain" | ||
# using our LLM, chat template and config params | ||
self.conversation_chain = ConversationChain( | ||
llm=self.llm, | ||
prompt=self.chat_template, | ||
memory=ConversationBufferMemory(), | ||
llm_kwargs=self.chat_kwargs, | ||
) | ||
|
||
def process_user_input(self): | ||
""" Get input from stdout """ | ||
print(colored('>>> ', 'yellow'), end="") | ||
user_prompt = input() | ||
if user_prompt.lower() in ["exit", "quit", "bye"]: | ||
print(colored(f'{LLAMA4U_STR}: =====', 'yellow')) | ||
print("Chat session ended. Goodbye!") | ||
sys.exit(0) | ||
return user_prompt | ||
|
||
def start_chat_session(self, query=""): | ||
""" Chat session loop """ | ||
my_messages="" | ||
stop_next_iter = False | ||
for _ in range(50): | ||
if stop_next_iter: | ||
break | ||
|
||
# User's turn | ||
print(colored('You: =====', 'yellow')) | ||
user_prompt = input() | ||
if user_prompt.lower() in ["exit", "quit", "bye"]: | ||
print(colored('Assistant(Median Prob:1.0): =====', 'yellow')) | ||
print("Chat session ended. Goodbye!") | ||
break | ||
my_messages.append({"role": "user", "content": user_prompt}) | ||
if not query: | ||
my_messages = self.process_user_input() | ||
else: | ||
my_messages = query | ||
stop_next_iter = True | ||
|
||
# AI's turn | ||
response = self.llm.create_chat_completion(messages=my_messages, | ||
logprobs=True, | ||
top_logprobs=1, | ||
) | ||
logprobs = response["choices"][0]["logprobs"]["token_logprobs"] | ||
# Convert logprobs to probabilities | ||
probabilities = [exp(logprob) for logprob in logprobs] | ||
print(colored(f'Assistant(Median Prob:{median(probabilities)}): =====', 'yellow')) | ||
print(response["choices"][0]["message"]["content"]) | ||
|
||
def single_query(self, query): | ||
""" Single Query Mode """ | ||
response = self.llm.create_chat_completion([{"role": "user", "content": query}], | ||
logprobs=True, | ||
top_logprobs=1, | ||
) | ||
if response: | ||
logprobs = response["choices"][0]["logprobs"]["token_logprobs"] | ||
# Convert logprobs to probabilities | ||
probabilities = [exp(logprob) for logprob in logprobs] | ||
print(f'Assistant(Median Prob:{median(probabilities)}): =====') | ||
print(response["choices"][0]["message"]["content"]) | ||
sys.exit(0) | ||
else: | ||
print("Query failed") | ||
sys.exit(1) | ||
response = self.conversation_chain.predict(input=my_messages) | ||
print(response.strip()) | ||
|
||
@contextmanager | ||
def suppress_stderr(): | ||
"""A context manager that redirects stderr to devnull""" | ||
with open(devnull, 'w', encoding='utf-8') as fnull: | ||
with redirect_stderr(fnull) as err: | ||
yield err | ||
|
||
def parse_arguments(): | ||
""" parse input arguments """ | ||
version = importlib.metadata.version('Llama4U') | ||
parser = argparse.ArgumentParser(description=f'Llama4U v{version}') | ||
parser.add_argument('-r', '--repo_id', type=str, required=False, help='Repository ID') | ||
parser.add_argument('-f', '--filename', type=str, required=False, help='Filename') | ||
parser.add_argument('-q', '--query', type=str, required=False, help='Single Query') | ||
return parser.parse_args() | ||
def suppress_stderr(verbose): | ||
"""A context manager that redirects stderr to devnull based on verbose selection """ | ||
if verbose <= 0: | ||
with open(devnull, 'w', encoding='utf-8') as fnull: | ||
with redirect_stderr(fnull) as err: | ||
yield err | ||
else: | ||
yield () | ||
|
||
def main(): | ||
""" Pip Package entrypoint """ | ||
args = parse_arguments() | ||
repo_id = args.repo_id | ||
filename = args.filename | ||
|
||
with suppress_stderr(): | ||
llama4u = Llama4U(repo_id, filename) | ||
if args.verbose: | ||
verbose = args.verbose | ||
else: | ||
verbose = 0 | ||
|
||
if args.query: | ||
llama4u.single_query(args.query) | ||
else: | ||
llama4u.start_chat_session() | ||
with suppress_stderr(verbose): | ||
llama4u = Llama4U(args.repo_id, args.filename) | ||
llama4u.start_chat_session(args.query) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings | ||
from llama_index.core.readers import SimpleDirectoryReader | ||
from llama_index.core.indices import VectorStoreIndex | ||
from llama_index.core.service_context import ServiceContext | ||
from llama_index.core.settings import Settings | ||
|
||
class DocReader(): | ||
def __init__(self, main_model, st_model='mixedbread-ai/mxbai-embed-large-v1', directory_path='/mnt/c/Users/viraj/Documents/ai_db/'): | ||
self.embed_model = HuggingFaceEmbeddings(model_name=st_model) | ||
|
||
self.directory_path = directory_path | ||
reader = SimpleDirectoryReader(directory_path) | ||
docs = reader.load_data() | ||
|
||
Settings.llm = main_model | ||
Settings.context_window = 8000 | ||
service_context = ServiceContext.from_defaults(embed_model=self.embed_model) | ||
self.index = VectorStoreIndex.from_documents(docs, service_context=service_context) | ||
|
||
def get_query_engine(self, model): | ||
query_engine = self.index.as_query_engine(model) | ||
return query_engine |
File renamed without changes.