Skip to content

Commit

Permalink
Merge pull request #71 from ruanrongman/main
Browse files Browse the repository at this point in the history
fix tool change llm bug and update humanfeedback tool
  • Loading branch information
Undertone0809 authored Sep 11, 2023
2 parents 9ea41aa + ee5f229 commit 5d0d02f
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 13 deletions.
8 changes: 4 additions & 4 deletions example/agent/iot_agent_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

def main():
# MQTT broker address and port
broker_address = "XXX"
broker_address = "xxx"
broker_port = 1883
# username and password
username = "XXX"
password = "XXX"
username = "xxx"
password = "xxx"
client = mqtt.Client()
client.username_pw_set(username, password)
client.connect(broker_address, broker_port)
Expand All @@ -41,7 +41,7 @@ def main():
),
]
agent = ToolAgent(tools)
prompt = """现在你是一个智能音箱,我现在感觉好暗"""
prompt = """现在你是一个智能音箱,你可以控制冷气,加热器和灯的开关,在开关之前请尽量询问人类,我现在感觉好暗"""
agent.run(prompt)


Expand Down
24 changes: 20 additions & 4 deletions promptulate/tools/human_feedback/tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from promptulate.tools import BaseTool
from typing import Callable

from promptulate.tools import Tool
from promptulate.utils.color_print import print_text


class HumanFeedBackTool(BaseTool):
def _print_func(content) -> None:
print_text(f"[Agent ask] {content}", "blue")


class HumanFeedBackTool(Tool):
"""A tool for running python code in a REPL."""

name: str = "human_feedback"
Expand All @@ -14,6 +20,16 @@ class HumanFeedBackTool(BaseTool):
"humans reason. "
)

def __init__(
self,
prompt_func: Callable[[str], None] = _print_func,
input_func: Callable = input,
**kwargs,
):
super().__init__(**kwargs)
self.prompt_func = prompt_func
self.input_func = input_func

def _run(self, content: str, *args, **kwargs) -> str:
print_text(f"[Agent ask] {content}", "blue")
return input()
self.prompt_func(content)
return self.input_func()
12 changes: 10 additions & 2 deletions promptulate/tools/iot_swith_mqtt/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,20 @@ class IotSwitchTool(Tool):
"If the operation of the device is successful, an OK will be returned, otherwise a failure will be returned."
)
llm_prompt_template: StringTemplate = prompt_template
llm: BaseLLM = ChatOpenAI(temperature=0.1)
client: mqtt.Client
rule_table: List[Dict]
api_wrapper: IotSwitchAPIWrapper = IotSwitchAPIWrapper()

def __init__(self, client: mqtt.Client, rule_table: List[Dict], **kwargs):
def __init__(
self,
llm: BaseLLM = None,
client: mqtt.Client = None,
rule_table: List[Dict] = None,
**kwargs
):
self.llm: BaseLLM = llm or ChatOpenAI(
temperature=0.1, enable_preset_description=False
)
self.client = client
self.rule_table = rule_table

Expand Down
7 changes: 6 additions & 1 deletion promptulate/tools/math/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ class Calculator(Tool):
"language of math expression."
)
llm_prompt_template: StringTemplate = prompt_template
llm: BaseLLM = ChatOpenAI(temperature=0, enable_preset_description=False)

def __init__(self, llm: BaseLLM = None, **kwargs):
self.llm: BaseLLM = llm or ChatOpenAI(
temperature=0, enable_preset_description=False
)
super().__init__(**kwargs)

def _run(self, prompt: str) -> str:
prompt = self.llm_prompt_template.format(question=prompt)
Expand Down
10 changes: 9 additions & 1 deletion tests/tools/test_human_feedback_toos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@


class TestHumanFeedBackTool(TestCase):
def prompt_func(self, content: str) -> None:
print(content)

def input_func(self):
return input()

def test_run(self):
tool = HumanFeedBackTool()
tool = HumanFeedBackTool(
prompt_func=self.prompt_func, input_func=self.input_func
)
result = tool.run("我好冷")
print(result)
2 changes: 1 addition & 1 deletion tests/tools/test_sleep_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class TestSleepTool(TestCase):
def test_run(self):
tool = SleepTool()
seconds = 1
seconds = "1s"
start_time = time.time()
result = tool.run(seconds)
duration = time.time() - start_time
Expand Down

0 comments on commit 5d0d02f

Please sign in to comment.