From fefbb7f6a342e24be3b2e29aca7d2f388f0c7592 Mon Sep 17 00:00:00 2001 From: Raja Kolli Date: Mon, 6 May 2024 16:48:48 +0530 Subject: [PATCH] feat : expose endpoint to upload multiple documents (#46) * feat : expose endpoint to upload multiple documents * feat : adds more test cases * fix : add more assertions * feat : fix spotless --- .../llmragwithspringai/config/AppConfig.java | 30 ---- .../controller/DataIndexController.java | 36 +++++ .../service/AIChatService.java | 6 +- .../service/DataIndexerService.java | 59 ++++++++ .../src/main/resources/application.properties | 7 +- .../LlmRagWithSpringAiApplicationIntTest.java | 133 +++++++++++++++++- .../TestLlmRagWithSpringAiApplication.java | 4 +- 7 files changed, 233 insertions(+), 42 deletions(-) create mode 100644 rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/controller/DataIndexController.java create mode 100644 rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/service/DataIndexerService.java diff --git a/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/config/AppConfig.java b/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/config/AppConfig.java index 10f8547..1534063 100644 --- a/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/config/AppConfig.java +++ b/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/config/AppConfig.java @@ -1,44 +1,14 @@ package com.learning.ai.llmragwithspringai.config; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.reader.ExtractedTextFormatter; -import org.springframework.ai.reader.pdf.PagePdfDocumentReader; -import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig; import org.springframework.ai.transformer.splitter.TokenTextSplitter; -import org.springframework.ai.vectorstore.VectorStore; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.boot.ApplicationRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.core.io.Resource; @Configuration(proxyBeanMethods = false) public class AppConfig { - private static final Logger LOGGER = LoggerFactory.getLogger(AppConfig.class); - - @Value("classpath:Rohit_Gurunath_Sharma.pdf") - private Resource resource; @Bean TokenTextSplitter tokenTextSplitter() { return new TokenTextSplitter(); } - - @Bean - ApplicationRunner runner(VectorStore vectorStore, TokenTextSplitter tokenTextSplitter) { - return args -> { - LOGGER.info("Loading file(s) as Documents"); - PdfDocumentReaderConfig config = PdfDocumentReaderConfig.builder() - .withPageExtractedTextFormatter(new ExtractedTextFormatter.Builder() - .withNumberOfBottomTextLinesToDelete(3) - .withNumberOfTopPagesToSkipBeforeDelete(1) - .build()) - .withPagesPerDocument(1) - .build(); - PagePdfDocumentReader pagePdfDocumentReader = new PagePdfDocumentReader(resource, config); - vectorStore.accept(tokenTextSplitter.apply(pagePdfDocumentReader.get())); - LOGGER.info("Loaded document to database."); - }; - } } diff --git a/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/controller/DataIndexController.java b/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/controller/DataIndexController.java new file mode 100644 index 0000000..e24099e --- /dev/null +++ b/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/controller/DataIndexController.java @@ -0,0 +1,36 @@ +package com.learning.ai.llmragwithspringai.controller; + +import com.learning.ai.llmragwithspringai.service.DataIndexerService; +import java.util.Map; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.*; +import org.springframework.web.multipart.MultipartFile; + +@RestController +@RequestMapping("/api/data/v1/") +public class DataIndexController { + + private final DataIndexerService dataIndexerService; + + public DataIndexController(DataIndexerService dataIndexerService) { + this.dataIndexerService = dataIndexerService; + } + + @PostMapping(value = "/upload", consumes = MediaType.MULTIPART_FORM_DATA_VALUE) + public ResponseEntity load(@RequestPart("file") MultipartFile multipartFile) { + try { + this.dataIndexerService.loadData(multipartFile.getResource()); + return ResponseEntity.ok("Data indexed successfully!"); + } catch (Exception e) { + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body("An error occurred while indexing data: " + e.getMessage()); + } + } + + @GetMapping("count") + public Map count() { + return Map.of("count", dataIndexerService.count()); + } +} diff --git a/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java b/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java index ad72c81..9047990 100644 --- a/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java +++ b/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java @@ -13,7 +13,6 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.document.Document; -import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.stereotype.Service; @@ -43,9 +42,8 @@ public AIChatService(ChatClient aiClient, VectorStore vectorStore) { public String chat(String query) { // Querying the VectorStore using natural language looking for the information about info asked. - LOGGER.info("Querying vector store with query :{}", query); - List listOfSimilarDocuments = this.vectorStore.similaritySearch( - SearchRequest.query(query).withTopK(2).withSimilarityThreshold(0.8d)); + LOGGER.debug("Querying vector store with query :{}", query); + List listOfSimilarDocuments = this.vectorStore.similaritySearch(query); String documents = listOfSimilarDocuments.stream() .map(Document::getContent) .collect(Collectors.joining(System.lineSeparator())); diff --git a/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/service/DataIndexerService.java b/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/service/DataIndexerService.java new file mode 100644 index 0000000..56f83b3 --- /dev/null +++ b/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/service/DataIndexerService.java @@ -0,0 +1,59 @@ +package com.learning.ai.llmragwithspringai.service; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.DocumentReader; +import org.springframework.ai.reader.ExtractedTextFormatter; +import org.springframework.ai.reader.JsonReader; +import org.springframework.ai.reader.TextReader; +import org.springframework.ai.reader.pdf.PagePdfDocumentReader; +import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig; +import org.springframework.ai.transformer.splitter.TokenTextSplitter; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.core.io.Resource; +import org.springframework.stereotype.Service; + +@Service +public class DataIndexerService { + + private static final Logger LOGGER = LoggerFactory.getLogger(DataIndexerService.class); + + private final TokenTextSplitter tokenTextSplitter; + private final VectorStore vectorStore; + + public DataIndexerService(TokenTextSplitter tokenTextSplitter, VectorStore vectorStore) { + this.tokenTextSplitter = tokenTextSplitter; + this.vectorStore = vectorStore; + } + + public void loadData(Resource documentResource) { + DocumentReader documentReader = null; + if (documentResource.getFilename() != null + && documentResource.getFilename().endsWith(".pdf")) { + LOGGER.info("Loading PDF document"); + PdfDocumentReaderConfig pdfDocumentReaderConfig = PdfDocumentReaderConfig.builder() + .withPageExtractedTextFormatter(ExtractedTextFormatter.builder() + .withNumberOfBottomTextLinesToDelete(3) + .withNumberOfTopPagesToSkipBeforeDelete(1) + .build()) + .withPagesPerDocument(1) + .build(); + documentReader = new PagePdfDocumentReader(documentResource, pdfDocumentReaderConfig); + } else if (documentResource.getFilename() != null + && documentResource.getFilename().endsWith(".txt")) { + documentReader = new TextReader(documentResource); + } else if (documentResource.getFilename() != null + && documentResource.getFilename().endsWith(".json")) { + documentReader = new JsonReader(documentResource); + } + if (documentReader != null) { + LOGGER.info("Loading text document to redis vector database"); + vectorStore.accept(tokenTextSplitter.apply(documentReader.get())); + LOGGER.info("Loaded document to redis vector database."); + } + } + + public long count() { + return this.vectorStore.similaritySearch("*").size(); + } +} diff --git a/rag/rag-springai-ollama-llm/src/main/resources/application.properties b/rag/rag-springai-ollama-llm/src/main/resources/application.properties index 68e85ba..79bad46 100644 --- a/rag/rag-springai-ollama-llm/src/main/resources/application.properties +++ b/rag/rag-springai-ollama-llm/src/main/resources/application.properties @@ -3,9 +3,12 @@ spring.application.name=rag-springai-ollama-llm spring.threads.virtual.enabled=true spring.mvc.problemdetails.enabled=true -spring.ai.ollama.chat.options.model=orca-mini +spring.ai.ollama.chat.options.model=llama2 +spring.ai.ollama.chat.options.temperature=0.3 +spring.ai.ollama.chat.options.top-k=2 +spring.ai.ollama.chat.options.top-p=0.2 -spring.ai.ollama.embedding.options.model=orca-mini +spring.ai.ollama.embedding.options.model=llama2 spring.ai.vectorstore.redis.index=vector_store spring.ai.vectorstore.redis.prefix=ai diff --git a/rag/rag-springai-ollama-llm/src/test/java/com/learning/ai/llmragwithspringai/LlmRagWithSpringAiApplicationIntTest.java b/rag/rag-springai-ollama-llm/src/test/java/com/learning/ai/llmragwithspringai/LlmRagWithSpringAiApplicationIntTest.java index 83b9e5b..2978914 100644 --- a/rag/rag-springai-ollama-llm/src/test/java/com/learning/ai/llmragwithspringai/LlmRagWithSpringAiApplicationIntTest.java +++ b/rag/rag-springai-ollama-llm/src/test/java/com/learning/ai/llmragwithspringai/LlmRagWithSpringAiApplicationIntTest.java @@ -9,12 +9,24 @@ import com.learning.ai.llmragwithspringai.model.request.AIChatRequest; import io.restassured.RestAssured; import io.restassured.http.ContentType; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Optional; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.MethodOrderer; +import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestMethodOrder; import org.springframework.boot.test.web.server.LocalServerPort; +import org.springframework.core.io.ClassPathResource; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; @TestInstance(TestInstance.Lifecycle.PER_CLASS) +@TestMethodOrder(value = MethodOrderer.OrderAnnotation.class) class LlmRagWithSpringAiApplicationIntTest extends AbstractIntegrationTest { @LocalServerPort @@ -26,9 +38,34 @@ void setUp() { } @Test + @Order(1) + void uploadPdfContent() throws IOException, URISyntaxException { + given().multiPart("file", getPath("/Rohit_Gurunath_Sharma.pdf").toFile()) + .when() + .post("/api/data/v1/upload") + .then() + .statusCode(200) + .log() + .all(); + } + + @Test + @Order(2) + void uploadPdfContentCount() { + given().when() + .get("/api/data/v1/count") + .then() + .statusCode(200) + .body("count", is(1)) + .log() + .all(); + } + + @Test + @Order(101) void testRag() { given().contentType(ContentType.JSON) - .body(new AIChatRequest("Did Rohit Sharma won ICC Mens T20 World Cup 2007 ?")) + .body(new AIChatRequest("Is Rohit Sharma batsman?")) .when() .post("/api/ai/chat") .then() @@ -39,6 +76,21 @@ void testRag() { } @Test + @Order(102) + void testRag2() { + given().contentType(ContentType.JSON) + .body(new AIChatRequest("Did Rohit Sharma won ICC Mens T20 World Cup 2016 ?")) + .when() + .post("/api/ai/chat") + .then() + .statusCode(200) + .body("queryResponse", containsString("not")) + .log() + .all(); + } + + @Test + @Order(111) void testEmptyQuery() { given().contentType(ContentType.JSON) .body(new AIChatRequest("")) @@ -46,7 +98,7 @@ void testEmptyQuery() { .post("/api/ai/chat") .then() .statusCode(400) - .header("Content-Type", is("application/problem+json")) + .header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE)) .body("detail", is("Invalid request content.")) .body("instance", is("/api/ai/chat")) .body("title", is("Constraint Violation")) @@ -56,6 +108,7 @@ void testEmptyQuery() { } @Test + @Order(112) void testLongQueryString() { String longQuery = "a".repeat(1000); // Example of a very long query string given().contentType(ContentType.JSON) @@ -64,7 +117,7 @@ void testLongQueryString() { .post("/api/ai/chat") .then() .statusCode(400) - .header("Content-Type", is("application/problem+json")) + .header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE)) .body("detail", is("Invalid request content.")) .body("instance", is("/api/ai/chat")) .body("title", is("Constraint Violation")) @@ -74,6 +127,7 @@ void testLongQueryString() { } @Test + @Order(113) void testSpecialCharactersInQuery() { given().contentType(ContentType.JSON) .body(new AIChatRequest("@#$%^&*()")) @@ -81,7 +135,7 @@ void testSpecialCharactersInQuery() { .post("/api/ai/chat") .then() .statusCode(400) - .header("Content-Type", is("application/problem+json")) + .header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE)) .body("detail", is("Invalid request content.")) .body("instance", is("/api/ai/chat")) .body("title", is("Constraint Violation")) @@ -90,4 +144,75 @@ void testSpecialCharactersInQuery() { .body("violations[0].message", containsString("Invalid characters in query")) .log(); } + + @Test + @Order(114) + void testNullRequestBody() { + given().contentType(ContentType.JSON) + .body(Optional.empty()) + .when() + .post("/api/ai/chat") + .then() + .statusCode(400) + .header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE)) + .body("detail", is("Failed to read request")) + .body("instance", is("/api/ai/chat")) + .body("title", is("Bad Request")) + .log(); + } + + @Test + @Order(115) + void testUnsupportedContentType() { + given().contentType("text/plain") + .body("Is Rohit Sharma a batsman?") + .when() + .post("/api/ai/chat") + .then() + .statusCode(415) + .header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE)) + .body("detail", is("Content-Type 'text/plain;charset=ISO-8859-1' is not supported.")) + .body("instance", is("/api/ai/chat")) + .body("title", is("Unsupported Media Type")) + .log(); + } + + @Test + @Order(116) + void testMissingQuestionField() { + given().contentType(ContentType.JSON) + .body("{}") + .when() + .post("/api/ai/chat") + .then() + .statusCode(400) + .header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE)) + .body("detail", is("Invalid request content.")) + .body("instance", is("/api/ai/chat")) + .body("title", is("Constraint Violation")) + .body("violations", hasSize(1)) + .body("violations[0].field", is("question")) + .body("violations[0].message", containsString("Query cannot be empty")) + .log(); + } + + @Test + @Order(117) + void testInvalidJsonStructure() { + given().contentType(ContentType.JSON) + .body("{invalid json}") + .when() + .post("/api/ai/chat") + .then() + .statusCode(400) + .header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE)) + .body("detail", is("Failed to read request")) + .body("instance", is("/api/ai/chat")) + .body("title", is("Bad Request")) + .log(); + } + + private Path getPath(String fileName) throws URISyntaxException, IOException { + return Paths.get(new ClassPathResource(fileName).getURL().toURI()); + } } diff --git a/rag/rag-springai-ollama-llm/src/test/java/com/learning/ai/llmragwithspringai/TestLlmRagWithSpringAiApplication.java b/rag/rag-springai-ollama-llm/src/test/java/com/learning/ai/llmragwithspringai/TestLlmRagWithSpringAiApplication.java index 3e04ab2..2f6512d 100644 --- a/rag/rag-springai-ollama-llm/src/test/java/com/learning/ai/llmragwithspringai/TestLlmRagWithSpringAiApplication.java +++ b/rag/rag-springai-ollama-llm/src/test/java/com/learning/ai/llmragwithspringai/TestLlmRagWithSpringAiApplication.java @@ -18,8 +18,8 @@ public class TestLlmRagWithSpringAiApplication { OllamaContainer ollama(DynamicPropertyRegistry properties) { // The model name to use (e.g., "orca-mini", "mistral", "llama2", "codellama", "phi", or // "tinyllama") - OllamaContainer ollama = new OllamaContainer(DockerImageName.parse("langchain4j/ollama-orca-mini:latest") - .asCompatibleSubstituteFor("ollama/ollama")); + OllamaContainer ollama = new OllamaContainer( + DockerImageName.parse("langchain4j/ollama-llama2:latest").asCompatibleSubstituteFor("ollama/ollama")); properties.add("spring.ai.ollama.base-url", ollama::getEndpoint); return ollama; }