Skip to content

Commit

Permalink
Merge pull request #108 from Undertone0809/feat-v1.9.0
Browse files Browse the repository at this point in the history
fix: output_formatter package error
  • Loading branch information
Undertone0809 authored Oct 25, 2023
2 parents d98f1b0 + 4cde9fc commit 775dba6
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 15 deletions.
6 changes: 3 additions & 3 deletions example/output_formatter/output_formatter_with_agent_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
from promptulate.agents import WebAgent


class JSONResponse(BaseModel):
class Response(BaseModel):
city: str = Field(description="City name")
temperature: float = Field(description="Temperature in Celsius")


def main():
agent = WebAgent()
prompt = f"What is the temperature in Shanghai tomorrow?"
response: JSONResponse = agent.run(prompt=prompt, output_schema=JSONResponse)
print(response)
response: Response = agent.run(prompt=prompt, output_schema=Response)
print(response.city, response.temperature)


if __name__ == "__main__":
Expand Down
30 changes: 25 additions & 5 deletions promptulate/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from abc import ABC, abstractmethod
from typing import List, Callable
from typing import List, Callable, Any

from pydantic import BaseModel

from promptulate.hook import Hook, HookTable
from promptulate.llms import BaseLLM
from promptulate.output_formatter import OutputFormatter


class BaseAgent(ABC):
Expand All @@ -13,13 +17,29 @@ def __init__(self, hooks: List[Callable] = None, *args, **kwargs):
Hook.mount_instance_hook(hook, self)
Hook.call_hook(HookTable.ON_AGENT_CREATE, self, *args, **kwargs)

def run(self, *args, **kwargs):
def run(
self, prompt: str, output_schema: type(BaseModel) = None, *args, **kwargs
) -> Any:
"""run the tool including specified function and hooks"""
Hook.call_hook(HookTable.ON_AGENT_START, self, *args, **kwargs)
result: str = self._run(*args, **kwargs)
Hook.call_hook(HookTable.ON_AGENT_START, self, prompt, *args, **kwargs)
result: str = self._run(prompt, *args, **kwargs)

# Return Pydantic instance if output_schema is specified
if output_schema:
formatter = OutputFormatter(output_schema)
prompt = (
f"{formatter.get_formatted_instructions()}\n##User input:\n{result}"
)
json_response = self.get_llm()(prompt)
return formatter.formatting_result(json_response)

Hook.call_hook(HookTable.ON_AGENT_RESULT, self, result=result)
return result

@abstractmethod
def _run(self, *args, **kwargs) -> str:
def _run(self, prompt: str, *args, **kwargs) -> str:
"""Run the detail agent, implemented by subclass."""

@abstractmethod
def get_llm(self) -> BaseLLM:
"""Get the llm when necessary."""
2 changes: 1 addition & 1 deletion promptulate/agents/tool_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _build_system_prompt(self, prompt) -> str:
tool_names=self.tool_manager.tool_names,
)

def _run(self, prompt: str) -> str:
def _run(self, prompt: str, *args, **kwargs) -> str:
self.conversation_prompt = self._build_system_prompt(prompt)
logger.info(f"[pne] tool agent system prompt: {self.conversation_prompt}")

Expand Down
18 changes: 13 additions & 5 deletions promptulate/agents/web_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
logger = get_logger()


def _build_system_prompt(prompt) -> str:
"""Build the system prompt."""
return SYSTEM_PROMPT_TEMPLATE.format(prompt=prompt)


class WebAgent(BaseAgent):
def __init__(
self,
Expand All @@ -20,23 +25,26 @@ def __init__(
):
super().__init__(hooks, *args, **kwargs)
self.llm: BaseLLM = llm or ChatOpenAI(
model="gpt-3.5-turbo-16k", temperature=0.0, enable_default_system_prompt=False
model="gpt-3.5-turbo-16k",
temperature=0.0,
enable_default_system_prompt=False,
)
self.stop_sequences: List[str] = ["Observation"]
self.websearch = DuckDuckGoTool()
self.conversation_prompt: str = ""

def _build_system_prompt(self, prompt) -> str:
"""Build the system prompt."""
return SYSTEM_PROMPT_TEMPLATE.format(prompt=prompt)
def get_llm(self) -> BaseLLM:
return self.llm

def _run(self, prompt: str, *args, **kwargs) -> str:
# ErnieBot built-in network search
if self.llm.llm_type == "ErnieBot":
return self.llm(prompt)

self.conversation_prompt = self._build_system_prompt(prompt)
self.conversation_prompt = _build_system_prompt(prompt)
iterations = 0

# Loop search until find the answer
while True:
answer: str = self.llm(self.conversation_prompt, stop=self.stop_sequences)
logger.info(
Expand Down
4 changes: 4 additions & 0 deletions promptulate/output_formatter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from promptulate.output_formatter.formatter import OutputFormatter


__all__ = ["OutputFormatter"]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

setuptools.setup(
name="promptulate",
version="1.8.1",
version="1.8.2",
author="Zeeland",
author_email="[email protected]",
description="A powerful LLM Application development framework.",
Expand Down
62 changes: 62 additions & 0 deletions tests/output_formatter/test_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Optional

from pydantic import BaseModel, Field

from promptulate.agents import BaseAgent
from promptulate.llms import BaseLLM
from promptulate.output_formatter import OutputFormatter
from promptulate.schema import MessageSet, BaseMessage


class LLMForTest(BaseLLM):
llm_type: str = "custom_llm"

def _predict(self, prompts: MessageSet, *args, **kwargs) -> Optional[BaseMessage]:
pass

def __call__(self, *args, **kwargs):
return """## Output
```json
{
"city": "Shanghai",
"temperature": 25
}
```"""


class AgentForTest(BaseAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.llm = LLMForTest()

def get_llm(self) -> BaseLLM:
return self.llm

def _run(self, prompt: str, *args, **kwargs) -> str:
return ""


class Response(BaseModel):
city: str = Field(description="City name")
temperature: float = Field(description="Temperature in Celsius")


def test_formatter_with_llm():
llm = LLMForTest()
formatter = OutputFormatter(Response)

prompt = f"What is the temperature in Shanghai tomorrow? \n{formatter.get_formatted_instructions()}"
llm_output = llm(prompt)
response: Response = formatter.formatting_result(llm_output)
assert isinstance(response, Response)
assert isinstance(response.city, str)
assert isinstance(response.temperature, float)


def test_formatter_with_agent():
agent = AgentForTest()
prompt = f"What is the temperature in Shanghai tomorrow?"
response: Response = agent.run(prompt=prompt, output_schema=Response)
assert isinstance(response, Response)
assert isinstance(response.city, str)
assert isinstance(response.temperature, float)

0 comments on commit 775dba6

Please sign in to comment.