Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : expose endpoint to upload multiple documents #46

Merged
merged 4 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,44 +1,14 @@
package com.learning.ai.llmragwithspringai.config;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.reader.ExtractedTextFormatter;
import org.springframework.ai.reader.pdf.PagePdfDocumentReader;
import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.ApplicationRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource;

@Configuration(proxyBeanMethods = false)
public class AppConfig {
private static final Logger LOGGER = LoggerFactory.getLogger(AppConfig.class);

@Value("classpath:Rohit_Gurunath_Sharma.pdf")
private Resource resource;

@Bean
TokenTextSplitter tokenTextSplitter() {
return new TokenTextSplitter();
}

@Bean
ApplicationRunner runner(VectorStore vectorStore, TokenTextSplitter tokenTextSplitter) {
return args -> {
LOGGER.info("Loading file(s) as Documents");
PdfDocumentReaderConfig config = PdfDocumentReaderConfig.builder()
.withPageExtractedTextFormatter(new ExtractedTextFormatter.Builder()
.withNumberOfBottomTextLinesToDelete(3)
.withNumberOfTopPagesToSkipBeforeDelete(1)
.build())
.withPagesPerDocument(1)
.build();
PagePdfDocumentReader pagePdfDocumentReader = new PagePdfDocumentReader(resource, config);
vectorStore.accept(tokenTextSplitter.apply(pagePdfDocumentReader.get()));
LOGGER.info("Loaded document to database.");
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.learning.ai.llmragwithspringai.controller;

import com.learning.ai.llmragwithspringai.service.DataIndexerService;
import java.util.Map;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;

@RestController
@RequestMapping("/api/data/v1/")
public class DataIndexController {

private final DataIndexerService dataIndexerService;

public DataIndexController(DataIndexerService dataIndexerService) {
this.dataIndexerService = dataIndexerService;
}

@PostMapping(value = "/upload", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
public ResponseEntity<String> load(@RequestPart("file") MultipartFile multipartFile) {
try {
this.dataIndexerService.loadData(multipartFile.getResource());
return ResponseEntity.ok("Data indexed successfully!");
} catch (Exception e) {
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
.body("An error occurred while indexing data: " + e.getMessage());
}
}

@GetMapping("count")
public Map<String, Long> count() {
return Map.of("count", dataIndexerService.count());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;

Expand Down Expand Up @@ -43,9 +42,8 @@ public AIChatService(ChatClient aiClient, VectorStore vectorStore) {

public String chat(String query) {
// Querying the VectorStore using natural language looking for the information about info asked.
LOGGER.info("Querying vector store with query :{}", query);
List<Document> listOfSimilarDocuments = this.vectorStore.similaritySearch(
SearchRequest.query(query).withTopK(2).withSimilarityThreshold(0.8d));
LOGGER.debug("Querying vector store with query :{}", query);
List<Document> listOfSimilarDocuments = this.vectorStore.similaritySearch(query);
String documents = listOfSimilarDocuments.stream()
.map(Document::getContent)
.collect(Collectors.joining(System.lineSeparator()));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.learning.ai.llmragwithspringai.service;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.DocumentReader;
import org.springframework.ai.reader.ExtractedTextFormatter;
import org.springframework.ai.reader.JsonReader;
import org.springframework.ai.reader.TextReader;
import org.springframework.ai.reader.pdf.PagePdfDocumentReader;
import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;

@Service
public class DataIndexerService {

private static final Logger LOGGER = LoggerFactory.getLogger(DataIndexerService.class);

private final TokenTextSplitter tokenTextSplitter;
private final VectorStore vectorStore;

public DataIndexerService(TokenTextSplitter tokenTextSplitter, VectorStore vectorStore) {
this.tokenTextSplitter = tokenTextSplitter;
this.vectorStore = vectorStore;
}

public void loadData(Resource documentResource) {
DocumentReader documentReader = null;
if (documentResource.getFilename() != null
&& documentResource.getFilename().endsWith(".pdf")) {
LOGGER.info("Loading PDF document");
PdfDocumentReaderConfig pdfDocumentReaderConfig = PdfDocumentReaderConfig.builder()
.withPageExtractedTextFormatter(ExtractedTextFormatter.builder()
.withNumberOfBottomTextLinesToDelete(3)
.withNumberOfTopPagesToSkipBeforeDelete(1)
.build())
.withPagesPerDocument(1)
.build();
documentReader = new PagePdfDocumentReader(documentResource, pdfDocumentReaderConfig);
} else if (documentResource.getFilename() != null
&& documentResource.getFilename().endsWith(".txt")) {
documentReader = new TextReader(documentResource);
} else if (documentResource.getFilename() != null
&& documentResource.getFilename().endsWith(".json")) {
documentReader = new JsonReader(documentResource);
}
if (documentReader != null) {
LOGGER.info("Loading text document to redis vector database");
vectorStore.accept(tokenTextSplitter.apply(documentReader.get()));
LOGGER.info("Loaded document to redis vector database.");
}
}

public long count() {
return this.vectorStore.similaritySearch("*").size();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ spring.application.name=rag-springai-ollama-llm
spring.threads.virtual.enabled=true
spring.mvc.problemdetails.enabled=true

spring.ai.ollama.chat.options.model=orca-mini
spring.ai.ollama.chat.options.model=llama2
spring.ai.ollama.chat.options.temperature=0.3
spring.ai.ollama.chat.options.top-k=2
spring.ai.ollama.chat.options.top-p=0.2

spring.ai.ollama.embedding.options.model=orca-mini
spring.ai.ollama.embedding.options.model=llama2

spring.ai.vectorstore.redis.index=vector_store
spring.ai.vectorstore.redis.prefix=ai
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,24 @@
import com.learning.ai.llmragwithspringai.model.request.AIChatRequest;
import io.restassured.RestAssured;
import io.restassured.http.ContentType;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Optional;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestMethodOrder;
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.core.io.ClassPathResource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@TestMethodOrder(value = MethodOrderer.OrderAnnotation.class)
class LlmRagWithSpringAiApplicationIntTest extends AbstractIntegrationTest {

@LocalServerPort
Expand All @@ -26,9 +38,34 @@ void setUp() {
}

@Test
@Order(1)
void uploadPdfContent() throws IOException, URISyntaxException {
given().multiPart("file", getPath("/Rohit_Gurunath_Sharma.pdf").toFile())
.when()
.post("/api/data/v1/upload")
.then()
.statusCode(200)
.log()
.all();
}
Comment on lines +42 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tip

Codebase Verification

The verification process has shown that the file "/Rohit_Gurunath_Sharma.pdf" does not exist in the 'src/test/resources' directory, which was assumed to be the correct location based on the context of a typical Java project structure. This suggests that there might be an issue with how the file path is handled or documented in the test setup. It is recommended to review the implementation of the getPath method to ensure it points to the correct location and that the necessary test files are properly placed in the project structure.

  • Review the implementation of the getPath method to confirm the intended file path.
  • Ensure that the file "/Rohit_Gurunath_Sharma.pdf" is correctly placed in the project structure as expected by the test.
Analysis chain

Ensure the file path in getPath method is correctly handled.


The verification script intended to check the existence of the file specified in the getPath method failed because the Java command could not be executed in the script's environment. This means we haven't yet verified whether the file path is correctly handled in the code.

To proceed, I recommend running a revised script that doesn't rely on executing Java, but instead directly checks the file system based on expected outcomes from the getPath method, if such expected outcomes can be inferred from the codebase or documentation.

Scripts Executed

The following scripts were executed for the analysis:


Script:

#!/bin/bash
# Description: Verify that the file path provided in the `getPath` method is correctly handled and accessible.

# Test: Check if the file exists at the specified path.
if [ -f "$(java -jar rag/rag-springai-ollama-llm/src/test/java/com/learning/ai/llmragwithspringai/LlmRagWithSpringAiApplicationIntTest.java getPath /Rohit_Gurunath_Sharma.pdf)" ]; then
  echo "File exists."
else
  echo "File does not exist."
fi

Length of output: 305



Script:

#!/bin/bash
# Description: Check if the file path "/Rohit_Gurunath_Sharma.pdf" exists in the expected directory.

# Assuming the file should be located in a 'resources' directory under 'src/test/java'
if [ -f "src/test/resources/Rohit_Gurunath_Sharma.pdf" ]; then
  echo "File exists."
else
  echo "File does not exist."
fi

Length of output: 108


@Test
@Order(2)
void uploadPdfContentCount() {
given().when()
.get("/api/data/v1/count")
.then()
.statusCode(200)
.body("count", is(1))
.log()
.all();
}

@Test
@Order(101)
void testRag() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("Did Rohit Sharma won ICC Mens T20 World Cup 2007 ?"))
.body(new AIChatRequest("Is Rohit Sharma batsman?"))
.when()
.post("/api/ai/chat")
.then()
Expand All @@ -39,14 +76,29 @@ void testRag() {
}

@Test
@Order(102)
void testRag2() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("Did Rohit Sharma won ICC Mens T20 World Cup 2016 ?"))
.when()
.post("/api/ai/chat")
.then()
.statusCode(200)
.body("queryResponse", containsString("not"))
.log()
.all();
}

@Test
@Order(111)
void testEmptyQuery() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest(""))
.when()
.post("/api/ai/chat")
.then()
.statusCode(400)
.header("Content-Type", is("application/problem+json"))
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Invalid request content."))
.body("instance", is("/api/ai/chat"))
.body("title", is("Constraint Violation"))
Expand All @@ -56,6 +108,7 @@ void testEmptyQuery() {
}

@Test
@Order(112)
void testLongQueryString() {
String longQuery = "a".repeat(1000); // Example of a very long query string
given().contentType(ContentType.JSON)
Expand All @@ -64,7 +117,7 @@ void testLongQueryString() {
.post("/api/ai/chat")
.then()
.statusCode(400)
.header("Content-Type", is("application/problem+json"))
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Invalid request content."))
.body("instance", is("/api/ai/chat"))
.body("title", is("Constraint Violation"))
Expand All @@ -74,14 +127,15 @@ void testLongQueryString() {
}

@Test
@Order(113)
void testSpecialCharactersInQuery() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("@#$%^&*()"))
.when()
.post("/api/ai/chat")
.then()
.statusCode(400)
.header("Content-Type", is("application/problem+json"))
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Invalid request content."))
.body("instance", is("/api/ai/chat"))
.body("title", is("Constraint Violation"))
Expand All @@ -90,4 +144,75 @@ void testSpecialCharactersInQuery() {
.body("violations[0].message", containsString("Invalid characters in query"))
.log();
}

@Test
@Order(114)
void testNullRequestBody() {
given().contentType(ContentType.JSON)
.body(Optional.empty())
.when()
.post("/api/ai/chat")
.then()
.statusCode(400)
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Failed to read request"))
.body("instance", is("/api/ai/chat"))
.body("title", is("Bad Request"))
.log();
}
Comment on lines +148 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test checks for null request bodies in AI chat. Consider adding a more descriptive error message.

- .body("detail", is("Failed to read request"))
+ .body("detail", is("Request body cannot be null or empty"))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
@Test
@Order(114)
void testNullRequestBody() {
given().contentType(ContentType.JSON)
.body(Optional.empty())
.when()
.post("/api/ai/chat")
.then()
.statusCode(400)
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Failed to read request"))
.body("instance", is("/api/ai/chat"))
.body("title", is("Bad Request"))
.log();
}
@Test
@Order(114)
void testNullRequestBody() {
given().contentType(ContentType.JSON)
.body(Optional.empty())
.when()
.post("/api/ai/chat")
.then()
.statusCode(400)
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Request body cannot be null or empty"))
.body("instance", is("/api/ai/chat"))
.body("title", is("Bad Request"))
.log();
}


@Test
@Order(115)
void testUnsupportedContentType() {
given().contentType("text/plain")
.body("Is Rohit Sharma a batsman?")
.when()
.post("/api/ai/chat")
.then()
.statusCode(415)
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Content-Type 'text/plain;charset=ISO-8859-1' is not supported."))
.body("instance", is("/api/ai/chat"))
.body("title", is("Unsupported Media Type"))
.log();
}

@Test
@Order(116)
void testMissingQuestionField() {
given().contentType(ContentType.JSON)
.body("{}")
.when()
.post("/api/ai/chat")
.then()
.statusCode(400)
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Invalid request content."))
.body("instance", is("/api/ai/chat"))
.body("title", is("Constraint Violation"))
.body("violations", hasSize(1))
.body("violations[0].field", is("question"))
.body("violations[0].message", containsString("Query cannot be empty"))
.log();
}

@Test
@Order(117)
void testInvalidJsonStructure() {
given().contentType(ContentType.JSON)
.body("{invalid json}")
.when()
.post("/api/ai/chat")
.then()
.statusCode(400)
.header(HttpHeaders.CONTENT_TYPE, is(MediaType.APPLICATION_PROBLEM_JSON_VALUE))
.body("detail", is("Failed to read request"))
.body("instance", is("/api/ai/chat"))
.body("title", is("Bad Request"))
.log();
}

private Path getPath(String fileName) throws URISyntaxException, IOException {
return Paths.get(new ClassPathResource(fileName).getURL().toURI());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ public class TestLlmRagWithSpringAiApplication {
OllamaContainer ollama(DynamicPropertyRegistry properties) {
// The model name to use (e.g., "orca-mini", "mistral", "llama2", "codellama", "phi", or
// "tinyllama")
OllamaContainer ollama = new OllamaContainer(DockerImageName.parse("langchain4j/ollama-orca-mini:latest")
.asCompatibleSubstituteFor("ollama/ollama"));
OllamaContainer ollama = new OllamaContainer(
DockerImageName.parse("langchain4j/ollama-llama2:latest").asCompatibleSubstituteFor("ollama/ollama"));
properties.add("spring.ai.ollama.base-url", ollama::getEndpoint);
return ollama;
}
Expand Down