Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : expose endpoint to upload multiple documents #46

Merged
merged 4 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,44 +1,14 @@
package com.learning.ai.llmragwithspringai.config;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.reader.ExtractedTextFormatter;
import org.springframework.ai.reader.pdf.PagePdfDocumentReader;
import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.ApplicationRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource;

@Configuration(proxyBeanMethods = false)
public class AppConfig {
private static final Logger LOGGER = LoggerFactory.getLogger(AppConfig.class);

@Value("classpath:Rohit_Gurunath_Sharma.pdf")
private Resource resource;

@Bean
TokenTextSplitter tokenTextSplitter() {
return new TokenTextSplitter();
}

@Bean
ApplicationRunner runner(VectorStore vectorStore, TokenTextSplitter tokenTextSplitter) {
return args -> {
LOGGER.info("Loading file(s) as Documents");
PdfDocumentReaderConfig config = PdfDocumentReaderConfig.builder()
.withPageExtractedTextFormatter(new ExtractedTextFormatter.Builder()
.withNumberOfBottomTextLinesToDelete(3)
.withNumberOfTopPagesToSkipBeforeDelete(1)
.build())
.withPagesPerDocument(1)
.build();
PagePdfDocumentReader pagePdfDocumentReader = new PagePdfDocumentReader(resource, config);
vectorStore.accept(tokenTextSplitter.apply(pagePdfDocumentReader.get()));
LOGGER.info("Loaded document to database.");
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.learning.ai.llmragwithspringai.controller;

import com.learning.ai.llmragwithspringai.service.DataIndexerService;
import java.util.Map;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;

@RestController
@RequestMapping("/api/data/v1/")
public class DataIndexController {

private final DataIndexerService dataIndexerService;

public DataIndexController(DataIndexerService dataIndexerService) {
this.dataIndexerService = dataIndexerService;
}

@PostMapping(value = "/upload", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
public ResponseEntity<String> load(@RequestPart("file") MultipartFile multipartFile) {
try {
this.dataIndexerService.loadData(multipartFile.getResource());
return ResponseEntity.ok("Data indexed successfully!");
} catch (Exception e) {
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
.body("An error occurred while indexing data: " + e.getMessage());
}
}

@GetMapping("count")
public Map<String, Long> count() {
return Map.of("count", dataIndexerService.count());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;

Expand Down Expand Up @@ -43,9 +42,8 @@ public AIChatService(ChatClient aiClient, VectorStore vectorStore) {

public String chat(String query) {
// Querying the VectorStore using natural language looking for the information about info asked.
LOGGER.info("Querying vector store with query :{}", query);
List<Document> listOfSimilarDocuments = this.vectorStore.similaritySearch(
SearchRequest.query(query).withTopK(2).withSimilarityThreshold(0.8d));
LOGGER.debug("Querying vector store with query :{}", query);
List<Document> listOfSimilarDocuments = this.vectorStore.similaritySearch(query);
String documents = listOfSimilarDocuments.stream()
.map(Document::getContent)
.collect(Collectors.joining(System.lineSeparator()));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.learning.ai.llmragwithspringai.service;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.DocumentReader;
import org.springframework.ai.reader.ExtractedTextFormatter;
import org.springframework.ai.reader.JsonReader;
import org.springframework.ai.reader.TextReader;
import org.springframework.ai.reader.pdf.PagePdfDocumentReader;
import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;

@Service
public class DataIndexerService {

private static final Logger LOGGER = LoggerFactory.getLogger(DataIndexerService.class);

private final TokenTextSplitter tokenTextSplitter;
private final VectorStore vectorStore;

public DataIndexerService(TokenTextSplitter tokenTextSplitter, VectorStore vectorStore) {
this.tokenTextSplitter = tokenTextSplitter;
this.vectorStore = vectorStore;
}

public void loadData(Resource documentResource) {
DocumentReader documentReader = null;
if (documentResource.getFilename() != null
&& documentResource.getFilename().endsWith(".pdf")) {
LOGGER.info("Loading PDF document");
PdfDocumentReaderConfig pdfDocumentReaderConfig = PdfDocumentReaderConfig.builder()
.withPageExtractedTextFormatter(ExtractedTextFormatter.builder()
.withNumberOfBottomTextLinesToDelete(3)
.withNumberOfTopPagesToSkipBeforeDelete(1)
.build())
.withPagesPerDocument(1)
.build();
documentReader = new PagePdfDocumentReader(documentResource, pdfDocumentReaderConfig);
} else if (documentResource.getFilename() != null
&& documentResource.getFilename().endsWith(".txt")) {
documentReader = new TextReader(documentResource);
} else if (documentResource.getFilename() != null
&& documentResource.getFilename().endsWith(".json")) {
documentReader = new JsonReader(documentResource);
}
if (documentReader != null) {
LOGGER.info("Loading text document to redis vector database");
vectorStore.accept(tokenTextSplitter.apply(documentReader.get()));
LOGGER.info("Loaded document to redis vector database.");
}
}

public long count() {
return this.vectorStore.similaritySearch("*").size();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ spring.application.name=rag-springai-ollama-llm
spring.threads.virtual.enabled=true
spring.mvc.problemdetails.enabled=true

spring.ai.ollama.chat.options.model=orca-mini
spring.ai.ollama.chat.options.model=llama2
spring.ai.ollama.chat.options.temperature=0.3
spring.ai.ollama.chat.options.top-k=2
spring.ai.ollama.chat.options.top-p=0.2

spring.ai.ollama.embedding.options.model=orca-mini
spring.ai.ollama.embedding.options.model=llama2

spring.ai.vectorstore.redis.index=vector_store
spring.ai.vectorstore.redis.prefix=ai
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,19 @@
import com.learning.ai.llmragwithspringai.model.request.AIChatRequest;
import io.restassured.RestAssured;
import io.restassured.http.ContentType;
import java.io.File;
import java.io.IOException;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestMethodOrder;
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.core.io.ClassPathResource;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@TestMethodOrder(value = MethodOrderer.OrderAnnotation.class)
class LlmRagWithSpringAiApplicationIntTest extends AbstractIntegrationTest {

@LocalServerPort
Expand All @@ -26,7 +33,46 @@ void setUp() {
}

@Test
@Order(1)
void uploadPdfContent() throws IOException {
given().multiPart("file", getFile("/Rohit_Gurunath_Sharma.pdf"))
.when()
.post("/api/data/v1/upload")
.then()
.statusCode(200)
.log()
.all();
}

@Test
@Order(2)
void uploadPdfContentCount() {
given().when()
.get("/api/data/v1/count")
.then()
.statusCode(200)
.body("count", is(1))
.log()
.all();
}

@Test
@Order(101)
void testRag() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("Is Rohit Sharma batsman?"))
.when()
.post("/api/ai/chat")
.then()
.statusCode(200)
.body("queryResponse", containsString("Yes"))
.log()
.all();
}
rajadilipkolli marked this conversation as resolved.
Show resolved Hide resolved

@Test
@Order(102)
void testRag2() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("Did Rohit Sharma won ICC Mens T20 World Cup 2007 ?"))
.when()
Expand All @@ -39,6 +85,7 @@ void testRag() {
}

@Test
@Order(111)
void testEmptyQuery() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest(""))
Expand All @@ -56,6 +103,7 @@ void testEmptyQuery() {
}

@Test
@Order(112)
void testLongQueryString() {
String longQuery = "a".repeat(1000); // Example of a very long query string
given().contentType(ContentType.JSON)
Expand All @@ -74,6 +122,7 @@ void testLongQueryString() {
}

@Test
@Order(113)
void testSpecialCharactersInQuery() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("@#$%^&*()"))
Expand All @@ -90,4 +139,8 @@ void testSpecialCharactersInQuery() {
.body("violations[0].message", containsString("Invalid characters in query"))
.log();
}

private File getFile(String fileName) throws IOException {
return new ClassPathResource(fileName).getFile();
}
rajadilipkolli marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ public class TestLlmRagWithSpringAiApplication {
OllamaContainer ollama(DynamicPropertyRegistry properties) {
// The model name to use (e.g., "orca-mini", "mistral", "llama2", "codellama", "phi", or
// "tinyllama")
OllamaContainer ollama = new OllamaContainer(DockerImageName.parse("langchain4j/ollama-orca-mini:latest")
.asCompatibleSubstituteFor("ollama/ollama"));
OllamaContainer ollama = new OllamaContainer(
DockerImageName.parse("langchain4j/ollama-llama2:latest").asCompatibleSubstituteFor("ollama/ollama"));
properties.add("spring.ai.ollama.base-url", ollama::getEndpoint);
return ollama;
}
Expand Down