Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : polish spring chat model #36

Merged
merged 4 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/rag-springai-ollama-llm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
matrix:
distribution: [ 'temurin' ]
java: [ '21' ]
os: [ubuntu-latest, macos-latest, windows-latest]
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v4
with:
Expand Down
24 changes: 24 additions & 0 deletions chatmodel-springai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
<properties>
<java.version>17</java.version>
<spring-ai.version>0.8.1</spring-ai.version>
<spotless.version>2.43.0</spotless.version>
</properties>

<dependencies>
Expand Down Expand Up @@ -72,6 +73,29 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>com.diffplug.spotless</groupId>
<artifactId>spotless-maven-plugin</artifactId>
<version>${spotless.version}</version>
<configuration>
<java>
<palantirJavaFormat>
<version>2.40.0</version>
</palantirJavaFormat>
<importOrder />
<removeUnusedImports />
<formatAnnotations />
</java>
</configuration>
<executions>
<execution>
<phase>compile</phase>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.web.client.RestClientCustomizer;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpRequest;
Expand All @@ -15,7 +16,6 @@
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.util.StreamUtils;
import org.springframework.web.client.RestClient;

@Configuration(proxyBeanMethods = false)
@ConditionalOnProperty(value = "spring.ai.openai.api-key", havingValue = "demo")
Expand All @@ -24,8 +24,8 @@ public class LoggingConfig {
private static final Logger LOGGER = LoggerFactory.getLogger(LoggingConfig.class);

@Bean
RestClient.Builder restClientBuilder() {
return RestClient.builder()
public RestClientCustomizer restClientCustomizer() {
return restClientBuilder -> restClientBuilder
.requestFactory(new BufferingClientHttpRequestFactory(new HttpComponentsClientHttpRequestFactory()))
.requestInterceptor((request, body, execution) -> {
logRequest(request, body);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
package com.example.ai.controller;

import com.example.ai.model.request.AIChatRequest;
import com.example.ai.model.response.AIChatResponse;
import java.util.List;
import java.util.Map;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

@RestController
Expand All @@ -20,32 +23,42 @@ public class ChatController {

private final ChatClient chatClient;

ChatController(ChatClient chatClient) {
private final EmbeddingClient embeddingClient;

ChatController(ChatClient chatClient, EmbeddingClient embeddingClient) {
this.chatClient = chatClient;
this.embeddingClient = embeddingClient;
}

@GetMapping("/chat")
Map<String, String> chat(@RequestParam String question) {
var response = chatClient.call(question);
return Map.of("question", question, "answer", response);
@PostMapping("/chat")
AIChatResponse chat(@RequestBody AIChatRequest aiChatRequest) {
var answer = chatClient.call(aiChatRequest.query());
return new AIChatResponse(answer);
}

@GetMapping("/chat-with-prompt")
AIChatResponse chatWithPrompt(@RequestParam String subject) {
@PostMapping("/chat-with-prompt")
AIChatResponse chatWithPrompt(@RequestBody AIChatRequest aiChatRequest) {
PromptTemplate promptTemplate = new PromptTemplate("Tell me a joke about {subject}");
Prompt prompt = promptTemplate.create(Map.of("subject", subject));
Prompt prompt = promptTemplate.create(Map.of("subject", aiChatRequest.query()));
ChatResponse response = chatClient.call(prompt);
String answer = response.getResult().getOutput().getContent();
Generation generation = response.getResult();
String answer = (generation != null) ? generation.getOutput().getContent() : "";
return new AIChatResponse(answer);
}

@GetMapping("/chat-with-system-prompt")
AIChatResponse chatWithSystemPrompt(@RequestParam String subject) {
@PostMapping("/chat-with-system-prompt")
AIChatResponse chatWithSystemPrompt(@RequestBody AIChatRequest aiChatRequest) {
SystemMessage systemMessage = new SystemMessage("You are a sarcastic and funny chatbot");
UserMessage userMessage = new UserMessage("Tell me a joke about " + subject);
UserMessage userMessage = new UserMessage("Tell me a joke about " + aiChatRequest.query());
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse response = chatClient.call(prompt);
String answer = response.getResult().getOutput().getContent();
return new AIChatResponse(answer);
}

@PostMapping("/emebedding-client-conversion")
AIChatResponse chatWithEmbeddingClient(@RequestBody AIChatRequest aiChatRequest) {
List<Double> embed = embeddingClient.embed(aiChatRequest.query());
return new AIChatResponse(embed.toString());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package com.example.ai.model.request;

public record AIChatRequest(String query) {}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ spring.ai.openai.chat.options.model=gpt-3.5-turbo
spring.ai.openai.chat.options.temperature=0.2
spring.ai.openai.chat.options.responseFormat=json_object

spring.ai.openai.embedding.enabled=false
spring.ai.openai.embedding.enabled=true

##logging
logging.level.org.apache.hc.client5.http=INFO
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import static io.restassured.RestAssured.given;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.containsStringIgnoringCase;

import com.example.ai.model.request.AIChatRequest;
import io.restassured.RestAssured;
import io.restassured.http.ContentType;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
Expand All @@ -25,30 +26,32 @@ void setUp() {

@Test
void testChat() {
given().param("question", "Hello?")
given().contentType(ContentType.JSON)
.body(new AIChatRequest("Hello?"))
.when()
.get("/api/ai/chat")
.post("/api/ai/chat")
.then()
.statusCode(200)
.body("question", containsStringIgnoringCase("Hello?"))
.body("answer", containsString("Hello!"));
}

@Test
void chatWithPrompt() {
given().param("subject", "java")
given().contentType(ContentType.JSON)
.body(new AIChatRequest("java"))
.when()
.get("/api/ai/chat-with-prompt")
.post("/api/ai/chat-with-prompt")
.then()
.statusCode(200)
.body("answer", containsString("Java"));
}

@Test
void chatWithSystemPrompt() {
given().param("subject", "cricket")
given().contentType(ContentType.JSON)
.body(new AIChatRequest("cricket"))
.when()
.get("/api/ai/chat-with-system-prompt")
.post("/api/ai/chat-with-system-prompt")
.then()
.statusCode(200)
.body("answer", containsString("cricket"));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
package com.learning.ai.llmragwithspringai.config;

import java.time.Duration;
import org.springframework.boot.web.client.ClientHttpRequestFactories;
import org.springframework.boot.web.client.ClientHttpRequestFactorySettings;
import org.springframework.boot.web.client.RestClientCustomizer;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.client.JdkClientHttpRequestFactory;
import org.springframework.web.client.RestClient;

@Configuration(proxyBeanMethods = false)
public class RestClientBuilderConfig {

@Bean
RestClient.Builder restClientBuilder(JdkClientHttpRequestFactory jdkClientHttpRequestFactory) {
return RestClient.builder().requestFactory(jdkClientHttpRequestFactory);
}

@Bean
JdkClientHttpRequestFactory jdkClientHttpRequestFactory() {
JdkClientHttpRequestFactory jdkClientHttpRequestFactory = new JdkClientHttpRequestFactory();
jdkClientHttpRequestFactory.setReadTimeout(Duration.ofMinutes(5));
return jdkClientHttpRequestFactory;
public RestClientCustomizer restClientCustomizer() {
return restClientBuilder -> restClientBuilder.requestFactory(
ClientHttpRequestFactories.get(ClientHttpRequestFactorySettings.DEFAULTS
.withConnectTimeout(Duration.ofSeconds(60))
.withReadTimeout(Duration.ofMinutes(5))));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
Expand Down Expand Up @@ -64,6 +65,7 @@ public String chat(String query) {
LOGGER.info("Calling ai with prompt :{}", prompt);
ChatResponse aiResponse = aiClient.call(prompt);
LOGGER.info("Response received from call :{}", aiResponse);
return aiResponse.getResult().getOutput().getContent();
Generation generation = aiResponse.getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.io.InputStream;
import java.util.Collections;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.web.client.RestClientCustomizer;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders;
Expand All @@ -12,15 +13,14 @@
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestClient;

@Configuration(proxyBeanMethods = false)
@ConditionalOnProperty(value = "spring.ai.openai.api-key", havingValue = "demo")
public class ResponseHeadersModification {

@Bean
RestClient.Builder restClientBuilder() {
return RestClient.builder().requestInterceptor((request, body, execution) -> {
public RestClientCustomizer restClientCustomizer() {
return restClientBuilder -> restClientBuilder.requestInterceptor((request, body, execution) -> {
ClientHttpResponse response = execution.execute(request, body);
return new CustomClientHttpResponse(response);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.util.stream.Collectors;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
Expand Down Expand Up @@ -55,6 +56,7 @@ public String chat(String searchQuery) {
UserMessage userMessage = new UserMessage(searchQuery);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse aiResponse = aiClient.call(prompt);
return aiResponse.getResult().getOutput().getContent();
Generation generation = aiResponse.getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}
}