diff --git a/example/agent/iot_agent_usage.py b/example/agent/iot_agent_usage.py index 175f552d..3d926b6f 100644 --- a/example/agent/iot_agent_usage.py +++ b/example/agent/iot_agent_usage.py @@ -14,11 +14,11 @@ def main(): # MQTT broker address and port - broker_address = "XXX" + broker_address = "xxx" broker_port = 1883 # username and password - username = "XXX" - password = "XXX" + username = "xxx" + password = "xxx" client = mqtt.Client() client.username_pw_set(username, password) client.connect(broker_address, broker_port) @@ -41,7 +41,7 @@ def main(): ), ] agent = ToolAgent(tools) - prompt = """现在你是一个智能音箱,我现在感觉好暗""" + prompt = """现在你是一个智能音箱,你可以控制冷气,加热器和灯的开关,在开关之前请尽量询问人类,我现在感觉好暗。""" agent.run(prompt) diff --git a/promptulate/tools/human_feedback/tools.py b/promptulate/tools/human_feedback/tools.py index 2dee1ee7..4e1ca9b0 100644 --- a/promptulate/tools/human_feedback/tools.py +++ b/promptulate/tools/human_feedback/tools.py @@ -1,8 +1,14 @@ -from promptulate.tools import BaseTool +from typing import Callable + +from promptulate.tools import Tool from promptulate.utils.color_print import print_text -class HumanFeedBackTool(BaseTool): +def _print_func(content) -> None: + print_text(f"[Agent ask] {content}", "blue") + + +class HumanFeedBackTool(Tool): """A tool for running python code in a REPL.""" name: str = "human_feedback" @@ -14,6 +20,16 @@ class HumanFeedBackTool(BaseTool): "humans reason. " ) + def __init__( + self, + prompt_func: Callable[[str], None] = _print_func, + input_func: Callable = input, + **kwargs, + ): + super().__init__(**kwargs) + self.prompt_func = prompt_func + self.input_func = input_func + def _run(self, content: str, *args, **kwargs) -> str: - print_text(f"[Agent ask] {content}", "blue") - return input() + self.prompt_func(content) + return self.input_func() diff --git a/promptulate/tools/iot_swith_mqtt/tools.py b/promptulate/tools/iot_swith_mqtt/tools.py index 81ad5aec..936a6ce1 100644 --- a/promptulate/tools/iot_swith_mqtt/tools.py +++ b/promptulate/tools/iot_swith_mqtt/tools.py @@ -23,12 +23,20 @@ class IotSwitchTool(Tool): "If the operation of the device is successful, an OK will be returned, otherwise a failure will be returned." ) llm_prompt_template: StringTemplate = prompt_template - llm: BaseLLM = ChatOpenAI(temperature=0.1) client: mqtt.Client rule_table: List[Dict] api_wrapper: IotSwitchAPIWrapper = IotSwitchAPIWrapper() - def __init__(self, client: mqtt.Client, rule_table: List[Dict], **kwargs): + def __init__( + self, + llm: BaseLLM = None, + client: mqtt.Client = None, + rule_table: List[Dict] = None, + **kwargs + ): + self.llm: BaseLLM = llm or ChatOpenAI( + temperature=0.1, enable_preset_description=False + ) self.client = client self.rule_table = rule_table diff --git a/promptulate/tools/math/tools.py b/promptulate/tools/math/tools.py index 62360224..83ab7c6c 100644 --- a/promptulate/tools/math/tools.py +++ b/promptulate/tools/math/tools.py @@ -19,7 +19,12 @@ class Calculator(Tool): "language of math expression." ) llm_prompt_template: StringTemplate = prompt_template - llm: BaseLLM = ChatOpenAI(temperature=0, enable_preset_description=False) + + def __init__(self, llm: BaseLLM = None, **kwargs): + self.llm: BaseLLM = llm or ChatOpenAI( + temperature=0, enable_preset_description=False + ) + super().__init__(**kwargs) def _run(self, prompt: str) -> str: prompt = self.llm_prompt_template.format(question=prompt) diff --git a/tests/tools/test_human_feedback_toos.py b/tests/tools/test_human_feedback_toos.py index 70375436..bd358b2c 100644 --- a/tests/tools/test_human_feedback_toos.py +++ b/tests/tools/test_human_feedback_toos.py @@ -4,7 +4,15 @@ class TestHumanFeedBackTool(TestCase): + def prompt_func(self, content: str) -> None: + print(content) + + def input_func(self): + return input() + def test_run(self): - tool = HumanFeedBackTool() + tool = HumanFeedBackTool( + prompt_func=self.prompt_func, input_func=self.input_func + ) result = tool.run("我好冷") print(result) diff --git a/tests/tools/test_sleep_tools.py b/tests/tools/test_sleep_tools.py index 2ae213b1..9af64701 100644 --- a/tests/tools/test_sleep_tools.py +++ b/tests/tools/test_sleep_tools.py @@ -12,7 +12,7 @@ class TestSleepTool(TestCase): def test_run(self): tool = SleepTool() - seconds = 1 + seconds = "1s" start_time = time.time() result = tool.run(seconds) duration = time.time() - start_time