diff --git a/chatmodel-springai/src/main/java/com/example/ai/Application.java b/chatmodel-springai/src/main/java/com/example/ai/Application.java index 073fdc6..d9efcdb 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/Application.java +++ b/chatmodel-springai/src/main/java/com/example/ai/Application.java @@ -9,4 +9,4 @@ public class Application { public static void main(String[] args) { SpringApplication.run(Application.class, args); } -} \ No newline at end of file +} diff --git a/chatmodel-springai/src/main/java/com/example/ai/config/CustomClientHttpResponse.java b/chatmodel-springai/src/main/java/com/example/ai/config/CustomClientHttpResponse.java index 805cbf5..f050d4a 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/config/CustomClientHttpResponse.java +++ b/chatmodel-springai/src/main/java/com/example/ai/config/CustomClientHttpResponse.java @@ -1,49 +1,47 @@ -package com.example.ai.config; - -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpStatusCode; -import org.springframework.http.MediaType; -import org.springframework.http.client.ClientHttpResponse; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; - -import java.io.IOException; -import java.io.InputStream; -import java.util.Collections; - -public class CustomClientHttpResponse implements ClientHttpResponse { - - private final ClientHttpResponse originalResponse; - private final HttpHeaders headers; - public CustomClientHttpResponse(ClientHttpResponse originalResponse) { - this.originalResponse = originalResponse; - MultiValueMap modifiedHeaders = new LinkedMultiValueMap<>(originalResponse.getHeaders()); - modifiedHeaders.put(HttpHeaders.CONTENT_TYPE, Collections.singletonList(MediaType.APPLICATION_JSON_VALUE)); - this.headers = new HttpHeaders(modifiedHeaders); - } - - @Override - public HttpStatusCode getStatusCode() throws IOException { - return originalResponse.getStatusCode(); - } - - @Override - public String getStatusText() throws IOException { - return originalResponse.getStatusText(); - } - - @Override - public void close() { - - } - - @Override - public InputStream getBody() throws IOException { - return originalResponse.getBody(); - } - - @Override - public HttpHeaders getHeaders() { - return headers; - } -} +package com.example.ai.config; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Collections; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatusCode; +import org.springframework.http.MediaType; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +public class CustomClientHttpResponse implements ClientHttpResponse { + + private final ClientHttpResponse originalResponse; + private final HttpHeaders headers; + + public CustomClientHttpResponse(ClientHttpResponse originalResponse) { + this.originalResponse = originalResponse; + MultiValueMap modifiedHeaders = new LinkedMultiValueMap<>(originalResponse.getHeaders()); + modifiedHeaders.put(HttpHeaders.CONTENT_TYPE, Collections.singletonList(MediaType.APPLICATION_JSON_VALUE)); + this.headers = new HttpHeaders(modifiedHeaders); + } + + @Override + public HttpStatusCode getStatusCode() throws IOException { + return originalResponse.getStatusCode(); + } + + @Override + public String getStatusText() throws IOException { + return originalResponse.getStatusText(); + } + + @Override + public void close() {} + + @Override + public InputStream getBody() throws IOException { + return originalResponse.getBody(); + } + + @Override + public HttpHeaders getHeaders() { + return headers; + } +} diff --git a/chatmodel-springai/src/main/java/com/example/ai/config/LoggingConfig.java b/chatmodel-springai/src/main/java/com/example/ai/config/LoggingConfig.java index d5116c3..e3d18f7 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/config/LoggingConfig.java +++ b/chatmodel-springai/src/main/java/com/example/ai/config/LoggingConfig.java @@ -1,60 +1,60 @@ -package com.example.ai.config; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.http.HttpRequest; -import org.springframework.http.MediaType; -import org.springframework.http.client.BufferingClientHttpRequestFactory; -import org.springframework.http.client.ClientHttpResponse; -import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; -import org.springframework.util.StreamUtils; -import org.springframework.web.client.RestClient; - -import java.io.IOException; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; -import java.util.List; - -@Configuration(proxyBeanMethods = false) -@ConditionalOnProperty(value = "spring.ai.openai.api-key", havingValue = "demo") -public class LoggingConfig { - - private final Logger log = LoggerFactory.getLogger(LoggingConfig.class); - - @Bean - RestClient.Builder restClientBuilder() { - return RestClient.builder().requestFactory(new BufferingClientHttpRequestFactory(new HttpComponentsClientHttpRequestFactory())) - .requestInterceptor((request, body, execution) -> { - logRequest(request, body); - ClientHttpResponse response = execution.execute(request, body); - logResponse(response); - return new CustomClientHttpResponse(response); - }).defaultHeaders(httpHeaders -> { - httpHeaders.setContentType(MediaType.APPLICATION_JSON); - httpHeaders.setAccept(List.of(MediaType.ALL)); - }); - } - - private void logResponse(ClientHttpResponse response) throws IOException { - log.info("============================response begin=========================================="); - log.info("Status code : {}", response.getStatusCode()); - log.info("Status text : {}", response.getStatusText()); - log.info("Headers : {}", response.getHeaders()); - log.info("Response body: {}", StreamUtils.copyToString(response.getBody(), Charset.defaultCharset())); - log.info("=======================response end================================================="); - } - - private void logRequest(HttpRequest request, byte[] body) { - - log.info("===========================request begin================================================"); - log.info("URI : {}", request.getURI()); - log.info("Method : {}", request.getMethod()); - log.info("Headers : {}", request.getHeaders()); - log.info("Request body: {}", new String(body, StandardCharsets.UTF_8)); - log.info("==========================request end================================================"); - - } -} +package com.example.ai.config; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpRequest; +import org.springframework.http.MediaType; +import org.springframework.http.client.BufferingClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.util.StreamUtils; +import org.springframework.web.client.RestClient; + +@Configuration(proxyBeanMethods = false) +@ConditionalOnProperty(value = "spring.ai.openai.api-key", havingValue = "demo") +public class LoggingConfig { + + private final Logger log = LoggerFactory.getLogger(LoggingConfig.class); + + @Bean + RestClient.Builder restClientBuilder() { + return RestClient.builder() + .requestFactory(new BufferingClientHttpRequestFactory(new HttpComponentsClientHttpRequestFactory())) + .requestInterceptor((request, body, execution) -> { + logRequest(request, body); + ClientHttpResponse response = execution.execute(request, body); + logResponse(response); + return new CustomClientHttpResponse(response); + }) + .defaultHeaders(httpHeaders -> { + httpHeaders.setContentType(MediaType.APPLICATION_JSON); + httpHeaders.setAccept(List.of(MediaType.ALL)); + }); + } + + private void logResponse(ClientHttpResponse response) throws IOException { + log.info("============================response begin=========================================="); + log.info("Status code : {}", response.getStatusCode()); + log.info("Status text : {}", response.getStatusText()); + log.info("Headers : {}", response.getHeaders()); + log.info("Response body: {}", StreamUtils.copyToString(response.getBody(), Charset.defaultCharset())); + log.info("=======================response end================================================="); + } + + private void logRequest(HttpRequest request, byte[] body) { + + log.info("===========================request begin================================================"); + log.info("URI : {}", request.getURI()); + log.info("Method : {}", request.getMethod()); + log.info("Headers : {}", request.getHeaders()); + log.info("Request body: {}", new String(body, StandardCharsets.UTF_8)); + log.info("==========================request end================================================"); + } +} diff --git a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java index eb30a21..15b1d52 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java +++ b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java @@ -1,6 +1,8 @@ package com.example.ai.controller; import com.example.ai.model.response.AIChatResponse; +import java.util.List; +import java.util.Map; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.SystemMessage; @@ -12,9 +14,6 @@ import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; -import java.util.List; -import java.util.Map; - @RestController @RequestMapping("/api/ai") public class ChatController { @@ -26,7 +25,7 @@ public class ChatController { } @GetMapping("/chat") - Map chat(@RequestParam String question) { + Map chat(@RequestParam String question) { var response = chatClient.call(question); return Map.of("question", question, "answer", response); } diff --git a/chatmodel-springai/src/main/java/com/example/ai/model/response/AIChatResponse.java b/chatmodel-springai/src/main/java/com/example/ai/model/response/AIChatResponse.java index e2f0fde..c3a9031 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/model/response/AIChatResponse.java +++ b/chatmodel-springai/src/main/java/com/example/ai/model/response/AIChatResponse.java @@ -1,4 +1,3 @@ -package com.example.ai.model.response; - -public record AIChatResponse(String answer) { -} +package com.example.ai.model.response; + +public record AIChatResponse(String answer) {} diff --git a/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java b/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java index aa1ee79..1ef927a 100644 --- a/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java +++ b/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java @@ -1,21 +1,56 @@ package com.example.ai.controller; -import org.hamcrest.Matchers; -import org.junit.jupiter.api.Test; +import static io.restassured.RestAssured.given; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.containsStringIgnoringCase; import io.restassured.RestAssured; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.web.server.LocalServerPort; + +@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class ChatControllerTest { -public class ChatControllerTest { + @LocalServerPort + private int localServerPort; + + @BeforeAll + void setUp() { + RestAssured.port = localServerPort; + } @Test void testChat() { - RestAssured.given() - .param("question", "Hello?") - .when() - .get("http://localhost:8080/api/ai/chat") - .then() - .statusCode(200) - .body("question", Matchers.equalTo("Hello?")) - .body("answer", Matchers.equalTo("Hi!")); + given().param("question", "Hello?") + .when() + .get("/api/ai/chat") + .then() + .statusCode(200) + .body("question", containsStringIgnoringCase("Hello?")) + .body("answer", containsString("Hello!")); + } + + @Test + void chatWithPrompt() { + given().param("subject", "java") + .when() + .get("/api/ai/chat-with-prompt") + .then() + .statusCode(200) + .body("answer", containsString("Java")); + } + + @Test + void chatWithSystemPrompt() { + given().param("subject", "cricket") + .when() + .get("/api/ai/chat-with-system-prompt") + .then() + .statusCode(200) + .body("answer", containsString("cricket")); } }