From a2d689d70331bc19d429af598484e69b9010ecb6 Mon Sep 17 00:00:00 2001 From: Raja Kolli <rajadilipkolli@gmail.com> Date: Sun, 7 Apr 2024 15:01:32 +0530 Subject: [PATCH] feat : implement simple RAG (#38) * feat : implement simple RAG * update assertion --- .../example/ai/controller/ChatController.java | 59 ++------ .../com/example/ai/service/ChatService.java | 128 ++++++++++++++++++ .../src/main/resources/data/restaurants.json | 10 ++ .../src/main/resources/rag-prompt-template.st | 9 ++ .../ai/controller/ChatControllerTest.java | 21 ++- 5 files changed, 179 insertions(+), 48 deletions(-) create mode 100644 chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java create mode 100644 chatmodel-springai/src/main/resources/data/restaurants.json create mode 100644 chatmodel-springai/src/main/resources/rag-prompt-template.st diff --git a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java index 77b131f..a9ef21d 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java +++ b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java @@ -3,17 +3,8 @@ import com.example.ai.model.request.AIChatRequest; import com.example.ai.model.response.AIChatResponse; import com.example.ai.model.response.ActorsFilms; -import java.util.List; -import java.util.Map; -import org.springframework.ai.chat.ChatClient; -import org.springframework.ai.chat.ChatResponse; -import org.springframework.ai.chat.Generation; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.PromptTemplate; -import org.springframework.ai.embedding.EmbeddingClient; -import org.springframework.ai.parser.BeanOutputParser; +import com.example.ai.service.ChatService; +import java.io.IOException; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; @@ -25,61 +16,39 @@ @RequestMapping("/api/ai") public class ChatController { - private final ChatClient chatClient; + private final ChatService chatService; - private final EmbeddingClient embeddingClient; - - ChatController(ChatClient chatClient, EmbeddingClient embeddingClient) { - this.chatClient = chatClient; - this.embeddingClient = embeddingClient; + ChatController(ChatService chatService) { + this.chatService = chatService; } @PostMapping("/chat") AIChatResponse chat(@RequestBody AIChatRequest aiChatRequest) { - var answer = chatClient.call(aiChatRequest.query()); - return new AIChatResponse(answer); + return chatService.chat(aiChatRequest.query()); } @PostMapping("/chat-with-prompt") AIChatResponse chatWithPrompt(@RequestBody AIChatRequest aiChatRequest) { - PromptTemplate promptTemplate = new PromptTemplate("Tell me a joke about {subject}"); - Prompt prompt = promptTemplate.create(Map.of("subject", aiChatRequest.query())); - ChatResponse response = chatClient.call(prompt); - Generation generation = response.getResult(); - String answer = (generation != null) ? generation.getOutput().getContent() : ""; - return new AIChatResponse(answer); + return chatService.chatWithPrompt(aiChatRequest.query()); } @PostMapping("/chat-with-system-prompt") AIChatResponse chatWithSystemPrompt(@RequestBody AIChatRequest aiChatRequest) { - SystemMessage systemMessage = new SystemMessage("You are a sarcastic and funny chatbot"); - UserMessage userMessage = new UserMessage("Tell me a joke about " + aiChatRequest.query()); - Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); - ChatResponse response = chatClient.call(prompt); - String answer = response.getResult().getOutput().getContent(); - return new AIChatResponse(answer); + return chatService.chatWithSystemPrompt(aiChatRequest.query()); } @PostMapping("/emebedding-client-conversion") AIChatResponse chatWithEmbeddingClient(@RequestBody AIChatRequest aiChatRequest) { - List<Double> embed = embeddingClient.embed(aiChatRequest.query()); - return new AIChatResponse(embed.toString()); + return chatService.getEmbeddings(aiChatRequest.query()); } @GetMapping("/output") public ActorsFilms generate(@RequestParam(value = "actor", defaultValue = "Jr NTR") String actor) { - BeanOutputParser<ActorsFilms> outputParser = new BeanOutputParser<>(ActorsFilms.class); - - String format = outputParser.getFormat(); - String template = """ - Generate the filmography for the actor {actor}. - {format} - """; - PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("actor", actor, "format", format)); - Prompt prompt = new Prompt(promptTemplate.createMessage()); - ChatResponse response = chatClient.call(prompt); - Generation generation = response.getResult(); + return chatService.generateAsBean(actor); + } - return outputParser.parse(generation.getOutput().getContent()); + @PostMapping("/rag") + AIChatResponse chatUsingRag(@RequestBody AIChatRequest aiChatRequest) throws IOException { + return chatService.ragGenerate(aiChatRequest.query()); } } diff --git a/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java b/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java new file mode 100644 index 0000000..7691f52 --- /dev/null +++ b/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java @@ -0,0 +1,128 @@ +package com.example.ai.service; + +import com.example.ai.model.response.AIChatResponse; +import com.example.ai.model.response.ActorsFilms; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.ChatClient; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.parser.BeanOutputParser; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.SimpleVectorStore; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.io.Resource; +import org.springframework.stereotype.Service; + +@Service +public class ChatService { + + private static final Logger logger = LoggerFactory.getLogger(ChatService.class); + + @Value("classpath:/data/restaurants.json") + private Resource restaurantsResource; + + @Value("classpath:/rag-prompt-template.st") + private Resource ragPromptTemplate; + + private final EmbeddingClient embeddingClient; + private final ChatClient chatClient; + + public ChatService(EmbeddingClient embeddingClient, ChatClient chatClient) { + this.embeddingClient = embeddingClient; + this.chatClient = chatClient; + } + + public AIChatResponse chat(String query) { + String answer = chatClient.call(query); + return new AIChatResponse(answer); + } + + public AIChatResponse chatWithPrompt(String query) { + PromptTemplate promptTemplate = new PromptTemplate("Tell me a joke about {subject}"); + Prompt prompt = promptTemplate.create(Map.of("subject", query)); + ChatResponse response = chatClient.call(prompt); + Generation generation = response.getResult(); + String answer = (generation != null) ? generation.getOutput().getContent() : ""; + return new AIChatResponse(answer); + } + + public AIChatResponse chatWithSystemPrompt(String query) { + SystemMessage systemMessage = new SystemMessage("You are a sarcastic and funny chatbot"); + UserMessage userMessage = new UserMessage("Tell me a joke about " + query); + Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); + ChatResponse response = chatClient.call(prompt); + String answer = response.getResult().getOutput().getContent(); + return new AIChatResponse(answer); + } + + public AIChatResponse getEmbeddings(String query) { + List<Double> embed = embeddingClient.embed(query); + return new AIChatResponse(embed.toString()); + } + + public ActorsFilms generateAsBean(String actor) { + BeanOutputParser<ActorsFilms> outputParser = new BeanOutputParser<>(ActorsFilms.class); + + String format = outputParser.getFormat(); + String template = """ + Generate the filmography for the actor {actor}. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("actor", actor, "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + ChatResponse response = chatClient.call(prompt); + Generation generation = response.getResult(); + + return outputParser.parse(generation.getOutput().getContent()); + } + + public AIChatResponse ragGenerate(String query) throws IOException { + + // Step 1 - Load JSON document as Documents and save + logger.info("Loading JSON as Documents and save"); + SimpleVectorStore simpleVectorStore = new SimpleVectorStore(embeddingClient); + List<Document> linesDocuments = new ArrayList<>(); + + if (restaurantsResource.exists()) { // load existing vector store if exists + String contentAsString = restaurantsResource.getContentAsString(StandardCharsets.UTF_8); + // Convert lines to Documents in parallel + linesDocuments = + contentAsString.lines().parallel().map(Document::new).toList(); + simpleVectorStore.accept(linesDocuments); + } + + // Step 2 retrieve related documents to query + logger.info("Retrieving relevant documents"); + List<Document> similarDocuments = + simpleVectorStore.similaritySearch(SearchRequest.query(query).withTopK(2)); + logger.info(String.format("Found %s relevant documents.", similarDocuments.size())); + + List<String> contentList = + similarDocuments.stream().map(Document::getContent).toList(); + PromptTemplate promptTemplate = new PromptTemplate(ragPromptTemplate); + Map<String, Object> promptParameters = new HashMap<>(); + + promptParameters.put("input", query); + promptParameters.put("documents", String.join("\n", contentList)); + Prompt prompt = promptTemplate.create(promptParameters); + + ChatResponse response = chatClient.call(prompt); + Generation generation = response.getResult(); + String answer = (generation != null) ? generation.getOutput().getContent() : ""; + simpleVectorStore.delete(linesDocuments.stream().map(Document::getId).toList()); + return new AIChatResponse(answer); + } +} diff --git a/chatmodel-springai/src/main/resources/data/restaurants.json b/chatmodel-springai/src/main/resources/data/restaurants.json new file mode 100644 index 0000000..eb6ecdf --- /dev/null +++ b/chatmodel-springai/src/main/resources/data/restaurants.json @@ -0,0 +1,10 @@ +{"address": {"building": "1007", "coord": [-73.856077, 40.848447], "street": "Morris Park Ave", "zipcode": "10462"}, "borough": "Bronx", "cuisine": "Bakery", "grades": [{"date": {"$date": 1393804800000}, "grade": "A", "score": 2}, {"date": {"$date": 1378857600000}, "grade": "A", "score": 6}, {"date": {"$date": 1358985600000}, "grade": "A", "score": 10}, {"date": {"$date": 1322006400000}, "grade": "A", "score": 9}, {"date": {"$date": 1299715200000}, "grade": "B", "score": 14}], "name": "Morris Park Bake Shop", "restaurant_id": "30075445"} +{"address": {"building": "469", "coord": [-73.961704, 40.662942], "street": "Flatbush Avenue", "zipcode": "11225"}, "borough": "Brooklyn", "cuisine": "Hamburgers", "grades": [{"date": {"$date": 1419897600000}, "grade": "A", "score": 8}, {"date": {"$date": 1404172800000}, "grade": "B", "score": 23}, {"date": {"$date": 1367280000000}, "grade": "A", "score": 12}, {"date": {"$date": 1336435200000}, "grade": "A", "score": 12}], "name": "Wendy'S", "restaurant_id": "30112340"} +{"address": {"building": "351", "coord": [-73.98513559999999, 40.7676919], "street": "West 57 Street", "zipcode": "10019"}, "borough": "Manhattan", "cuisine": "Irish", "grades": [{"date": {"$date": 1409961600000}, "grade": "A", "score": 2}, {"date": {"$date": 1374451200000}, "grade": "A", "score": 11}, {"date": {"$date": 1343692800000}, "grade": "A", "score": 12}, {"date": {"$date": 1325116800000}, "grade": "A", "score": 12}], "name": "Dj Reynolds Pub And Restaurant", "restaurant_id": "30191841"} +{"address": {"building": "2780", "coord": [-73.98241999999999, 40.579505], "street": "Stillwell Avenue", "zipcode": "11224"}, "borough": "Brooklyn", "cuisine": "American ", "grades": [{"date": {"$date": 1402358400000}, "grade": "A", "score": 5}, {"date": {"$date": 1370390400000}, "grade": "A", "score": 7}, {"date": {"$date": 1334275200000}, "grade": "A", "score": 12}, {"date": {"$date": 1318377600000}, "grade": "A", "score": 12}], "name": "Riviera Caterer", "restaurant_id": "40356018"} +{"address": {"building": "97-22", "coord": [-73.8601152, 40.7311739], "street": "63 Road", "zipcode": "11374"}, "borough": "Queens", "cuisine": "Jewish/Kosher", "grades": [{"date": {"$date": 1416787200000}, "grade": "Z", "score": 20}, {"date": {"$date": 1358380800000}, "grade": "A", "score": 13}, {"date": {"$date": 1343865600000}, "grade": "A", "score": 13}, {"date": {"$date": 1323907200000}, "grade": "B", "score": 25}], "name": "Tov Kosher Kitchen", "restaurant_id": "40356068"} +{"address": {"building": "8825", "coord": [-73.8803827, 40.7643124], "street": "Astoria Boulevard", "zipcode": "11369"}, "borough": "Queens", "cuisine": "American ", "grades": [{"date": {"$date": 1416009600000}, "grade": "Z", "score": 38}, {"date": {"$date": 1398988800000}, "grade": "A", "score": 10}, {"date": {"$date": 1362182400000}, "grade": "A", "score": 7}, {"date": {"$date": 1328832000000}, "grade": "A", "score": 13}], "name": "Brunos On The Boulevard", "restaurant_id": "40356151"} +{"address": {"building": "2206", "coord": [-74.1377286, 40.6119572], "street": "Victory Boulevard", "zipcode": "10314"}, "borough": "Staten Island", "cuisine": "Jewish/Kosher", "grades": [{"date": {"$date": 1412553600000}, "grade": "A", "score": 9}, {"date": {"$date": 1400544000000}, "grade": "A", "score": 12}, {"date": {"$date": 1365033600000}, "grade": "A", "score": 12}, {"date": {"$date": 1327363200000}, "grade": "A", "score": 9}], "name": "Kosher Island", "restaurant_id": "40356442"} +{"address": {"building": "7114", "coord": [-73.9068506, 40.6199034], "street": "Avenue U", "zipcode": "11234"}, "borough": "Brooklyn", "cuisine": "Delicatessen", "grades": [{"date": {"$date": 1401321600000}, "grade": "A", "score": 10}, {"date": {"$date": 1389657600000}, "grade": "A", "score": 10}, {"date": {"$date": 1375488000000}, "grade": "A", "score": 8}, {"date": {"$date": 1342569600000}, "grade": "A", "score": 10}, {"date": {"$date": 1331251200000}, "grade": "A", "score": 13}, {"date": {"$date": 1318550400000}, "grade": "A", "score": 9}], "name": "Wilken'S Fine Food", "restaurant_id": "40356483"} +{"address": {"building": "6409", "coord": [-74.00528899999999, 40.628886], "street": "11 Avenue", "zipcode": "11219"}, "borough": "Brooklyn", "cuisine": "American ", "grades": [{"date": {"$date": 1405641600000}, "grade": "A", "score": 12}, {"date": {"$date": 1375142400000}, "grade": "A", "score": 12}, {"date": {"$date": 1360713600000}, "grade": "A", "score": 11}, {"date": {"$date": 1345075200000}, "grade": "A", "score": 2}, {"date": {"$date": 1313539200000}, "grade": "A", "score": 11}], "name": "Regina Caterers", "restaurant_id": "40356649"} +{"address": {"building": "1839", "coord": [-73.9482609, 40.6408271], "street": "Nostrand Avenue", "zipcode": "11226"}, "borough": "Brooklyn", "cuisine": "Ice Cream, Gelato, Yogurt, Ices", "grades": [{"date": {"$date": 1405296000000}, "grade": "A", "score": 12}, {"date": {"$date": 1373414400000}, "grade": "A", "score": 8}, {"date": {"$date": 1341964800000}, "grade": "A", "score": 5}, {"date": {"$date": 1329955200000}, "grade": "A", "score": 8}], "name": "Taste The Tropics Ice Cream", "restaurant_id": "40356731"} diff --git a/chatmodel-springai/src/main/resources/rag-prompt-template.st b/chatmodel-springai/src/main/resources/rag-prompt-template.st new file mode 100644 index 0000000..f46a20b --- /dev/null +++ b/chatmodel-springai/src/main/resources/rag-prompt-template.st @@ -0,0 +1,9 @@ +You are a helpful assistant, conversing with a user about the subjects contained in a set of documents. +Use the information from the DOCUMENTS section to provide accurate answers. If unsure or if the answer +isn't found in the DOCUMENTS section, simply state that you don't know the answer. + +QUESTION: +{input} + +DOCUMENTS: +{documents} diff --git a/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java b/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java index f1afeca..a58e818 100644 --- a/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java +++ b/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java @@ -35,7 +35,8 @@ void testChat() { .when() .post("/api/ai/chat") .then() - .statusCode(200) + .statusCode(HttpStatus.SC_OK) + .contentType(ContentType.JSON) .body("answer", containsString("Hello!")); } @@ -46,7 +47,8 @@ void chatWithPrompt() { .when() .post("/api/ai/chat-with-prompt") .then() - .statusCode(200) + .statusCode(HttpStatus.SC_OK) + .contentType(ContentType.JSON) .body("answer", containsString("Java")); } @@ -57,7 +59,8 @@ void chatWithSystemPrompt() { .when() .post("/api/ai/chat-with-system-prompt") .then() - .statusCode(200) + .statusCode(HttpStatus.SC_OK) + .contentType(ContentType.JSON) .body("answer", containsString("cricket")); } @@ -72,4 +75,16 @@ void outputParser() { .body("actor", is("Jr NTR")) .body("movies", hasSize(greaterThanOrEqualTo(25))); } + + @Test + void ragWithSimpleStore() { + given().contentType(ContentType.JSON) + .body(new AIChatRequest("which is the restaurant with highest grade that has cuisine as American ?")) + .when() + .post("/api/ai/rag") + .then() + .statusCode(HttpStatus.SC_OK) + .contentType(ContentType.JSON) + .body("answer", containsString("American cuisine is \"Regina Caterers\"")); + } }