Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Clipboard Interactions and User Input Suggestions #5

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
*.backup
*.history
*.log
*.json
*.pyc
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ requests==2.31.0
rich==13.7.0
urllib3==2.2.1
prompt-toolkit==3.0.43
pyperclip==1.8.2
Empty file added rich_chat/__init__.py
Empty file.
114 changes: 114 additions & 0 deletions rich_chat/chat_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import json
import os
import re
from pathlib import Path
from typing import Dict, List

from prompt_toolkit import PromptSession
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.clipboard.pyperclip import PyperclipClipboard
from prompt_toolkit.history import FileHistory
from prompt_toolkit.key_binding import KeyBindings


class ChatHistory:
def __init__(self, session_name: str, system_message: str = None):
# Define the cache path for storing chat history
home = os.environ.get("HOME", ".") # get user's home path, else assume cwd
cache = Path(f"{home}/.cache/rich-chat") # set the cache path
cache.mkdir(parents=True, exist_ok=True) # ensure the directory exists

# Define the file path for storing chat history
self.file_path = cache / f"{session_name}.json"

# Define the file path for storing prompt session history
file_history_path = cache / f"{session_name}.history"
self.session = PromptSession(history=FileHistory(file_history_path))
self.auto_suggest = AutoSuggestFromHistory()

# Define the list for tracking chat messages.
# Each message is a dictionary with the following structure:
# {"role": "user/assistant/system", "content": "<message content>"}
self.messages: List[Dict[str, str]] = []
if system_message is not None:
self.messages.append({"role": "system", "content": system_message})

@property
def key_bindings(self) -> KeyBindings:
kb = KeyBindings()
clipboard = PyperclipClipboard()

for i in range(9):

@kb.add("c-s", "a", str(i))
def _(event):
"""Copy the entire last message to the system clipboard."""
if self.messages:
# this doesn't auto-update. we need to re-render the toolbar somehow.
self.bottom_toolbar = "Copied last message into clipboard!"
# look at the last key
key = int(event.key_sequence[-1].key)
# look at the content with the given key
# note: referenced key may not exist and can trigger a IndexError
last_message_content = self.messages[-key]["content"].strip()
clipboard.set_text(last_message_content)

@kb.add("c-s", "s", str(i))
def _(event):
"""Copy only code snippets from the last message to the system clipboard."""
if self.messages:
self.bottom_toolbar = (
"Copied code blocks from last message into clipboard!"
)
key = int(event.key_sequence[-1].key)
last_message_content = self.messages[-key]["content"].strip()
code_snippets = re.findall(
r"```(.*?)```", last_message_content, re.DOTALL
)
snippets_content = "\n\n".join(code_snippets)
clipboard.set_text(snippets_content)

return kb

def load(self) -> List[Dict[str, str]]:
try:
with open(self.file_path, "r") as chat_session:
self.messages = json.load(chat_session)
return self.messages
except (FileNotFoundError, json.JSONDecodeError):
self.save() # create the missing file
print(f"ChatHistoryLoad: Created new cache: {self.file_path}")

def save(self) -> None:
try:
with open(self.file_path, "w") as chat_session:
json.dump(self.messages, chat_session, indent=2)
except TypeError as e:
print(f"ChatHistoryWrite: {e}")

def prompt(self) -> str:
# Prompt the user for input
return self.session.prompt(
"Prompt: (⌥ + ⏎) | Copy: ((⌘ + s) (a|s) (.[0-9])) | Exit: (⌘ + c):\n",
key_bindings=self.key_bindings,
auto_suggest=self.auto_suggest,
multiline=True,
).strip()

def append(self, message: Dict[str, str]) -> None:
self.messages.append(message)

def insert(self, index: int, element: object) -> None:
self.messages.insert(index, element)

def pop(self, index: int) -> Dict[str, str]:
return self.messages.pop(index)

def replace(self, index: int, content: str) -> None:
try:
self.messages[index]["content"] = content
except (IndexError, KeyError) as e:
print(f"ChatHistoryReplace: Failed to substitute chat message: {e}")

def reset(self) -> None:
self.messages = []
86 changes: 72 additions & 14 deletions source/rich-chat.py → rich_chat/rich_chat_cli.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
#!/usr/bin/env python

import argparse
import json
import os

import requests
from prompt_toolkit import PromptSession
from prompt_toolkit.history import FileHistory
from prompt_toolkit import prompt as input
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from rich.panel import Panel

from rich_chat.chat_history import ChatHistory


def remove_lines_console(num_lines):
for _ in range(num_lines):
Expand All @@ -28,51 +31,68 @@ def estimate_lines(text):
return line_count


def handle_console_input(session: PromptSession) -> str:
return session.prompt("(Prompt: ⌥ + ⏎) | (Exit: ⌘ + c):\n", multiline=True).strip()


class conchat:
def __init__(
self,
server_addr,
min_p: float,
repeat_penalty: float,
seed: int,
top_k=10,
top_p=0.95,
temperature=0.12,
n_predict=-1,
stream: bool = True,
cache_prompt: bool = True,
model_frame_color: str = "red",
chat_history: ChatHistory = None,
) -> None:
self.model_frame_color = model_frame_color
self.serveraddr = server_addr
self.topk = top_k
self.top_p = top_p
self.seed = seed
self.min_p = min_p
self.repeat_penalty = repeat_penalty
self.temperature = temperature
self.n_predict = n_predict
self.stream = stream
self.cache_prompt = cache_prompt
self.headers = {"Content-Type": "application/json"}
self.chat_history = []
self.chat_history = chat_history
self.model_name = ""

self.console = Console()

# TODO: Gracefully handle user input history file.
self.session = PromptSession(history=FileHistory(".rich-chat.history"))
self._render_messages_once_on_start()

def _render_messages_once_on_start(self) -> None:
self.chat_history.load()
for message in self.chat_history.messages:
title = message["role"] if message["role"] != "user" else "HUMAN"
self.console.print(
Panel(
Markdown(message["content"]),
title=title.upper(),
title_align="left",
)
)

def chat_generator(self, prompt):
endpoint = self.serveraddr + "/v1/chat/completions"
self.chat_history.append({"role": "user", "content": prompt})

payload = {
"messages": self.chat_history,
"messages": self.chat_history.messages,
"temperature": self.temperature,
"top_k": self.topk,
"top_p": self.top_p,
"n_predict": self.n_predict,
"stream": self.stream,
"cache_prompt": self.cache_prompt,
"seed": self.seed,
"repeat_penalty": self.repeat_penalty,
"min_p": self.min_p,
}
try:
response = requests.post(
Expand Down Expand Up @@ -150,7 +170,7 @@ def chat(self):
self.model_name = self.get_model_name()
while True:
try:
user_m = handle_console_input(self.session)
user_m = self.chat_history.prompt()
remove_lines_console(estimate_lines(text=user_m))
self.console.print(
Panel(Markdown(user_m), title="HUMAN", title_align="left")
Expand All @@ -160,6 +180,7 @@ def chat(self):
# NOTE: Ctrl + c (keyboard) or Ctrl + d (eof) to exit
# Adding EOFError prevents an exception and gracefully exits.
except (KeyboardInterrupt, EOFError):
self.chat_history.save()
exit()


Expand Down Expand Up @@ -194,17 +215,54 @@ def main():
type=int,
help="The number defines how many tokens to be predict by the model. Default: infinity until [stop] token.",
)
parser.add_argument(
"--minp",
type=float,
default=0.5,
help="The minimum probability for a token to be considered, relative to the probability of the most likely token (default: 0.05).",
)
parser.add_argument(
"--repeat-penalty",
type=float,
default=1.1,
help="Control the repetition of token sequences in the generated text (default: 1.1).",
)
parser.add_argument(
"--seed",
type=int,
default=-1,
help="Set the random number generator (RNG) seed (default: -1, -1 = random seed).",
)
parser.add_argument(
"-m",
"--system-message",
type=str,
default=None, # empty by default; avoiding assumptions.
help="The system message used to orientate the model, if any.",
)
parser.add_argument(
"-n",
"--session-name",
type=str,
default="rich-chat",
help="The name of the chat session. Default is 'rich-chat'.",
)

args = parser.parse_args()
# print(args)
# print(f"ARG of server is {args.server}")
# print(f"argument of bot color is {args.model_frame_color}")

# Defaults to Path(".") if args.chat_history is ""
chat_history = ChatHistory(args.session_name, args.system_message)

chat = conchat(
server_addr=args.server,
top_k=args.topk,
top_p=args.topp,
temperature=args.temperature,
model_frame_color=args.model_frame_color,
min_p=args.minp,
seed=args.seed,
repeat_penalty=args.repeat_penalty,
chat_history=chat_history,
)
chat.chat()

Expand Down