diff --git a/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/AIConfig.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/AIConfig.java index d6fa248..c2a9289 100644 --- a/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/AIConfig.java +++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/AIConfig.java @@ -85,9 +85,14 @@ EmbeddingStore embeddingStore( .build(); // 2. Load an example document (medicaid-wa-faqs.pdf) - Resource pdfResource = resourceLoader.getResource("classpath:medicaid-wa-faqs.pdf"); + Resource pdfResource = resourceLoader.getResource("classpath:Rohit.pdf"); Document document = loadDocument(pdfResource.getFile().toPath(), new ApachePdfBoxDocumentParser()); + // URL url = new URL("https://en.wikipedia.org/wiki/MS_Dhoni"); + // Document htmlDocument = UrlDocumentLoader.load(url, new TextDocumentParser()); + // HtmlTextExtractor transformer = new HtmlTextExtractor(null, null, true); + // Document dhoniDocument = transformer.transform(htmlDocument); + // 3. Split the document into segments 500 tokens each // 4. Convert segments into embeddings // 5. Store embeddings into embedding store @@ -99,7 +104,7 @@ EmbeddingStore embeddingStore( .embeddingModel(embeddingModel) .embeddingStore(embeddingStore) .build(); - ingestor.ingest(document); + ingestor.ingest(document /*, dhoniDocument*/); return embeddingStore; } diff --git a/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/AICustomerSupportAgent.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/AICustomerSupportAgent.java index fe480e2..f941256 100644 --- a/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/AICustomerSupportAgent.java +++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/AICustomerSupportAgent.java @@ -1,25 +1,18 @@ package com.learning.ai.config; import com.learning.ai.domain.response.AICustomerSupportResponse; -import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.UserMessage; import dev.langchain4j.service.V; public interface AICustomerSupportAgent { - @SystemMessage({ + @UserMessage({ """ - You're assisting with questions about services offered by Carina. - Carina is a two-sided healthcare marketplace focusing on home care aides (caregivers) - and their Medicaid in-home care clients (adults and children with developmental disabilities and low income elderly population). - Carina's mission is to build online tools to bring good jobs to care workers, so care workers can provide the - best possible care for those who need it. + Tell me about {{question}}? as of {{current_date}} - Use the information from the DOCUMENTS section to provide accurate answers but act as if you knew this information innately. - If unsure, simply state that you don't know. - - DOCUMENTS: - {documents} - """ + Use the following information to answer the question: + {{information}} + """ }) - AICustomerSupportResponse chat(@V("documents") String documents); + AICustomerSupportResponse chat(@V("question") String question, @V("information") String information); } diff --git a/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/controller/CustomerSupportController.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/controller/CustomerSupportController.java index 6860ad5..6f08ff9 100644 --- a/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/controller/CustomerSupportController.java +++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/controller/CustomerSupportController.java @@ -1,8 +1,8 @@ package com.learning.ai.controller; -import com.learning.ai.config.AICustomerSupportAgent; import com.learning.ai.domain.request.AIChatRequest; import com.learning.ai.domain.response.AICustomerSupportResponse; +import com.learning.ai.service.CustomerSupportService; import jakarta.validation.Valid; import org.springframework.validation.annotation.Validated; import org.springframework.web.bind.annotation.PostMapping; @@ -15,14 +15,14 @@ @Validated public class CustomerSupportController { - private final AICustomerSupportAgent aiCustomerSupportAgent; + private final CustomerSupportService customerSupportService; - public CustomerSupportController(AICustomerSupportAgent aiCustomerSupportAgent) { - this.aiCustomerSupportAgent = aiCustomerSupportAgent; + public CustomerSupportController(CustomerSupportService customerSupportService) { + this.customerSupportService = customerSupportService; } @PostMapping("/chat") public AICustomerSupportResponse customerSupportChat(@RequestBody @Valid AIChatRequest aiChatRequest) { - return aiCustomerSupportAgent.chat(aiChatRequest.question()); + return customerSupportService.chat(aiChatRequest.question()); } } diff --git a/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/domain/response/AICustomerSupportResponse.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/domain/response/AICustomerSupportResponse.java index 033cc28..2c455cd 100644 --- a/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/domain/response/AICustomerSupportResponse.java +++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/domain/response/AICustomerSupportResponse.java @@ -1,3 +1,5 @@ package com.learning.ai.domain.response; -public record AICustomerSupportResponse(String response) {} +import java.util.List; + +public record AICustomerSupportResponse(String name, int age, List records, List trophiesWon) {} diff --git a/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/service/CustomerSupportService.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/service/CustomerSupportService.java new file mode 100644 index 0000000..ef1fe2a --- /dev/null +++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/service/CustomerSupportService.java @@ -0,0 +1,38 @@ +package com.learning.ai.service; + +import com.learning.ai.config.AICustomerSupportAgent; +import com.learning.ai.domain.response.AICustomerSupportResponse; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; +import java.util.List; +import org.springframework.stereotype.Service; + +@Service +public class CustomerSupportService { + + private final EmbeddingModel embeddingModel; + private final EmbeddingStore embeddingStore; + private final AICustomerSupportAgent aiCustomerSupportAgent; + + public CustomerSupportService( + EmbeddingModel embeddingModel, + EmbeddingStore embeddingStore, + AICustomerSupportAgent aiCustomerSupportAgent) { + this.embeddingModel = embeddingModel; + this.embeddingStore = embeddingStore; + this.aiCustomerSupportAgent = aiCustomerSupportAgent; + } + + public AICustomerSupportResponse chat(String question) { + + Embedding queryEmbedding = embeddingModel.embed(question).content(); + List> relevant = embeddingStore.findRelevant(queryEmbedding, 1); + EmbeddingMatch embeddingMatch = relevant.get(0); + + String embeddedText = embeddingMatch.embedded().text(); + return aiCustomerSupportAgent.chat(question, embeddedText); + } +} diff --git a/rag-langchain4j-AllMiniLmL6V2-llm/src/main/resources/Rohit.pdf b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/resources/Rohit.pdf new file mode 100644 index 0000000..b2215c8 Binary files /dev/null and b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/resources/Rohit.pdf differ diff --git a/rag-langchain4j-AllMiniLmL6V2-llm/src/main/resources/medicaid-wa-faqs.pdf b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/resources/medicaid-wa-faqs.pdf deleted file mode 100644 index fcfcfc0..0000000 Binary files a/rag-langchain4j-AllMiniLmL6V2-llm/src/main/resources/medicaid-wa-faqs.pdf and /dev/null differ diff --git a/rag-langchain4j-AllMiniLmL6V2-llm/src/test/java/com/learning/ai/LLMRagWithSpringBootTest.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/test/java/com/learning/ai/LLMRagWithSpringBootTest.java index 0e4ecc2..af996aa 100644 --- a/rag-langchain4j-AllMiniLmL6V2-llm/src/test/java/com/learning/ai/LLMRagWithSpringBootTest.java +++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/test/java/com/learning/ai/LLMRagWithSpringBootTest.java @@ -1,7 +1,7 @@ package com.learning.ai; import static io.restassured.RestAssured.given; -import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.is; import com.learning.ai.domain.request.AIChatRequest; import io.restassured.RestAssured; @@ -27,24 +27,15 @@ public void setUp() { } @Test - void whenRequestPost_thenOK() { + void whenRequestGetFromPdf_thenOK() { given().contentType(ContentType.JSON) - .body(new AIChatRequest( - "what should I know about the transition to consumer direct care network washington?")) - .when() - .request(Method.POST, "/api/ai/chat") - .then() - .statusCode(HttpStatus.SC_OK); - } - - @Test - void whenRequestGetTime_thenOK() { - given().contentType(ContentType.JSON) - .body(new AIChatRequest("What is the time now?")) + .body(new AIChatRequest("Who is Rohit")) .when() .request(Method.POST, "/api/ai/chat") .then() .statusCode(HttpStatus.SC_OK) - .body("response", notNullValue()); + .body("name", is("Rohit Gurunath Sharma")) + .log() + .all(); } }