diff --git a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java index a9ef21d..8812787 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java +++ b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java @@ -37,6 +37,11 @@ AIChatResponse chatWithSystemPrompt(@RequestBody AIChatRequest aiChatRequest) { return chatService.chatWithSystemPrompt(aiChatRequest.query()); } + @PostMapping("/sentiment/analyze") + AIChatResponse sentimentAnalyzer(@RequestBody AIChatRequest aiChatRequest) { + return chatService.analyzeSentiment(aiChatRequest.query()); + } + @PostMapping("/emebedding-client-conversion") AIChatResponse chatWithEmbeddingClient(@RequestBody AIChatRequest aiChatRequest) { return chatService.getEmbeddings(aiChatRequest.query()); diff --git a/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java b/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java index 7691f52..d659a8a 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java +++ b/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java @@ -15,6 +15,7 @@ import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.AssistantPromptTemplate; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.document.Document; @@ -30,6 +31,8 @@ public class ChatService { private static final Logger logger = LoggerFactory.getLogger(ChatService.class); + private static final String SENTIMENT_ANALYSIS_TEMPLATE = + "{query}, You must answer strictly in the following format: one of [POSITIVE, NEGATIVE, SARCASTIC]"; @Value("classpath:/data/restaurants.json") private Resource restaurantsResource; @@ -68,6 +71,15 @@ public AIChatResponse chatWithSystemPrompt(String query) { return new AIChatResponse(answer); } + public AIChatResponse analyzeSentiment(String query) { + AssistantPromptTemplate promptTemplate = new AssistantPromptTemplate(SENTIMENT_ANALYSIS_TEMPLATE); + Prompt prompt = promptTemplate.create(Map.of("query", query)); + ChatResponse response = chatClient.call(prompt); + Generation generation = response.getResult(); + String answer = (generation != null) ? generation.getOutput().getContent() : ""; + return new AIChatResponse(answer); + } + public AIChatResponse getEmbeddings(String query) { List embed = embeddingClient.embed(query); return new AIChatResponse(embed.toString()); diff --git a/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java b/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java index a58e818..194f9c7 100644 --- a/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java +++ b/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java @@ -64,6 +64,18 @@ void chatWithSystemPrompt() { .body("answer", containsString("cricket")); } + @Test + void sentimentAnalyzer() { + given().contentType(ContentType.JSON) + .body(new AIChatRequest("Why did the Python programmer go broke? Because he couldn't C#")) + .when() + .post("/api/ai/sentiment/analyze") + .then() + .statusCode(HttpStatus.SC_OK) + .contentType(ContentType.JSON) + .body("answer", is("SARCASTIC")); + } + @Test void outputParser() { given().param("actor", "Jr NTR")