Skip to content

Commit

Permalink
feat : expose endpoint to upload multiple documents (#46)
Browse files Browse the repository at this point in the history
* feat : expose endpoint to upload multiple documents

* feat : adds more test cases

* fix : add more assertions

* feat : fix spotless
  • Loading branch information
rajadilipkolli authored May 6, 2024
1 parent d6f4e17 commit fefbb7f
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -1,44 +1,14 @@
package com.learning.ai.llmragwithspringai.config;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.reader.ExtractedTextFormatter;
import org.springframework.ai.reader.pdf.PagePdfDocumentReader;
import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.ApplicationRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource;

@Configuration(proxyBeanMethods = false)
public class AppConfig {
private static final Logger LOGGER = LoggerFactory.getLogger(AppConfig.class);

@Value("classpath:Rohit_Gurunath_Sharma.pdf")
private Resource resource;

@Bean
TokenTextSplitter tokenTextSplitter() {
return new TokenTextSplitter();
}

@Bean
ApplicationRunner runner(VectorStore vectorStore, TokenTextSplitter tokenTextSplitter) {
return args -> {
LOGGER.info("Loading file(s) as Documents");
PdfDocumentReaderConfig config = PdfDocumentReaderConfig.builder()
.withPageExtractedTextFormatter(new ExtractedTextFormatter.Builder()
.withNumberOfBottomTextLinesToDelete(3)
.withNumberOfTopPagesToSkipBeforeDelete(1)
.build())
.withPagesPerDocument(1)
.build();
PagePdfDocumentReader pagePdfDocumentReader = new PagePdfDocumentReader(resource, config);
vectorStore.accept(tokenTextSplitter.apply(pagePdfDocumentReader.get()));
LOGGER.info("Loaded document to database.");
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.learning.ai.llmragwithspringai.controller;

import com.learning.ai.llmragwithspringai.service.DataIndexerService;
import java.util.Map;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;

@RestController
@RequestMapping("/api/data/v1/")
public class DataIndexController {

private final DataIndexerService dataIndexerService;

public DataIndexController(DataIndexerService dataIndexerService) {
this.dataIndexerService = dataIndexerService;
}

@PostMapping(value = "/upload", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
public ResponseEntity<String> load(@RequestPart("file") MultipartFile multipartFile) {
try {
this.dataIndexerService.loadData(multipartFile.getResource());
return ResponseEntity.ok("Data indexed successfully!");
} catch (Exception e) {
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
.body("An error occurred while indexing data: " + e.getMessage());
}
}

@GetMapping("count")
public Map<String, Long> count() {
return Map.of("count", dataIndexerService.count());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;

Expand Down Expand Up @@ -43,9 +42,8 @@ public AIChatService(ChatClient aiClient, VectorStore vectorStore) {

public String chat(String query) {
// Querying the VectorStore using natural language looking for the information about info asked.
LOGGER.info("Querying vector store with query :{}", query);
List<Document> listOfSimilarDocuments = this.vectorStore.similaritySearch(
SearchRequest.query(query).withTopK(2).withSimilarityThreshold(0.8d));
LOGGER.debug("Querying vector store with query :{}", query);
List<Document> listOfSimilarDocuments = this.vectorStore.similaritySearch(query);
String documents = listOfSimilarDocuments.stream()
.map(Document::getContent)
.collect(Collectors.joining(System.lineSeparator()));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.learning.ai.llmragwithspringai.service;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.DocumentReader;
import org.springframework.ai.reader.ExtractedTextFormatter;
import org.springframework.ai.reader.JsonReader;
import org.springframework.ai.reader.TextReader;
import org.springframework.ai.reader.pdf.PagePdfDocumentReader;
import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;

@Service
public class DataIndexerService {

private static final Logger LOGGER = LoggerFactory.getLogger(DataIndexerService.class);

private final TokenTextSplitter tokenTextSplitter;
private final VectorStore vectorStore;

public DataIndexerService(TokenTextSplitter tokenTextSplitter, VectorStore vectorStore) {
this.tokenTextSplitter = tokenTextSplitter;
this.vectorStore = vectorStore;
}

public void loadData(Resource documentResource) {
DocumentReader documentReader = null;
if (documentResource.getFilename() != null
&& documentResource.getFilename().endsWith(".pdf")) {
LOGGER.info("Loading PDF document");
PdfDocumentReaderConfig pdfDocumentReaderConfig = PdfDocumentReaderConfig.builder()
.withPageExtractedTextFormatter(ExtractedTextFormatter.builder()
.withNumberOfBottomTextLinesToDelete(3)
.withNumberOfTopPagesToSkipBeforeDelete(1)
.build())
.withPagesPerDocument(1)
.build();
documentReader = new PagePdfDocumentReader(documentResource, pdfDocumentReaderConfig);
} else if (documentResource.getFilename() != null
&& documentResource.getFilename().endsWith(".txt")) {
documentReader = new TextReader(documentResource);
} else if (documentResource.getFilename() != null
&& documentResource.getFilename().endsWith(".json")) {
documentReader = new JsonReader(documentResource);
}
if (documentReader != null) {
LOGGER.info("Loading text document to redis vector database");
vectorStore.accept(tokenTextSplitter.apply(documentReader.get()));
LOGGER.info("Loaded document to redis vector database.");
}
}

public long count() {
return this.vectorStore.similaritySearch("*").size();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ spring.application.name=rag-springai-ollama-llm
spring.threads.virtual.enabled=true
spring.mvc.problemdetails.enabled=true

spring.ai.ollama.chat.options.model=orca-mini
spring.ai.ollama.chat.options.model=llama2
spring.ai.ollama.chat.options.temperature=0.3
spring.ai.ollama.chat.options.top-k=2
spring.ai.ollama.chat.options.top-p=0.2

spring.ai.ollama.embedding.options.model=orca-mini
spring.ai.ollama.embedding.options.model=llama2

spring.ai.vectorstore.redis.index=vector_store
spring.ai.vectorstore.redis.prefix=ai
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,24 @@
import com.learning.ai.llmragwithspringai.model.request.AIChatRequest;
import io.restassured.RestAssured;
import io.restassured.http.ContentType;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Optional;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestMethodOrder;
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.core.io.ClassPathResource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@TestMethodOrder(value = MethodOrderer.OrderAnnotation.class)
class LlmRagWithSpringAiApplicationIntTest extends AbstractIntegrationTest {

@LocalServerPort
Expand All @@ -26,9 +38,34 @@ void setUp() {
}

@Test
@Order(1)
void uploadPdfContent() throws IOException, URISyntaxException {
given().multiPart("file", getPath("/Rohit_Gurunath_Sharma.pdf").toFile())
.when()
.post("/api/data/v1/upload")
.then()
.statusCode(200)
.log()
.all();
}

@Test
@Order(2)
void uploadPdfContentCount() {
given().when()
.get("/api/data/v1/count")
.then()
.statusCode(200)
.body("count", is(1))
.log()
.all();
}

@Test
@Order(101)
void testRag() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("Did Rohit Sharma won ICC Mens T20 World Cup 2007 ?"))
.body(new AIChatRequest("Is Rohit Sharma batsman?"))
.when()
.post("/api/ai/chat")
.then()
Expand All @@ -39,14 +76,29 @@ void testRag() {
}

@Test
@Order(102)
void testRag2() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("Did Rohit Sharma won ICC Mens T20 World Cup 2016 ?"))
.when()
.post("/api/ai/chat")
.then()
.statusCode(200)
.body("queryResponse", containsString("not"))
.log()
.all();
}

@Test
@Order(111)
void testEmptyQuery() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest(""))
.when()
.post("/api/ai/chat")
.then()
.statusCode(400)
.header("Content-Type", is("application/problem+json"))
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Invalid request content."))
.body("instance", is("/api/ai/chat"))
.body("title", is("Constraint Violation"))
Expand All @@ -56,6 +108,7 @@ void testEmptyQuery() {
}

@Test
@Order(112)
void testLongQueryString() {
String longQuery = "a".repeat(1000); // Example of a very long query string
given().contentType(ContentType.JSON)
Expand All @@ -64,7 +117,7 @@ void testLongQueryString() {
.post("/api/ai/chat")
.then()
.statusCode(400)
.header("Content-Type", is("application/problem+json"))
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Invalid request content."))
.body("instance", is("/api/ai/chat"))
.body("title", is("Constraint Violation"))
Expand All @@ -74,14 +127,15 @@ void testLongQueryString() {
}

@Test
@Order(113)
void testSpecialCharactersInQuery() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("@#$%^&*()"))
.when()
.post("/api/ai/chat")
.then()
.statusCode(400)
.header("Content-Type", is("application/problem+json"))
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Invalid request content."))
.body("instance", is("/api/ai/chat"))
.body("title", is("Constraint Violation"))
Expand All @@ -90,4 +144,75 @@ void testSpecialCharactersInQuery() {
.body("violations[0].message", containsString("Invalid characters in query"))
.log();
}

@Test
@Order(114)
void testNullRequestBody() {
given().contentType(ContentType.JSON)
.body(Optional.empty())
.when()
.post("/api/ai/chat")
.then()
.statusCode(400)
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Failed to read request"))
.body("instance", is("/api/ai/chat"))
.body("title", is("Bad Request"))
.log();
}

@Test
@Order(115)
void testUnsupportedContentType() {
given().contentType("text/plain")
.body("Is Rohit Sharma a batsman?")
.when()
.post("/api/ai/chat")
.then()
.statusCode(415)
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Content-Type 'text/plain;charset=ISO-8859-1' is not supported."))
.body("instance", is("/api/ai/chat"))
.body("title", is("Unsupported Media Type"))
.log();
}

@Test
@Order(116)
void testMissingQuestionField() {
given().contentType(ContentType.JSON)
.body("{}")
.when()
.post("/api/ai/chat")
.then()
.statusCode(400)
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Invalid request content."))
.body("instance", is("/api/ai/chat"))
.body("title", is("Constraint Violation"))
.body("violations", hasSize(1))
.body("violations[0].field", is("question"))
.body("violations[0].message", containsString("Query cannot be empty"))
.log();
}

@Test
@Order(117)
void testInvalidJsonStructure() {
given().contentType(ContentType.JSON)
.body("{invalid json}")
.when()
.post("/api/ai/chat")
.then()
.statusCode(400)
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Failed to read request"))
.body("instance", is("/api/ai/chat"))
.body("title", is("Bad Request"))
.log();
}

private Path getPath(String fileName) throws URISyntaxException, IOException {
return Paths.get(new ClassPathResource(fileName).getURL().toURI());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ public class TestLlmRagWithSpringAiApplication {
OllamaContainer ollama(DynamicPropertyRegistry properties) {
// The model name to use (e.g., "orca-mini", "mistral", "llama2", "codellama", "phi", or
// "tinyllama")
OllamaContainer ollama = new OllamaContainer(DockerImageName.parse("langchain4j/ollama-orca-mini:latest")
.asCompatibleSubstituteFor("ollama/ollama"));
OllamaContainer ollama = new OllamaContainer(
DockerImageName.parse("langchain4j/ollama-llama2:latest").asCompatibleSubstituteFor("ollama/ollama"));
properties.add("spring.ai.ollama.base-url", ollama::getEndpoint);
return ollama;
}
Expand Down

0 comments on commit fefbb7f

Please sign in to comment.