Skip to content

Commit

Permalink
feat : implement RAG (#29)
Browse files Browse the repository at this point in the history
* feat : convert from get to post endpoint

* rename : project name

* add request and response, convert endpoint to post

* feat : implement RAG

* remove unused file
  • Loading branch information
rajadilipkolli authored Mar 28, 2024
1 parent 6adb623 commit 6d00a63
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,14 @@ EmbeddingStore<TextSegment> embeddingStore(
.build();

// 2. Load an example document (medicaid-wa-faqs.pdf)
Resource pdfResource = resourceLoader.getResource("classpath:medicaid-wa-faqs.pdf");
Resource pdfResource = resourceLoader.getResource("classpath:Rohit.pdf");
Document document = loadDocument(pdfResource.getFile().toPath(), new ApachePdfBoxDocumentParser());

// URL url = new URL("https://en.wikipedia.org/wiki/MS_Dhoni");
// Document htmlDocument = UrlDocumentLoader.load(url, new TextDocumentParser());
// HtmlTextExtractor transformer = new HtmlTextExtractor(null, null, true);
// Document dhoniDocument = transformer.transform(htmlDocument);

// 3. Split the document into segments 500 tokens each
// 4. Convert segments into embeddings
// 5. Store embeddings into embedding store
Expand All @@ -99,7 +104,7 @@ EmbeddingStore<TextSegment> embeddingStore(
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.build();
ingestor.ingest(document);
ingestor.ingest(document /*, dhoniDocument*/);

return embeddingStore;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
package com.learning.ai.config;

import com.learning.ai.domain.response.AICustomerSupportResponse;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.V;

public interface AICustomerSupportAgent {

@SystemMessage({
@UserMessage({
"""
You're assisting with questions about services offered by Carina.
Carina is a two-sided healthcare marketplace focusing on home care aides (caregivers)
and their Medicaid in-home care clients (adults and children with developmental disabilities and low income elderly population).
Carina's mission is to build online tools to bring good jobs to care workers, so care workers can provide the
best possible care for those who need it.
Tell me about {{question}}? as of {{current_date}}
Use the information from the DOCUMENTS section to provide accurate answers but act as if you knew this information innately.
If unsure, simply state that you don't know.
DOCUMENTS:
{documents}
"""
Use the following information to answer the question:
{{information}}
"""
})
AICustomerSupportResponse chat(@V("documents") String documents);
AICustomerSupportResponse chat(@V("question") String question, @V("information") String information);
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package com.learning.ai.controller;

import com.learning.ai.config.AICustomerSupportAgent;
import com.learning.ai.domain.request.AIChatRequest;
import com.learning.ai.domain.response.AICustomerSupportResponse;
import com.learning.ai.service.CustomerSupportService;
import jakarta.validation.Valid;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.PostMapping;
Expand All @@ -15,14 +15,14 @@
@Validated
public class CustomerSupportController {

private final AICustomerSupportAgent aiCustomerSupportAgent;
private final CustomerSupportService customerSupportService;

public CustomerSupportController(AICustomerSupportAgent aiCustomerSupportAgent) {
this.aiCustomerSupportAgent = aiCustomerSupportAgent;
public CustomerSupportController(CustomerSupportService customerSupportService) {
this.customerSupportService = customerSupportService;
}

@PostMapping("/chat")
public AICustomerSupportResponse customerSupportChat(@RequestBody @Valid AIChatRequest aiChatRequest) {
return aiCustomerSupportAgent.chat(aiChatRequest.question());
return customerSupportService.chat(aiChatRequest.question());
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
package com.learning.ai.domain.response;

public record AICustomerSupportResponse(String response) {}
import java.util.List;

public record AICustomerSupportResponse(String name, int age, List<String> records, List<String> trophiesWon) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package com.learning.ai.service;

import com.learning.ai.config.AICustomerSupportAgent;
import com.learning.ai.domain.response.AICustomerSupportResponse;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import java.util.List;
import org.springframework.stereotype.Service;

@Service
public class CustomerSupportService {

private final EmbeddingModel embeddingModel;
private final EmbeddingStore<TextSegment> embeddingStore;
private final AICustomerSupportAgent aiCustomerSupportAgent;

public CustomerSupportService(
EmbeddingModel embeddingModel,
EmbeddingStore<TextSegment> embeddingStore,
AICustomerSupportAgent aiCustomerSupportAgent) {
this.embeddingModel = embeddingModel;
this.embeddingStore = embeddingStore;
this.aiCustomerSupportAgent = aiCustomerSupportAgent;
}

public AICustomerSupportResponse chat(String question) {

Embedding queryEmbedding = embeddingModel.embed(question).content();
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(queryEmbedding, 1);
EmbeddingMatch<TextSegment> embeddingMatch = relevant.get(0);

String embeddedText = embeddingMatch.embedded().text();
return aiCustomerSupportAgent.chat(question, embeddedText);
}
}
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.learning.ai;

import static io.restassured.RestAssured.given;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.is;

import com.learning.ai.domain.request.AIChatRequest;
import io.restassured.RestAssured;
Expand All @@ -27,24 +27,15 @@ public void setUp() {
}

@Test
void whenRequestPost_thenOK() {
void whenRequestGetFromPdf_thenOK() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest(
"what should I know about the transition to consumer direct care network washington?"))
.when()
.request(Method.POST, "/api/ai/chat")
.then()
.statusCode(HttpStatus.SC_OK);
}

@Test
void whenRequestGetTime_thenOK() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("What is the time now?"))
.body(new AIChatRequest("Who is Rohit"))
.when()
.request(Method.POST, "/api/ai/chat")
.then()
.statusCode(HttpStatus.SC_OK)
.body("response", notNullValue());
.body("name", is("Rohit Gurunath Sharma"))
.log()
.all();
}
}

0 comments on commit 6d00a63

Please sign in to comment.