Skip to content

Commit

Permalink
feat : adds integration tests for all endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
rajadilipkolli committed Mar 27, 2024
1 parent 14358b4 commit 760af11
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ public class Application {
public static void main(String[] args) {
SpringApplication.run(Application.class, args);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,49 +1,47 @@
package com.example.ai.config;

import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

import java.io.IOException;
import java.io.InputStream;
import java.util.Collections;

public class CustomClientHttpResponse implements ClientHttpResponse {

private final ClientHttpResponse originalResponse;
private final HttpHeaders headers;
public CustomClientHttpResponse(ClientHttpResponse originalResponse) {
this.originalResponse = originalResponse;
MultiValueMap<String, String> modifiedHeaders = new LinkedMultiValueMap<>(originalResponse.getHeaders());
modifiedHeaders.put(HttpHeaders.CONTENT_TYPE, Collections.singletonList(MediaType.APPLICATION_JSON_VALUE));
this.headers = new HttpHeaders(modifiedHeaders);
}

@Override
public HttpStatusCode getStatusCode() throws IOException {
return originalResponse.getStatusCode();
}

@Override
public String getStatusText() throws IOException {
return originalResponse.getStatusText();
}

@Override
public void close() {

}

@Override
public InputStream getBody() throws IOException {
return originalResponse.getBody();
}

@Override
public HttpHeaders getHeaders() {
return headers;
}
}
package com.example.ai.config;

import java.io.IOException;
import java.io.InputStream;
import java.util.Collections;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

public class CustomClientHttpResponse implements ClientHttpResponse {

private final ClientHttpResponse originalResponse;
private final HttpHeaders headers;

public CustomClientHttpResponse(ClientHttpResponse originalResponse) {
this.originalResponse = originalResponse;
MultiValueMap<String, String> modifiedHeaders = new LinkedMultiValueMap<>(originalResponse.getHeaders());
modifiedHeaders.put(HttpHeaders.CONTENT_TYPE, Collections.singletonList(MediaType.APPLICATION_JSON_VALUE));
this.headers = new HttpHeaders(modifiedHeaders);
}

@Override
public HttpStatusCode getStatusCode() throws IOException {
return originalResponse.getStatusCode();
}

@Override
public String getStatusText() throws IOException {
return originalResponse.getStatusText();
}

@Override
public void close() {}

@Override
public InputStream getBody() throws IOException {
return originalResponse.getBody();
}

@Override
public HttpHeaders getHeaders() {
return headers;
}
}
Original file line number Diff line number Diff line change
@@ -1,60 +1,60 @@
package com.example.ai.config;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpRequest;
import org.springframework.http.MediaType;
import org.springframework.http.client.BufferingClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.util.StreamUtils;
import org.springframework.web.client.RestClient;

import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.List;

@Configuration(proxyBeanMethods = false)
@ConditionalOnProperty(value = "spring.ai.openai.api-key", havingValue = "demo")
public class LoggingConfig {

private final Logger log = LoggerFactory.getLogger(LoggingConfig.class);

@Bean
RestClient.Builder restClientBuilder() {
return RestClient.builder().requestFactory(new BufferingClientHttpRequestFactory(new HttpComponentsClientHttpRequestFactory()))
.requestInterceptor((request, body, execution) -> {
logRequest(request, body);
ClientHttpResponse response = execution.execute(request, body);
logResponse(response);
return new CustomClientHttpResponse(response);
}).defaultHeaders(httpHeaders -> {
httpHeaders.setContentType(MediaType.APPLICATION_JSON);
httpHeaders.setAccept(List.of(MediaType.ALL));
});
}

private void logResponse(ClientHttpResponse response) throws IOException {
log.info("============================response begin==========================================");
log.info("Status code : {}", response.getStatusCode());
log.info("Status text : {}", response.getStatusText());
log.info("Headers : {}", response.getHeaders());
log.info("Response body: {}", StreamUtils.copyToString(response.getBody(), Charset.defaultCharset()));
log.info("=======================response end=================================================");
}

private void logRequest(HttpRequest request, byte[] body) {

log.info("===========================request begin================================================");
log.info("URI : {}", request.getURI());
log.info("Method : {}", request.getMethod());
log.info("Headers : {}", request.getHeaders());
log.info("Request body: {}", new String(body, StandardCharsets.UTF_8));
log.info("==========================request end================================================");

}
}
package com.example.ai.config;

import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpRequest;
import org.springframework.http.MediaType;
import org.springframework.http.client.BufferingClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.util.StreamUtils;
import org.springframework.web.client.RestClient;

@Configuration(proxyBeanMethods = false)
@ConditionalOnProperty(value = "spring.ai.openai.api-key", havingValue = "demo")
public class LoggingConfig {

private final Logger log = LoggerFactory.getLogger(LoggingConfig.class);

@Bean
RestClient.Builder restClientBuilder() {
return RestClient.builder()
.requestFactory(new BufferingClientHttpRequestFactory(new HttpComponentsClientHttpRequestFactory()))
.requestInterceptor((request, body, execution) -> {
logRequest(request, body);
ClientHttpResponse response = execution.execute(request, body);
logResponse(response);
return new CustomClientHttpResponse(response);
})
.defaultHeaders(httpHeaders -> {
httpHeaders.setContentType(MediaType.APPLICATION_JSON);
httpHeaders.setAccept(List.of(MediaType.ALL));
});
}

private void logResponse(ClientHttpResponse response) throws IOException {
log.info("============================response begin==========================================");
log.info("Status code : {}", response.getStatusCode());
log.info("Status text : {}", response.getStatusText());
log.info("Headers : {}", response.getHeaders());
log.info("Response body: {}", StreamUtils.copyToString(response.getBody(), Charset.defaultCharset()));
log.info("=======================response end=================================================");
}

private void logRequest(HttpRequest request, byte[] body) {

log.info("===========================request begin================================================");
log.info("URI : {}", request.getURI());
log.info("Method : {}", request.getMethod());
log.info("Headers : {}", request.getHeaders());
log.info("Request body: {}", new String(body, StandardCharsets.UTF_8));
log.info("==========================request end================================================");
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.example.ai.controller;

import com.example.ai.model.response.AIChatResponse;
import java.util.List;
import java.util.Map;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.SystemMessage;
Expand All @@ -12,9 +14,6 @@
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.util.List;
import java.util.Map;

@RestController
@RequestMapping("/api/ai")
public class ChatController {
Expand All @@ -26,7 +25,7 @@ public class ChatController {
}

@GetMapping("/chat")
Map<String,String> chat(@RequestParam String question) {
Map<String, String> chat(@RequestParam String question) {
var response = chatClient.call(question);
return Map.of("question", question, "answer", response);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
package com.example.ai.model.response;

public record AIChatResponse(String answer) {
}
package com.example.ai.model.response;

public record AIChatResponse(String answer) {}
Original file line number Diff line number Diff line change
@@ -1,21 +1,56 @@
package com.example.ai.controller;

import org.hamcrest.Matchers;
import org.junit.jupiter.api.Test;
import static io.restassured.RestAssured.given;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.containsStringIgnoringCase;

import io.restassured.RestAssured;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.server.LocalServerPort;

@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class ChatControllerTest {

public class ChatControllerTest {
@LocalServerPort
private int localServerPort;

@BeforeAll
void setUp() {
RestAssured.port = localServerPort;
}

@Test
void testChat() {
RestAssured.given()
.param("question", "Hello?")
.when()
.get("http://localhost:8080/api/ai/chat")
.then()
.statusCode(200)
.body("question", Matchers.equalTo("Hello?"))
.body("answer", Matchers.equalTo("Hi!"));
given().param("question", "Hello?")
.when()
.get("/api/ai/chat")
.then()
.statusCode(200)
.body("question", containsStringIgnoringCase("Hello?"))
.body("answer", containsString("Hello!"));
}

@Test
void chatWithPrompt() {
given().param("subject", "java")
.when()
.get("/api/ai/chat-with-prompt")
.then()
.statusCode(200)
.body("answer", containsString("Java"));
}

@Test
void chatWithSystemPrompt() {
given().param("subject", "cricket")
.when()
.get("/api/ai/chat-with-system-prompt")
.then()
.statusCode(200)
.body("answer", containsString("cricket"));
}
}

0 comments on commit 760af11

Please sign in to comment.