Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : implement simple RAG #38

Merged
merged 2 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,8 @@
import com.example.ai.model.request.AIChatRequest;
import com.example.ai.model.response.AIChatResponse;
import com.example.ai.model.response.ActorsFilms;
import java.util.List;
import java.util.Map;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.parser.BeanOutputParser;
import com.example.ai.service.ChatService;
import java.io.IOException;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
Expand All @@ -25,61 +16,39 @@
@RequestMapping("/api/ai")
public class ChatController {

private final ChatClient chatClient;
private final ChatService chatService;

private final EmbeddingClient embeddingClient;

ChatController(ChatClient chatClient, EmbeddingClient embeddingClient) {
this.chatClient = chatClient;
this.embeddingClient = embeddingClient;
ChatController(ChatService chatService) {
this.chatService = chatService;
}

@PostMapping("/chat")
AIChatResponse chat(@RequestBody AIChatRequest aiChatRequest) {
var answer = chatClient.call(aiChatRequest.query());
return new AIChatResponse(answer);
return chatService.chat(aiChatRequest.query());
}

@PostMapping("/chat-with-prompt")
AIChatResponse chatWithPrompt(@RequestBody AIChatRequest aiChatRequest) {
PromptTemplate promptTemplate = new PromptTemplate("Tell me a joke about {subject}");
Prompt prompt = promptTemplate.create(Map.of("subject", aiChatRequest.query()));
ChatResponse response = chatClient.call(prompt);
Generation generation = response.getResult();
String answer = (generation != null) ? generation.getOutput().getContent() : "";
return new AIChatResponse(answer);
return chatService.chatWithPrompt(aiChatRequest.query());
}

@PostMapping("/chat-with-system-prompt")
AIChatResponse chatWithSystemPrompt(@RequestBody AIChatRequest aiChatRequest) {
SystemMessage systemMessage = new SystemMessage("You are a sarcastic and funny chatbot");
UserMessage userMessage = new UserMessage("Tell me a joke about " + aiChatRequest.query());
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse response = chatClient.call(prompt);
String answer = response.getResult().getOutput().getContent();
return new AIChatResponse(answer);
return chatService.chatWithSystemPrompt(aiChatRequest.query());
}

@PostMapping("/emebedding-client-conversion")
AIChatResponse chatWithEmbeddingClient(@RequestBody AIChatRequest aiChatRequest) {
List<Double> embed = embeddingClient.embed(aiChatRequest.query());
return new AIChatResponse(embed.toString());
return chatService.getEmbeddings(aiChatRequest.query());
}

@GetMapping("/output")
public ActorsFilms generate(@RequestParam(value = "actor", defaultValue = "Jr NTR") String actor) {
BeanOutputParser<ActorsFilms> outputParser = new BeanOutputParser<>(ActorsFilms.class);

String format = outputParser.getFormat();
String template = """
Generate the filmography for the actor {actor}.
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("actor", actor, "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
ChatResponse response = chatClient.call(prompt);
Generation generation = response.getResult();
return chatService.generateAsBean(actor);
}

return outputParser.parse(generation.getOutput().getContent());
@PostMapping("/rag")
AIChatResponse chatUsingRag(@RequestBody AIChatRequest aiChatRequest) throws IOException {
return chatService.ragGenerate(aiChatRequest.query());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package com.example.ai.service;

import com.example.ai.model.response.AIChatResponse;
import com.example.ai.model.response.ActorsFilms;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.parser.BeanOutputParser;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;

@Service
public class ChatService {

private static final Logger logger = LoggerFactory.getLogger(ChatService.class);

@Value("classpath:/data/restaurants.json")
private Resource restaurantsResource;

@Value("classpath:/rag-prompt-template.st")
private Resource ragPromptTemplate;

private final EmbeddingClient embeddingClient;
private final ChatClient chatClient;

public ChatService(EmbeddingClient embeddingClient, ChatClient chatClient) {
this.embeddingClient = embeddingClient;
this.chatClient = chatClient;
}

public AIChatResponse chat(String query) {
String answer = chatClient.call(query);
return new AIChatResponse(answer);
}

public AIChatResponse chatWithPrompt(String query) {
PromptTemplate promptTemplate = new PromptTemplate("Tell me a joke about {subject}");
Prompt prompt = promptTemplate.create(Map.of("subject", query));
ChatResponse response = chatClient.call(prompt);
Generation generation = response.getResult();
String answer = (generation != null) ? generation.getOutput().getContent() : "";
return new AIChatResponse(answer);
}

public AIChatResponse chatWithSystemPrompt(String query) {
SystemMessage systemMessage = new SystemMessage("You are a sarcastic and funny chatbot");
UserMessage userMessage = new UserMessage("Tell me a joke about " + query);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse response = chatClient.call(prompt);
String answer = response.getResult().getOutput().getContent();
return new AIChatResponse(answer);
}

public AIChatResponse getEmbeddings(String query) {
List<Double> embed = embeddingClient.embed(query);
return new AIChatResponse(embed.toString());
}

public ActorsFilms generateAsBean(String actor) {
BeanOutputParser<ActorsFilms> outputParser = new BeanOutputParser<>(ActorsFilms.class);

String format = outputParser.getFormat();
String template = """
Generate the filmography for the actor {actor}.
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("actor", actor, "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
ChatResponse response = chatClient.call(prompt);
Generation generation = response.getResult();

return outputParser.parse(generation.getOutput().getContent());
}

public AIChatResponse ragGenerate(String query) throws IOException {

// Step 1 - Load JSON document as Documents and save
logger.info("Loading JSON as Documents and save");
SimpleVectorStore simpleVectorStore = new SimpleVectorStore(embeddingClient);
List<Document> linesDocuments = new ArrayList<>();

if (restaurantsResource.exists()) { // load existing vector store if exists
String contentAsString = restaurantsResource.getContentAsString(StandardCharsets.UTF_8);
// Convert lines to Documents in parallel
linesDocuments =
contentAsString.lines().parallel().map(Document::new).toList();
simpleVectorStore.accept(linesDocuments);
}

// Step 2 retrieve related documents to query
logger.info("Retrieving relevant documents");
List<Document> similarDocuments =
simpleVectorStore.similaritySearch(SearchRequest.query(query).withTopK(2));
logger.info(String.format("Found %s relevant documents.", similarDocuments.size()));

List<String> contentList =
similarDocuments.stream().map(Document::getContent).toList();
PromptTemplate promptTemplate = new PromptTemplate(ragPromptTemplate);
Map<String, Object> promptParameters = new HashMap<>();

promptParameters.put("input", query);
promptParameters.put("documents", String.join("\n", contentList));
Prompt prompt = promptTemplate.create(promptParameters);

ChatResponse response = chatClient.call(prompt);
Generation generation = response.getResult();
String answer = (generation != null) ? generation.getOutput().getContent() : "";
simpleVectorStore.delete(linesDocuments.stream().map(Document::getId).toList());
return new AIChatResponse(answer);
}
}
10 changes: 10 additions & 0 deletions chatmodel-springai/src/main/resources/data/restaurants.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{"address": {"building": "1007", "coord": [-73.856077, 40.848447], "street": "Morris Park Ave", "zipcode": "10462"}, "borough": "Bronx", "cuisine": "Bakery", "grades": [{"date": {"$date": 1393804800000}, "grade": "A", "score": 2}, {"date": {"$date": 1378857600000}, "grade": "A", "score": 6}, {"date": {"$date": 1358985600000}, "grade": "A", "score": 10}, {"date": {"$date": 1322006400000}, "grade": "A", "score": 9}, {"date": {"$date": 1299715200000}, "grade": "B", "score": 14}], "name": "Morris Park Bake Shop", "restaurant_id": "30075445"}
{"address": {"building": "469", "coord": [-73.961704, 40.662942], "street": "Flatbush Avenue", "zipcode": "11225"}, "borough": "Brooklyn", "cuisine": "Hamburgers", "grades": [{"date": {"$date": 1419897600000}, "grade": "A", "score": 8}, {"date": {"$date": 1404172800000}, "grade": "B", "score": 23}, {"date": {"$date": 1367280000000}, "grade": "A", "score": 12}, {"date": {"$date": 1336435200000}, "grade": "A", "score": 12}], "name": "Wendy'S", "restaurant_id": "30112340"}
{"address": {"building": "351", "coord": [-73.98513559999999, 40.7676919], "street": "West 57 Street", "zipcode": "10019"}, "borough": "Manhattan", "cuisine": "Irish", "grades": [{"date": {"$date": 1409961600000}, "grade": "A", "score": 2}, {"date": {"$date": 1374451200000}, "grade": "A", "score": 11}, {"date": {"$date": 1343692800000}, "grade": "A", "score": 12}, {"date": {"$date": 1325116800000}, "grade": "A", "score": 12}], "name": "Dj Reynolds Pub And Restaurant", "restaurant_id": "30191841"}
{"address": {"building": "2780", "coord": [-73.98241999999999, 40.579505], "street": "Stillwell Avenue", "zipcode": "11224"}, "borough": "Brooklyn", "cuisine": "American ", "grades": [{"date": {"$date": 1402358400000}, "grade": "A", "score": 5}, {"date": {"$date": 1370390400000}, "grade": "A", "score": 7}, {"date": {"$date": 1334275200000}, "grade": "A", "score": 12}, {"date": {"$date": 1318377600000}, "grade": "A", "score": 12}], "name": "Riviera Caterer", "restaurant_id": "40356018"}
{"address": {"building": "97-22", "coord": [-73.8601152, 40.7311739], "street": "63 Road", "zipcode": "11374"}, "borough": "Queens", "cuisine": "Jewish/Kosher", "grades": [{"date": {"$date": 1416787200000}, "grade": "Z", "score": 20}, {"date": {"$date": 1358380800000}, "grade": "A", "score": 13}, {"date": {"$date": 1343865600000}, "grade": "A", "score": 13}, {"date": {"$date": 1323907200000}, "grade": "B", "score": 25}], "name": "Tov Kosher Kitchen", "restaurant_id": "40356068"}
{"address": {"building": "8825", "coord": [-73.8803827, 40.7643124], "street": "Astoria Boulevard", "zipcode": "11369"}, "borough": "Queens", "cuisine": "American ", "grades": [{"date": {"$date": 1416009600000}, "grade": "Z", "score": 38}, {"date": {"$date": 1398988800000}, "grade": "A", "score": 10}, {"date": {"$date": 1362182400000}, "grade": "A", "score": 7}, {"date": {"$date": 1328832000000}, "grade": "A", "score": 13}], "name": "Brunos On The Boulevard", "restaurant_id": "40356151"}
{"address": {"building": "2206", "coord": [-74.1377286, 40.6119572], "street": "Victory Boulevard", "zipcode": "10314"}, "borough": "Staten Island", "cuisine": "Jewish/Kosher", "grades": [{"date": {"$date": 1412553600000}, "grade": "A", "score": 9}, {"date": {"$date": 1400544000000}, "grade": "A", "score": 12}, {"date": {"$date": 1365033600000}, "grade": "A", "score": 12}, {"date": {"$date": 1327363200000}, "grade": "A", "score": 9}], "name": "Kosher Island", "restaurant_id": "40356442"}
{"address": {"building": "7114", "coord": [-73.9068506, 40.6199034], "street": "Avenue U", "zipcode": "11234"}, "borough": "Brooklyn", "cuisine": "Delicatessen", "grades": [{"date": {"$date": 1401321600000}, "grade": "A", "score": 10}, {"date": {"$date": 1389657600000}, "grade": "A", "score": 10}, {"date": {"$date": 1375488000000}, "grade": "A", "score": 8}, {"date": {"$date": 1342569600000}, "grade": "A", "score": 10}, {"date": {"$date": 1331251200000}, "grade": "A", "score": 13}, {"date": {"$date": 1318550400000}, "grade": "A", "score": 9}], "name": "Wilken'S Fine Food", "restaurant_id": "40356483"}
{"address": {"building": "6409", "coord": [-74.00528899999999, 40.628886], "street": "11 Avenue", "zipcode": "11219"}, "borough": "Brooklyn", "cuisine": "American ", "grades": [{"date": {"$date": 1405641600000}, "grade": "A", "score": 12}, {"date": {"$date": 1375142400000}, "grade": "A", "score": 12}, {"date": {"$date": 1360713600000}, "grade": "A", "score": 11}, {"date": {"$date": 1345075200000}, "grade": "A", "score": 2}, {"date": {"$date": 1313539200000}, "grade": "A", "score": 11}], "name": "Regina Caterers", "restaurant_id": "40356649"}
{"address": {"building": "1839", "coord": [-73.9482609, 40.6408271], "street": "Nostrand Avenue", "zipcode": "11226"}, "borough": "Brooklyn", "cuisine": "Ice Cream, Gelato, Yogurt, Ices", "grades": [{"date": {"$date": 1405296000000}, "grade": "A", "score": 12}, {"date": {"$date": 1373414400000}, "grade": "A", "score": 8}, {"date": {"$date": 1341964800000}, "grade": "A", "score": 5}, {"date": {"$date": 1329955200000}, "grade": "A", "score": 8}], "name": "Taste The Tropics Ice Cream", "restaurant_id": "40356731"}
9 changes: 9 additions & 0 deletions chatmodel-springai/src/main/resources/rag-prompt-template.st
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
You are a helpful assistant, conversing with a user about the subjects contained in a set of documents.
Use the information from the DOCUMENTS section to provide accurate answers. If unsure or if the answer
isn't found in the DOCUMENTS section, simply state that you don't know the answer.

QUESTION:
{input}

DOCUMENTS:
{documents}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ void testChat() {
.when()
.post("/api/ai/chat")
.then()
.statusCode(200)
.statusCode(HttpStatus.SC_OK)
.contentType(ContentType.JSON)
.body("answer", containsString("Hello!"));
}

Expand All @@ -46,7 +47,8 @@ void chatWithPrompt() {
.when()
.post("/api/ai/chat-with-prompt")
.then()
.statusCode(200)
.statusCode(HttpStatus.SC_OK)
.contentType(ContentType.JSON)
.body("answer", containsString("Java"));
}

Expand All @@ -57,7 +59,8 @@ void chatWithSystemPrompt() {
.when()
.post("/api/ai/chat-with-system-prompt")
.then()
.statusCode(200)
.statusCode(HttpStatus.SC_OK)
.contentType(ContentType.JSON)
.body("answer", containsString("cricket"));
}

Expand All @@ -72,4 +75,16 @@ void outputParser() {
.body("actor", is("Jr NTR"))
.body("movies", hasSize(greaterThanOrEqualTo(25)));
}

@Test
void ragWithSimpleStore() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("which is the restaurant with highest grade that has cuisine as American ?"))
.when()
.post("/api/ai/rag")
.then()
.statusCode(HttpStatus.SC_OK)
.contentType(ContentType.JSON)
.body("answer", containsString("American cuisine is \"Regina Caterers\""));
}
}