Skip to content

Commit

Permalink
Merge pull request #122 from Undertone0809/feat-v1.10.0
Browse files Browse the repository at this point in the history
test: optimize tests
  • Loading branch information
Undertone0809 authored Nov 7, 2023
2 parents 3e5cffa + 8a80e9c commit da34d31
Show file tree
Hide file tree
Showing 27 changed files with 630 additions and 286 deletions.
14 changes: 8 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@ OS := $(shell python -c "import sys; print(sys.platform)")
ifeq ($(OS),win32)
PYTHONPATH := $(shell python -c "import os; print(os.getcwd())")
TEST_COMMAND := set PYTHONPATH=$(PYTHONPATH) && poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate ./tests/test_chat.py ./tests/output_formatter
TEST_PROD_COMMAND := set PYTHONPATH=$(PYTHONPATH) && poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate tests
else
PYTHONPATH := `pwd`
TEST_COMMAND := PYTHONPATH=$(PYTHONPATH) poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate ./tests/test_chat.py ./tests/output_formatter
TEST_PROD_COMMAND := PYTHONPATH=$(PYTHONPATH) poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate tests
endif

#* Poetry
.PHONY: poetry-download
poetry-download:
pip install poetry

#* Installation
.PHONY: install
Expand All @@ -33,14 +31,18 @@ polish-codestyle:
.PHONY: formatting
formatting: polish-codestyle



#* Linting
.PHONY: test
test:
$(TEST_COMMAND)
poetry run coverage-badge -o docs/images/coverage.svg -f

#* Linting
.PHONY: test-prod
test-prod:
$(TEST_PROD_COMMAND)
poetry run coverage-badge -o docs/images/coverage.svg -f

.PHONY: check-codestyle
check-codestyle:
poetry run isort --diff --check-only --settings-path pyproject.toml promptulate tests example
Expand Down
2 changes: 1 addition & 1 deletion docs/_coverpage.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

<div class="title">Promptulate</div>

> All you need is an elegant LLM development framework.
> All you need is an elegant LLM Agent development framework.

<p align="center">
Expand Down
11 changes: 5 additions & 6 deletions docs/get_started/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,10 @@ if __name__ == "__main__":
- Python >= 3.8
- make

> make 不是必须的,但是利用 makefile 的能力轻松集成运行 test、lint 等模块。
> 本项目使用 make 进行项目配套设施的构建,通过 makefile 的能力轻松集成运行 test、lint 等模块,请确保你的电脑已经安装了 make。
>
> [how to install and use make in windows?](https://stackoverflow.com/questions/32127524/how-to-install-and-use-make-in-windows)

运行以下命令:

Expand All @@ -208,12 +211,8 @@ pip install poetry
make install
```

如果你没有安装 make,也可以使用如下方式安装:

```shell
pip install poetry
poetry install
```
本项目使用配备代码语法检查工具,如果你想提交 pr,则需要在 commit 之前运行 `make polish-codestyle` 进行代码规范格式化,并且运行 `make lint` 通过语法与单元测试的检查。

## 更多

Expand Down
6 changes: 3 additions & 3 deletions docs/images/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 18 additions & 0 deletions example/llm/custom_conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from promptulate.llms import ChatOpenAI
from promptulate.schema import AssistantMessage, MessageSet, SystemMessage, UserMessage


def main():
messages = MessageSet(
messages=[
SystemMessage(content="You are a helpful assitant"),
UserMessage(content="Hello?"),
]
)

llm = ChatOpenAI()
answer: AssistantMessage = llm.predict(messages)
print(answer.content)


main()
12 changes: 12 additions & 0 deletions example/llm/llm_private_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""This example will show how to use a specified key in OpenAI model."""
from promptulate.llms import ChatOpenAI


def main():
llm = ChatOpenAI()
llm.set_private_api_key("your key here")
print(llm("hello"))


if __name__ == "__main__":
main()
371 changes: 361 additions & 10 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion promptulate/client/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def chat():
if args.proxy_mode:
set_proxy_mode(args.proxy_mode)

print_text(f"Hi there, here is promptulate chat terminal.", "pink")
print_text("Hi there, here is promptulate chat terminal.", "pink")

terminal_mode = questionary.select(
"Choose a chat terminal:",
Expand Down
8 changes: 1 addition & 7 deletions promptulate/llms/erniebot/erniebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,7 @@

from promptulate.config import Config
from promptulate.llms import BaseLLM
from promptulate.schema import (
AssistantMessage,
BaseMessage,
LLMType,
MessageSet,
UserMessage,
)
from promptulate.schema import AssistantMessage, LLMType, MessageSet, UserMessage
from promptulate.tips import LLMError
from promptulate.utils import get_logger

Expand Down
5 changes: 4 additions & 1 deletion promptulate/tools/arxiv/api_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class Config:
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that python package exists in environment."""
try:
import arxiv
import arxiv # noqa
except ImportError:
raise ValueError(
"Could not import arxiv python package. "
Expand All @@ -83,10 +83,13 @@ def _query(
"""
if not keyword:
keyword = ""

if not id_list:
id_list = []

if not num_results:
num_results = self.max_num_of_result

if isinstance(id_list, str):
id_list = [id_list]

Expand Down
2 changes: 1 addition & 1 deletion promptulate/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class Tool(ABC):
description: str
"""Tool description"""

def __init__(self, **kwargs):
def __init__(self, *args, **kwargs):
self.check_params()
if "hooks" in kwargs and kwargs["hooks"]:
for hook in kwargs["hooks"]:
Expand Down
2 changes: 1 addition & 1 deletion promptulate/tools/duckduckgo/api_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Config:
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that python package exists in environment."""
try:
from duckduckgo_search import DDGS
from duckduckgo_search import DDGS # noqa
except ImportError:
raise ValueError(
"Could not import duckduckgo-search python package. "
Expand Down
21 changes: 11 additions & 10 deletions promptulate/tools/human_feedback/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from promptulate.utils.color_print import print_text


def _print_func(content) -> None:
print_text(f"[Agent ask] {content}", "blue")
def _print_func(llm_question: str) -> None:
"""Default way to show llm question when llm using HumanFeedBackTool."""
print_text(f"[Agent ask] {llm_question}", "blue")


class HumanFeedBackTool(Tool):
Expand All @@ -14,22 +15,22 @@ class HumanFeedBackTool(Tool):
name: str = "human_feedback"
description: str = (
"Human feedback tools are used to collect human feedback information."
"Please only use this tool in situations where relevant contextual information is lacking or reasoning cannot "
"continue."
"Please enter the content you wish for human feedback and interaction, but do not ask for knowledge or let "
"humans reason. "
"Please only use this tool in situations where relevant contextual information"
"is lacking or reasoning cannot continue. Please enter the content you wish for "
"human feedback and interaction, but do not ask for knowledge or let humans reason."
)

def __init__(
self,
prompt_func: Callable[[str], None] = _print_func,
output_func: Callable[[str], None] = _print_func,
input_func: Callable = input,
*args,
**kwargs,
):
super().__init__(**kwargs)
self.prompt_func = prompt_func
super().__init__(*args, **kwargs)
self.output_func = output_func
self.input_func = input_func

def _run(self, content: str, *args, **kwargs) -> str:
self.prompt_func(content)
self.output_func(content)
return self.input_func()
2 changes: 1 addition & 1 deletion promptulate/tools/iot_swith_mqtt/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(

def _run(self, question: str, *args, **kwargs) -> str:
try:
import paho.mqtt.client as mqtt
import paho.mqtt.client as mqtt # noqa
except ImportError:
raise ImportError(
"Could not import paho python package. "
Expand Down
2 changes: 1 addition & 1 deletion promptulate/tools/paper/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_advice():
if paper_info:
paper_info = paper_info[0]

except NetWorkError as e:
except NetWorkError:
paper_info = self.arxiv_apiwrapper.query(
keyword=query, num_results=1, specified_fields=["title", "summary"]
)
Expand Down
2 changes: 1 addition & 1 deletion promptulate/tools/semantic_scholar/api_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_detail():
self.current_result = response.json()["matches"]

if len(self.current_result) == 0:
logger.debug(f"[pne] semantic scholar return none")
logger.debug("[pne] semantic scholar return none")
return []

for item in self.current_result:
Expand Down
4 changes: 2 additions & 2 deletions promptulate/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import os

from promptulate import utils
from promptulate.utils.core_utils import get_default_storage_path

logger = logging.getLogger(__name__)

Expand All @@ -31,7 +31,7 @@ def get_logger():


def get_default_log_path():
return utils.get_default_storage_path("log")
return get_default_storage_path("log")


def get_log_name() -> str:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pytest-cov = "^3.0.0"
coverage = "^6.1.2"
pre-commit = "^3.5.0"
coverage-badge = "^1.1.0"
langchain = "^0.0.324"

[tool.poetry.scripts]
pne-chat = "promptulate.client.chat:main"
Expand Down
78 changes: 39 additions & 39 deletions tests/framework/test_conversation.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,46 @@
from unittest import TestCase
# from unittest import TestCase

from promptulate.frameworks.conversation import Conversation
from promptulate.memory import FileChatMemory
from promptulate.utils.logger import enable_log, get_logger
# from promptulate.frameworks.conversation import Conversation
# from promptulate.memory import FileChatMemory
# from promptulate.utils.logger import enable_log, get_logger

enable_log()
logger = get_logger()
# enable_log()
# logger = get_logger()


class TestConversation(TestCase):
def test_predict(self):
conversation = Conversation()
result = conversation.run("什么是大语言模型")
self.assertIsNotNone(result)
self.assertTrue("大语言模型" in result)
# class TestConversation(TestCase):
# def test_predict(self):
# conversation = Conversation()
# result = conversation.run("什么是大语言模型")
# self.assertIsNotNone(result)
# self.assertTrue("大语言模型" in result)

def test_predict_with_stop(self):
conversation = Conversation()
prompt = """
Please strictly output the following content.
```
[start] This is a test [end]
```
"""
result = conversation.run(prompt, stop=["test"])
self.assertTrue("test [end]" not in result)
self.assertIsNotNone(result)
# def test_predict_with_stop(self):
# conversation = Conversation()
# prompt = """
# Please strictly output the following content.
# ```
# [start] This is a test [end]
# ```
# """
# result = conversation.run(prompt, stop=["test"])
# self.assertTrue("test [end]" not in result)
# self.assertIsNotNone(result)

def test_memory_with_buffer(self):
conversation = Conversation()
prompt = """给我想5个公司的名字"""
conversation.run(prompt)
conversation_id = conversation.conversation_id
new_conversation = Conversation(conversation_id=conversation_id)
new_conversation.predict("再给我五个")
# def test_memory_with_buffer(self):
# conversation = Conversation()
# prompt = """give me 5 company names"""
# conversation.run(prompt)
# conversation_id = conversation.conversation_id
# new_conversation = Conversation(conversation_id=conversation_id)
# new_conversation.predict("give me 5 more")

def test_memory_with_file(self):
conversation = Conversation(memory=FileChatMemory())
prompt = """给我想5个公司的名字"""
conversation.run(prompt)
conversation_id = conversation.conversation_id
new_conversation = Conversation(
conversation_id=conversation_id, memory=FileChatMemory()
)
new_conversation.predict("再给我五个")
# def test_memory_with_file(self):
# conversation = Conversation(memory=FileChatMemory())
# prompt = """give me 5 company names"""
# conversation.run(prompt)
# conversation_id = conversation.conversation_id
# new_conversation = Conversation(
# conversation_id=conversation_id, memory=FileChatMemory()
# )
# new_conversation.predict("give me 5 more")
7 changes: 4 additions & 3 deletions tests/hook/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def handle_result(*args, **kwargs):
self.assertIsNotNone(result)
print(f"<instance> result: {result}")

hooks = [handle_result, handle_start, handle_result]
hooks = [handle_create, handle_start, handle_result]
tools = [DuckDuckGoTool(), Calculator()]
agent = ToolAgent(tools=tools, hooks=hooks)
agent.run("What is promptulate?")
Expand Down Expand Up @@ -71,8 +71,9 @@ def handle_result(*args, **kwargs):
self.assertIsNotNone(result)
print(f"<component> result: {result}")

tool = DuckDuckGoTool()
tool.run("What is promptulate?")
tools = [DuckDuckGoTool(), Calculator()]
agent = ToolAgent(tools=tools)
agent.run("What is promptulate?")

self.assertTrue(create_flag)
self.assertTrue(start_flag)
Expand Down
2 changes: 1 addition & 1 deletion tests/hook/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def handle_result(*args, **kwargs):
self.assertIsNotNone(result)
print(f"<instance> result: {result}")

hooks = [handle_result, handle_start, handle_result]
hooks = [handle_create, handle_start, handle_result]
tool = DuckDuckGoTool(hooks=hooks)
tool.run("What is LLM?")

Expand Down
Loading

0 comments on commit da34d31

Please sign in to comment.