Skip to content

Commit

Permalink
test: add test of OutputFormatter
Browse files Browse the repository at this point in the history
  • Loading branch information
Undertone0809 committed Oct 25, 2023
1 parent c000ec6 commit 4cde9fc
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 4 deletions.
6 changes: 3 additions & 3 deletions example/output_formatter/output_formatter_with_agent_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
from promptulate.agents import WebAgent


class JSONResponse(BaseModel):
class Response(BaseModel):
city: str = Field(description="City name")
temperature: float = Field(description="Temperature in Celsius")


def main():
agent = WebAgent()
prompt = f"What is the temperature in Shanghai tomorrow?"
response: JSONResponse = agent.run(prompt=prompt, output_schema=JSONResponse)
print(response)
response: Response = agent.run(prompt=prompt, output_schema=Response)
print(response.city, response.temperature)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

setuptools.setup(
name="promptulate",
version="1.8.1",
version="1.8.2",
author="Zeeland",
author_email="[email protected]",
description="A powerful LLM Application development framework.",
Expand Down
62 changes: 62 additions & 0 deletions tests/output_formatter/test_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Optional

from pydantic import BaseModel, Field

from promptulate.agents import BaseAgent
from promptulate.llms import BaseLLM
from promptulate.output_formatter import OutputFormatter
from promptulate.schema import MessageSet, BaseMessage


class LLMForTest(BaseLLM):
llm_type: str = "custom_llm"

def _predict(self, prompts: MessageSet, *args, **kwargs) -> Optional[BaseMessage]:
pass

def __call__(self, *args, **kwargs):
return """## Output
```json
{
"city": "Shanghai",
"temperature": 25
}
```"""


class AgentForTest(BaseAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.llm = LLMForTest()

def get_llm(self) -> BaseLLM:
return self.llm

def _run(self, prompt: str, *args, **kwargs) -> str:
return ""


class Response(BaseModel):
city: str = Field(description="City name")
temperature: float = Field(description="Temperature in Celsius")


def test_formatter_with_llm():
llm = LLMForTest()
formatter = OutputFormatter(Response)

prompt = f"What is the temperature in Shanghai tomorrow? \n{formatter.get_formatted_instructions()}"
llm_output = llm(prompt)
response: Response = formatter.formatting_result(llm_output)
assert isinstance(response, Response)
assert isinstance(response.city, str)
assert isinstance(response.temperature, float)


def test_formatter_with_agent():
agent = AgentForTest()
prompt = f"What is the temperature in Shanghai tomorrow?"
response: Response = agent.run(prompt=prompt, output_schema=Response)
assert isinstance(response, Response)
assert isinstance(response.city, str)
assert isinstance(response.temperature, float)

0 comments on commit 4cde9fc

Please sign in to comment.