From a2d689d70331bc19d429af598484e69b9010ecb6 Mon Sep 17 00:00:00 2001
From: Raja Kolli <rajadilipkolli@gmail.com>
Date: Sun, 7 Apr 2024 15:01:32 +0530
Subject: [PATCH] feat : implement simple RAG (#38)

* feat : implement simple RAG

* update assertion
---
 .../example/ai/controller/ChatController.java |  59 ++------
 .../com/example/ai/service/ChatService.java   | 128 ++++++++++++++++++
 .../src/main/resources/data/restaurants.json  |  10 ++
 .../src/main/resources/rag-prompt-template.st |   9 ++
 .../ai/controller/ChatControllerTest.java     |  21 ++-
 5 files changed, 179 insertions(+), 48 deletions(-)
 create mode 100644 chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java
 create mode 100644 chatmodel-springai/src/main/resources/data/restaurants.json
 create mode 100644 chatmodel-springai/src/main/resources/rag-prompt-template.st

diff --git a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java
index 77b131f..a9ef21d 100644
--- a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java
+++ b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java
@@ -3,17 +3,8 @@
 import com.example.ai.model.request.AIChatRequest;
 import com.example.ai.model.response.AIChatResponse;
 import com.example.ai.model.response.ActorsFilms;
-import java.util.List;
-import java.util.Map;
-import org.springframework.ai.chat.ChatClient;
-import org.springframework.ai.chat.ChatResponse;
-import org.springframework.ai.chat.Generation;
-import org.springframework.ai.chat.messages.SystemMessage;
-import org.springframework.ai.chat.messages.UserMessage;
-import org.springframework.ai.chat.prompt.Prompt;
-import org.springframework.ai.chat.prompt.PromptTemplate;
-import org.springframework.ai.embedding.EmbeddingClient;
-import org.springframework.ai.parser.BeanOutputParser;
+import com.example.ai.service.ChatService;
+import java.io.IOException;
 import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.PostMapping;
 import org.springframework.web.bind.annotation.RequestBody;
@@ -25,61 +16,39 @@
 @RequestMapping("/api/ai")
 public class ChatController {
 
-    private final ChatClient chatClient;
+    private final ChatService chatService;
 
-    private final EmbeddingClient embeddingClient;
-
-    ChatController(ChatClient chatClient, EmbeddingClient embeddingClient) {
-        this.chatClient = chatClient;
-        this.embeddingClient = embeddingClient;
+    ChatController(ChatService chatService) {
+        this.chatService = chatService;
     }
 
     @PostMapping("/chat")
     AIChatResponse chat(@RequestBody AIChatRequest aiChatRequest) {
-        var answer = chatClient.call(aiChatRequest.query());
-        return new AIChatResponse(answer);
+        return chatService.chat(aiChatRequest.query());
     }
 
     @PostMapping("/chat-with-prompt")
     AIChatResponse chatWithPrompt(@RequestBody AIChatRequest aiChatRequest) {
-        PromptTemplate promptTemplate = new PromptTemplate("Tell me a joke about {subject}");
-        Prompt prompt = promptTemplate.create(Map.of("subject", aiChatRequest.query()));
-        ChatResponse response = chatClient.call(prompt);
-        Generation generation = response.getResult();
-        String answer = (generation != null) ? generation.getOutput().getContent() : "";
-        return new AIChatResponse(answer);
+        return chatService.chatWithPrompt(aiChatRequest.query());
     }
 
     @PostMapping("/chat-with-system-prompt")
     AIChatResponse chatWithSystemPrompt(@RequestBody AIChatRequest aiChatRequest) {
-        SystemMessage systemMessage = new SystemMessage("You are a sarcastic and funny chatbot");
-        UserMessage userMessage = new UserMessage("Tell me a joke about " + aiChatRequest.query());
-        Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
-        ChatResponse response = chatClient.call(prompt);
-        String answer = response.getResult().getOutput().getContent();
-        return new AIChatResponse(answer);
+        return chatService.chatWithSystemPrompt(aiChatRequest.query());
     }
 
     @PostMapping("/emebedding-client-conversion")
     AIChatResponse chatWithEmbeddingClient(@RequestBody AIChatRequest aiChatRequest) {
-        List<Double> embed = embeddingClient.embed(aiChatRequest.query());
-        return new AIChatResponse(embed.toString());
+        return chatService.getEmbeddings(aiChatRequest.query());
     }
 
     @GetMapping("/output")
     public ActorsFilms generate(@RequestParam(value = "actor", defaultValue = "Jr NTR") String actor) {
-        BeanOutputParser<ActorsFilms> outputParser = new BeanOutputParser<>(ActorsFilms.class);
-
-        String format = outputParser.getFormat();
-        String template = """
-				Generate the filmography for the actor {actor}.
-				{format}
-				""";
-        PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("actor", actor, "format", format));
-        Prompt prompt = new Prompt(promptTemplate.createMessage());
-        ChatResponse response = chatClient.call(prompt);
-        Generation generation = response.getResult();
+        return chatService.generateAsBean(actor);
+    }
 
-        return outputParser.parse(generation.getOutput().getContent());
+    @PostMapping("/rag")
+    AIChatResponse chatUsingRag(@RequestBody AIChatRequest aiChatRequest) throws IOException {
+        return chatService.ragGenerate(aiChatRequest.query());
     }
 }
diff --git a/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java b/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java
new file mode 100644
index 0000000..7691f52
--- /dev/null
+++ b/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java
@@ -0,0 +1,128 @@
+package com.example.ai.service;
+
+import com.example.ai.model.response.AIChatResponse;
+import com.example.ai.model.response.ActorsFilms;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.chat.ChatClient;
+import org.springframework.ai.chat.ChatResponse;
+import org.springframework.ai.chat.Generation;
+import org.springframework.ai.chat.messages.SystemMessage;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.chat.prompt.PromptTemplate;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.embedding.EmbeddingClient;
+import org.springframework.ai.parser.BeanOutputParser;
+import org.springframework.ai.vectorstore.SearchRequest;
+import org.springframework.ai.vectorstore.SimpleVectorStore;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.core.io.Resource;
+import org.springframework.stereotype.Service;
+
+@Service
+public class ChatService {
+
+    private static final Logger logger = LoggerFactory.getLogger(ChatService.class);
+
+    @Value("classpath:/data/restaurants.json")
+    private Resource restaurantsResource;
+
+    @Value("classpath:/rag-prompt-template.st")
+    private Resource ragPromptTemplate;
+
+    private final EmbeddingClient embeddingClient;
+    private final ChatClient chatClient;
+
+    public ChatService(EmbeddingClient embeddingClient, ChatClient chatClient) {
+        this.embeddingClient = embeddingClient;
+        this.chatClient = chatClient;
+    }
+
+    public AIChatResponse chat(String query) {
+        String answer = chatClient.call(query);
+        return new AIChatResponse(answer);
+    }
+
+    public AIChatResponse chatWithPrompt(String query) {
+        PromptTemplate promptTemplate = new PromptTemplate("Tell me a joke about {subject}");
+        Prompt prompt = promptTemplate.create(Map.of("subject", query));
+        ChatResponse response = chatClient.call(prompt);
+        Generation generation = response.getResult();
+        String answer = (generation != null) ? generation.getOutput().getContent() : "";
+        return new AIChatResponse(answer);
+    }
+
+    public AIChatResponse chatWithSystemPrompt(String query) {
+        SystemMessage systemMessage = new SystemMessage("You are a sarcastic and funny chatbot");
+        UserMessage userMessage = new UserMessage("Tell me a joke about " + query);
+        Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
+        ChatResponse response = chatClient.call(prompt);
+        String answer = response.getResult().getOutput().getContent();
+        return new AIChatResponse(answer);
+    }
+
+    public AIChatResponse getEmbeddings(String query) {
+        List<Double> embed = embeddingClient.embed(query);
+        return new AIChatResponse(embed.toString());
+    }
+
+    public ActorsFilms generateAsBean(String actor) {
+        BeanOutputParser<ActorsFilms> outputParser = new BeanOutputParser<>(ActorsFilms.class);
+
+        String format = outputParser.getFormat();
+        String template = """
+				Generate the filmography for the actor {actor}.
+				{format}
+				""";
+        PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("actor", actor, "format", format));
+        Prompt prompt = new Prompt(promptTemplate.createMessage());
+        ChatResponse response = chatClient.call(prompt);
+        Generation generation = response.getResult();
+
+        return outputParser.parse(generation.getOutput().getContent());
+    }
+
+    public AIChatResponse ragGenerate(String query) throws IOException {
+
+        // Step 1 - Load JSON document as Documents and save
+        logger.info("Loading JSON as Documents and save");
+        SimpleVectorStore simpleVectorStore = new SimpleVectorStore(embeddingClient);
+        List<Document> linesDocuments = new ArrayList<>();
+
+        if (restaurantsResource.exists()) { // load existing vector store if exists
+            String contentAsString = restaurantsResource.getContentAsString(StandardCharsets.UTF_8);
+            // Convert lines to Documents in parallel
+            linesDocuments =
+                    contentAsString.lines().parallel().map(Document::new).toList();
+            simpleVectorStore.accept(linesDocuments);
+        }
+
+        // Step 2 retrieve related documents to query
+        logger.info("Retrieving relevant documents");
+        List<Document> similarDocuments =
+                simpleVectorStore.similaritySearch(SearchRequest.query(query).withTopK(2));
+        logger.info(String.format("Found %s relevant documents.", similarDocuments.size()));
+
+        List<String> contentList =
+                similarDocuments.stream().map(Document::getContent).toList();
+        PromptTemplate promptTemplate = new PromptTemplate(ragPromptTemplate);
+        Map<String, Object> promptParameters = new HashMap<>();
+
+        promptParameters.put("input", query);
+        promptParameters.put("documents", String.join("\n", contentList));
+        Prompt prompt = promptTemplate.create(promptParameters);
+
+        ChatResponse response = chatClient.call(prompt);
+        Generation generation = response.getResult();
+        String answer = (generation != null) ? generation.getOutput().getContent() : "";
+        simpleVectorStore.delete(linesDocuments.stream().map(Document::getId).toList());
+        return new AIChatResponse(answer);
+    }
+}
diff --git a/chatmodel-springai/src/main/resources/data/restaurants.json b/chatmodel-springai/src/main/resources/data/restaurants.json
new file mode 100644
index 0000000..eb6ecdf
--- /dev/null
+++ b/chatmodel-springai/src/main/resources/data/restaurants.json
@@ -0,0 +1,10 @@
+{"address": {"building": "1007", "coord": [-73.856077, 40.848447], "street": "Morris Park Ave", "zipcode": "10462"}, "borough": "Bronx", "cuisine": "Bakery", "grades": [{"date": {"$date": 1393804800000}, "grade": "A", "score": 2}, {"date": {"$date": 1378857600000}, "grade": "A", "score": 6}, {"date": {"$date": 1358985600000}, "grade": "A", "score": 10}, {"date": {"$date": 1322006400000}, "grade": "A", "score": 9}, {"date": {"$date": 1299715200000}, "grade": "B", "score": 14}], "name": "Morris Park Bake Shop", "restaurant_id": "30075445"}
+{"address": {"building": "469", "coord": [-73.961704, 40.662942], "street": "Flatbush Avenue", "zipcode": "11225"}, "borough": "Brooklyn", "cuisine": "Hamburgers", "grades": [{"date": {"$date": 1419897600000}, "grade": "A", "score": 8}, {"date": {"$date": 1404172800000}, "grade": "B", "score": 23}, {"date": {"$date": 1367280000000}, "grade": "A", "score": 12}, {"date": {"$date": 1336435200000}, "grade": "A", "score": 12}], "name": "Wendy'S", "restaurant_id": "30112340"}
+{"address": {"building": "351", "coord": [-73.98513559999999, 40.7676919], "street": "West   57 Street", "zipcode": "10019"}, "borough": "Manhattan", "cuisine": "Irish", "grades": [{"date": {"$date": 1409961600000}, "grade": "A", "score": 2}, {"date": {"$date": 1374451200000}, "grade": "A", "score": 11}, {"date": {"$date": 1343692800000}, "grade": "A", "score": 12}, {"date": {"$date": 1325116800000}, "grade": "A", "score": 12}], "name": "Dj Reynolds Pub And Restaurant", "restaurant_id": "30191841"}
+{"address": {"building": "2780", "coord": [-73.98241999999999, 40.579505], "street": "Stillwell Avenue", "zipcode": "11224"}, "borough": "Brooklyn", "cuisine": "American ", "grades": [{"date": {"$date": 1402358400000}, "grade": "A", "score": 5}, {"date": {"$date": 1370390400000}, "grade": "A", "score": 7}, {"date": {"$date": 1334275200000}, "grade": "A", "score": 12}, {"date": {"$date": 1318377600000}, "grade": "A", "score": 12}], "name": "Riviera Caterer", "restaurant_id": "40356018"}
+{"address": {"building": "97-22", "coord": [-73.8601152, 40.7311739], "street": "63 Road", "zipcode": "11374"}, "borough": "Queens", "cuisine": "Jewish/Kosher", "grades": [{"date": {"$date": 1416787200000}, "grade": "Z", "score": 20}, {"date": {"$date": 1358380800000}, "grade": "A", "score": 13}, {"date": {"$date": 1343865600000}, "grade": "A", "score": 13}, {"date": {"$date": 1323907200000}, "grade": "B", "score": 25}], "name": "Tov Kosher Kitchen", "restaurant_id": "40356068"}
+{"address": {"building": "8825", "coord": [-73.8803827, 40.7643124], "street": "Astoria Boulevard", "zipcode": "11369"}, "borough": "Queens", "cuisine": "American ", "grades": [{"date": {"$date": 1416009600000}, "grade": "Z", "score": 38}, {"date": {"$date": 1398988800000}, "grade": "A", "score": 10}, {"date": {"$date": 1362182400000}, "grade": "A", "score": 7}, {"date": {"$date": 1328832000000}, "grade": "A", "score": 13}], "name": "Brunos On The Boulevard", "restaurant_id": "40356151"}
+{"address": {"building": "2206", "coord": [-74.1377286, 40.6119572], "street": "Victory Boulevard", "zipcode": "10314"}, "borough": "Staten Island", "cuisine": "Jewish/Kosher", "grades": [{"date": {"$date": 1412553600000}, "grade": "A", "score": 9}, {"date": {"$date": 1400544000000}, "grade": "A", "score": 12}, {"date": {"$date": 1365033600000}, "grade": "A", "score": 12}, {"date": {"$date": 1327363200000}, "grade": "A", "score": 9}], "name": "Kosher Island", "restaurant_id": "40356442"}
+{"address": {"building": "7114", "coord": [-73.9068506, 40.6199034], "street": "Avenue U", "zipcode": "11234"}, "borough": "Brooklyn", "cuisine": "Delicatessen", "grades": [{"date": {"$date": 1401321600000}, "grade": "A", "score": 10}, {"date": {"$date": 1389657600000}, "grade": "A", "score": 10}, {"date": {"$date": 1375488000000}, "grade": "A", "score": 8}, {"date": {"$date": 1342569600000}, "grade": "A", "score": 10}, {"date": {"$date": 1331251200000}, "grade": "A", "score": 13}, {"date": {"$date": 1318550400000}, "grade": "A", "score": 9}], "name": "Wilken'S Fine Food", "restaurant_id": "40356483"}
+{"address": {"building": "6409", "coord": [-74.00528899999999, 40.628886], "street": "11 Avenue", "zipcode": "11219"}, "borough": "Brooklyn", "cuisine": "American ", "grades": [{"date": {"$date": 1405641600000}, "grade": "A", "score": 12}, {"date": {"$date": 1375142400000}, "grade": "A", "score": 12}, {"date": {"$date": 1360713600000}, "grade": "A", "score": 11}, {"date": {"$date": 1345075200000}, "grade": "A", "score": 2}, {"date": {"$date": 1313539200000}, "grade": "A", "score": 11}], "name": "Regina Caterers", "restaurant_id": "40356649"}
+{"address": {"building": "1839", "coord": [-73.9482609, 40.6408271], "street": "Nostrand Avenue", "zipcode": "11226"}, "borough": "Brooklyn", "cuisine": "Ice Cream, Gelato, Yogurt, Ices", "grades": [{"date": {"$date": 1405296000000}, "grade": "A", "score": 12}, {"date": {"$date": 1373414400000}, "grade": "A", "score": 8}, {"date": {"$date": 1341964800000}, "grade": "A", "score": 5}, {"date": {"$date": 1329955200000}, "grade": "A", "score": 8}], "name": "Taste The Tropics Ice Cream", "restaurant_id": "40356731"}
diff --git a/chatmodel-springai/src/main/resources/rag-prompt-template.st b/chatmodel-springai/src/main/resources/rag-prompt-template.st
new file mode 100644
index 0000000..f46a20b
--- /dev/null
+++ b/chatmodel-springai/src/main/resources/rag-prompt-template.st
@@ -0,0 +1,9 @@
+You are a helpful assistant, conversing with a user about the subjects contained in a set of documents.
+Use the information from the DOCUMENTS section to provide accurate answers. If unsure or if the answer
+isn't found in the DOCUMENTS section, simply state that you don't know the answer.
+
+QUESTION:
+{input}
+
+DOCUMENTS:
+{documents}
diff --git a/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java b/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java
index f1afeca..a58e818 100644
--- a/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java
+++ b/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java
@@ -35,7 +35,8 @@ void testChat() {
                 .when()
                 .post("/api/ai/chat")
                 .then()
-                .statusCode(200)
+                .statusCode(HttpStatus.SC_OK)
+                .contentType(ContentType.JSON)
                 .body("answer", containsString("Hello!"));
     }
 
@@ -46,7 +47,8 @@ void chatWithPrompt() {
                 .when()
                 .post("/api/ai/chat-with-prompt")
                 .then()
-                .statusCode(200)
+                .statusCode(HttpStatus.SC_OK)
+                .contentType(ContentType.JSON)
                 .body("answer", containsString("Java"));
     }
 
@@ -57,7 +59,8 @@ void chatWithSystemPrompt() {
                 .when()
                 .post("/api/ai/chat-with-system-prompt")
                 .then()
-                .statusCode(200)
+                .statusCode(HttpStatus.SC_OK)
+                .contentType(ContentType.JSON)
                 .body("answer", containsString("cricket"));
     }
 
@@ -72,4 +75,16 @@ void outputParser() {
                 .body("actor", is("Jr NTR"))
                 .body("movies", hasSize(greaterThanOrEqualTo(25)));
     }
+
+    @Test
+    void ragWithSimpleStore() {
+        given().contentType(ContentType.JSON)
+                .body(new AIChatRequest("which is the restaurant with highest grade that has cuisine as American ?"))
+                .when()
+                .post("/api/ai/rag")
+                .then()
+                .statusCode(HttpStatus.SC_OK)
+                .contentType(ContentType.JSON)
+                .body("answer", containsString("American cuisine is \"Regina Caterers\""));
+    }
 }