From 4cde9fc6f0e2e77a895af59fe25ae4ade11f7e94 Mon Sep 17 00:00:00 2001 From: Zeeland Date: Thu, 26 Oct 2023 05:04:22 +0800 Subject: [PATCH] test: add test of OutputFormatter --- .../output_formatter_with_agent_usage.py | 6 +- setup.py | 2 +- tests/output_formatter/test_formatter.py | 62 +++++++++++++++++++ 3 files changed, 66 insertions(+), 4 deletions(-) create mode 100644 tests/output_formatter/test_formatter.py diff --git a/example/output_formatter/output_formatter_with_agent_usage.py b/example/output_formatter/output_formatter_with_agent_usage.py index 3306b7dd..10fa285a 100644 --- a/example/output_formatter/output_formatter_with_agent_usage.py +++ b/example/output_formatter/output_formatter_with_agent_usage.py @@ -4,7 +4,7 @@ from promptulate.agents import WebAgent -class JSONResponse(BaseModel): +class Response(BaseModel): city: str = Field(description="City name") temperature: float = Field(description="Temperature in Celsius") @@ -12,8 +12,8 @@ class JSONResponse(BaseModel): def main(): agent = WebAgent() prompt = f"What is the temperature in Shanghai tomorrow?" - response: JSONResponse = agent.run(prompt=prompt, output_schema=JSONResponse) - print(response) + response: Response = agent.run(prompt=prompt, output_schema=Response) + print(response.city, response.temperature) if __name__ == "__main__": diff --git a/setup.py b/setup.py index 75ecbcc9..a8462b55 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ setuptools.setup( name="promptulate", - version="1.8.1", + version="1.8.2", author="Zeeland", author_email="zeeland@foxmail.com", description="A powerful LLM Application development framework.", diff --git a/tests/output_formatter/test_formatter.py b/tests/output_formatter/test_formatter.py new file mode 100644 index 00000000..4cc92b85 --- /dev/null +++ b/tests/output_formatter/test_formatter.py @@ -0,0 +1,62 @@ +from typing import Optional + +from pydantic import BaseModel, Field + +from promptulate.agents import BaseAgent +from promptulate.llms import BaseLLM +from promptulate.output_formatter import OutputFormatter +from promptulate.schema import MessageSet, BaseMessage + + +class LLMForTest(BaseLLM): + llm_type: str = "custom_llm" + + def _predict(self, prompts: MessageSet, *args, **kwargs) -> Optional[BaseMessage]: + pass + + def __call__(self, *args, **kwargs): + return """## Output + ```json + { + "city": "Shanghai", + "temperature": 25 + } + ```""" + + +class AgentForTest(BaseAgent): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.llm = LLMForTest() + + def get_llm(self) -> BaseLLM: + return self.llm + + def _run(self, prompt: str, *args, **kwargs) -> str: + return "" + + +class Response(BaseModel): + city: str = Field(description="City name") + temperature: float = Field(description="Temperature in Celsius") + + +def test_formatter_with_llm(): + llm = LLMForTest() + formatter = OutputFormatter(Response) + + prompt = f"What is the temperature in Shanghai tomorrow? \n{formatter.get_formatted_instructions()}" + llm_output = llm(prompt) + response: Response = formatter.formatting_result(llm_output) + assert isinstance(response, Response) + assert isinstance(response.city, str) + assert isinstance(response.temperature, float) + + +def test_formatter_with_agent(): + agent = AgentForTest() + prompt = f"What is the temperature in Shanghai tomorrow?" + response: Response = agent.run(prompt=prompt, output_schema=Response) + assert isinstance(response, Response) + assert isinstance(response.city, str) + assert isinstance(response.temperature, float)