Skip to content

Commit

Permalink
feat : adds test case for embeddingClientConversion (#136)
Browse files Browse the repository at this point in the history
* feat : adds test case for embeddingClientConversion

* adds reasoning

* adds more assertions

* adds validation and more test cases

* modulerize the request

* adds more test cases

* re define prompt

* rename test case

* adds -ve test case
  • Loading branch information
rajadilipkolli authored Jan 2, 2025
1 parent 2e8664e commit dbd65ad
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 22 deletions.
4 changes: 4 additions & 0 deletions chatmodel-springai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-validation</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-webflux</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import com.example.ai.model.response.AIStreamChatResponse;
import com.example.ai.model.response.ActorsFilms;
import com.example.ai.service.ChatService;
import jakarta.validation.Valid;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
Expand All @@ -15,6 +17,7 @@

@RestController
@RequestMapping("/api/ai")
@Validated
public class ChatController {

private final ChatService chatService;
Expand All @@ -24,27 +27,27 @@ public class ChatController {
}

@PostMapping("/chat")
AIChatResponse chat(@RequestBody AIChatRequest aiChatRequest) {
AIChatResponse chat(@RequestBody @Valid AIChatRequest aiChatRequest) {
return chatService.chat(aiChatRequest.query());
}

@PostMapping("/chat-with-prompt")
AIChatResponse chatWithPrompt(@RequestBody AIChatRequest aiChatRequest) {
AIChatResponse chatWithPrompt(@RequestBody @Valid AIChatRequest aiChatRequest) {
return chatService.chatWithPrompt(aiChatRequest.query());
}

@PostMapping("/chat-with-system-prompt")
AIChatResponse chatWithSystemPrompt(@RequestBody AIChatRequest aiChatRequest) {
AIChatResponse chatWithSystemPrompt(@RequestBody @Valid AIChatRequest aiChatRequest) {
return chatService.chatWithSystemPrompt(aiChatRequest.query());
}

@PostMapping("/sentiment/analyze")
AIChatResponse sentimentAnalyzer(@RequestBody AIChatRequest aiChatRequest) {
AIChatResponse sentimentAnalyzer(@RequestBody @Valid AIChatRequest aiChatRequest) {
return chatService.analyzeSentiment(aiChatRequest.query());
}

@PostMapping("/emebedding-client-conversion")
AIChatResponse chatWithEmbeddingClient(@RequestBody AIChatRequest aiChatRequest) {
@PostMapping("/embedding-client-conversion")
AIChatResponse chatWithEmbeddingClient(@RequestBody @Valid AIChatRequest aiChatRequest) {
return chatService.getEmbeddings(aiChatRequest.query());
}

Expand All @@ -54,12 +57,12 @@ public ActorsFilms generate(@RequestParam(value = "actor", defaultValue = "Jr NT
}

@PostMapping("/rag")
AIChatResponse chatUsingRag(@RequestBody AIChatRequest aiChatRequest) {
AIChatResponse chatUsingRag(@RequestBody @Valid AIChatRequest aiChatRequest) {
return chatService.ragGenerate(aiChatRequest.query());
}

@PostMapping("/chat/stream")
AIStreamChatResponse streamChat(@RequestBody AIChatRequest aiChatRequest) {
AIStreamChatResponse streamChat(@RequestBody @Valid AIChatRequest aiChatRequest) {
Flux<String> streamChat = chatService.streamChat(aiChatRequest.query());
return new AIStreamChatResponse(streamChat);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
package com.example.ai.model.request;

public record AIChatRequest(String query) {}
import jakarta.validation.constraints.NotBlank;

public record AIChatRequest(@NotBlank(message = "Query cant be Blank") String query) {}
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ public ActorsFilms generateAsBean(String actor) {
BeanOutputConverter<ActorsFilms> outputParser = new BeanOutputConverter<>(ActorsFilms.class);

String format = outputParser.getFormat();
String template = """
Generate the filmography for the actor {actor}.
String template =
"""
Generate the filmography for the Indian actor {actor} as of today.
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("actor", actor, "format", format));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,34 @@
package com.example.ai.controller;

import static io.restassured.RestAssured.given;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.containsStringIgnoringCase;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;

import com.example.ai.model.request.AIChatRequest;
import io.restassured.RestAssured;
import io.restassured.http.ContentType;
import java.util.Arrays;
import java.util.stream.Stream;
import org.apache.http.HttpStatus;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.server.LocalServerPort;

@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class ChatControllerTest {

private static final int OPENAI_EMBEDDING_DIMENSION = 1536;

@LocalServerPort
private int localServerPort;

Expand All @@ -31,7 +40,7 @@ void setUp() {
@Test
void testChat() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("Hello?"))
.body(defaultChatRequest("Hello?"))
.when()
.post("/api/ai/chat")
.then()
Expand All @@ -41,21 +50,32 @@ void testChat() {
}

@Test
void chatWithPrompt() {
void shouldReturnBadRequestForMalformedChatRequest() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("java"))
.body("{}") // Empty or malformed request body
.when()
.post("/api/ai/chat")
.then()
.statusCode(HttpStatus.SC_BAD_REQUEST);
}

@ParameterizedTest
@MethodSource("chatPrompts")
void shouldChatWithMultiplePrompts(String prompt) {
given().contentType(ContentType.JSON)
.body(defaultChatRequest(prompt))
.when()
.post("/api/ai/chat-with-prompt")
.then()
.statusCode(HttpStatus.SC_OK)
.contentType(ContentType.JSON)
.body("answer", containsString("Java"));
.body("answer", containsStringIgnoringCase(prompt));
}

@Test
void chatWithSystemPrompt() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("cricket"))
.body(defaultChatRequest("cricket"))
.when()
.post("/api/ai/chat-with-system-prompt")
.then()
Expand All @@ -65,9 +85,19 @@ void chatWithSystemPrompt() {
}

@Test
void sentimentAnalyzer() {
void shouldHandleErrorGracefullyForSystemPrompt() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("Why did the Python programmer go broke? Because he couldn't C#"))
.body(defaultChatRequest(""))
.when()
.post("/api/ai/chat-with-system-prompt")
.then()
.statusCode(HttpStatus.SC_BAD_REQUEST);
}

@Test
void shouldAnalyzeSentimentAsSarcastic() {
given().contentType(ContentType.JSON)
.body(defaultChatRequest("Why did the Python programmer go broke? Because he couldn't C#"))
.when()
.post("/api/ai/sentiment/analyze")
.then()
Expand All @@ -76,10 +106,60 @@ void sentimentAnalyzer() {
.body("answer", is("SARCASTIC"));
}

@ParameterizedTest
@ValueSource(strings = {"This is a test sentence.", "Another different sentence.", "A third unique test case."})
void shouldGenerateValidEmbeddingsWithinExpectedRange(String input) {
String response = given().contentType(ContentType.JSON)
.body(defaultChatRequest(input))
.when()
.post("/api/ai/embedding-client-conversion")
.then()
.statusCode(HttpStatus.SC_OK)
.contentType(ContentType.JSON)
.extract()
.jsonPath()
.get("answer");

assertThat(response).isNotNull().startsWith("[").endsWith("]");

double[] doubles = Arrays.stream(response.replaceAll("[\\[\\]]", "").split(","))
.mapToDouble(Double::parseDouble)
.toArray();

assertThat(doubles.length)
.isEqualTo(OPENAI_EMBEDDING_DIMENSION)
.as("Dimensions for openai model is %d", OPENAI_EMBEDDING_DIMENSION);

assertThat(Arrays.stream(doubles).allMatch(value -> value >= -1.0 && value <= 1.0))
.isTrue()
.as("All embedding values should be between -1.0 and 1.0");
}

@Test
void shouldHandleErrorCasesGracefully() {
given().contentType(ContentType.JSON)
.body(defaultChatRequest(""))
.when()
.post("/api/ai/embedding-client-conversion")
.then()
.statusCode(HttpStatus.SC_BAD_REQUEST);
}

@Test
void outputParser() {
given().param("actor", "Jr NTR")
void outputParserWithParam() {
given().param("actor", "BalaKrishna")
.when()
.get("/api/ai/output")
.then()
.statusCode(HttpStatus.SC_OK)
.contentType(ContentType.JSON)
.body("actor", is("BalaKrishna"))
.body("movies", hasSize(greaterThanOrEqualTo(10)));
}

@Test
void outputParserDefaultParam() {
given().when()
.get("/api/ai/output")
.then()
.statusCode(HttpStatus.SC_OK)
Expand All @@ -89,9 +169,9 @@ void outputParser() {
}

@Test
void ragWithSimpleStore() {
void testRagWithSimpleStoreProvidesValidResponse() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest(
.body(defaultChatRequest(
"Which is the restaurant with the highest grade that has a cuisine as American ?"))
.when()
.post("/api/ai/rag")
Expand All @@ -100,4 +180,12 @@ void ragWithSimpleStore() {
.contentType(ContentType.JSON)
.body("answer", containsString("Regina Caterers"));
}

static Stream<String> chatPrompts() {
return Stream.of("java", "spring boot", "ai");
}

private AIChatRequest defaultChatRequest(String message) {
return new AIChatRequest(message);
}
}

0 comments on commit dbd65ad

Please sign in to comment.