Skip to content

Commit

Permalink
feat : convert from get to post endpoint (#27)
Browse files Browse the repository at this point in the history
* feat : convert from get to post endpoint

* rename : project name

* add request and response, convert endpoint to post
  • Loading branch information
rajadilipkolli authored Mar 28, 2024
1 parent 4931dae commit 6adb623
Show file tree
Hide file tree
Showing 29 changed files with 143 additions and 91 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
name: llm-rag-with-langchain4j-spring-boot CI Build
name: rag-langchain4j-AllMiniLmL6V2-llm CI Build

on:
push:
paths:
- "llm-rag-with-langchain4j-spring-boot/**"
- "rag-langchain4j-AllMiniLmL6V2-llm/**"
branches: [main]
pull_request:
paths:
- "llm-rag-with-langchain4j-spring-boot/**"
- "rag-langchain4j-AllMiniLmL6V2-llm/**"
types:
- opened
- synchronize
Expand All @@ -19,7 +19,7 @@ jobs:
runs-on: ubuntu-latest
defaults:
run:
working-directory: llm-rag-with-langchain4j-spring-boot
working-directory: rag-langchain4j-AllMiniLmL6V2-llm
strategy:
matrix:
distribution: [ 'temurin' ]
Expand Down

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>org.example.ai</groupId>
<artifactId>llm-rag-with-langchain4j-spring-boot</artifactId>
<artifactId>rag-langchain4j-AllMiniLmL6V2-llm</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>llm-rag-with-langchain4j-spring-boot</name>
<name>rag-langchain4j-AllMiniLmL6V2-llm</name>
<description>Demo project for Spring Boot</description>

<properties>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import static dev.langchain4j.data.document.loader.FileSystemDocumentLoader.loadDocument;

import com.zaxxer.hikari.HikariDataSource;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentSplitter;
import dev.langchain4j.data.document.parser.apache.pdfbox.ApachePdfBoxDocumentParser;
Expand All @@ -22,7 +21,7 @@
import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore;
import java.io.IOException;
import java.net.URI;
import javax.sql.DataSource;
import org.springframework.boot.autoconfigure.jdbc.JdbcConnectionDetails;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource;
Expand Down Expand Up @@ -62,13 +61,13 @@ EmbeddingModel embeddingModel() {

@Bean
EmbeddingStore<TextSegment> embeddingStore(
EmbeddingModel embeddingModel, ResourceLoader resourceLoader, DataSource dataSource) throws IOException {
EmbeddingModel embeddingModel, ResourceLoader resourceLoader, JdbcConnectionDetails jdbcConnectionDetails)
throws IOException {

// Normally, you would already have your embedding store filled with your data.
// However, for the purpose of this demonstration, we will:

HikariDataSource hikariDataSource = (HikariDataSource) dataSource;
String jdbcUrl = hikariDataSource.getJdbcUrl();
String jdbcUrl = jdbcConnectionDetails.getJdbcUrl();
URI uri = URI.create(jdbcUrl.substring(5));
String host = uri.getHost();
int dbPort = uri.getPort();
Expand All @@ -78,8 +77,8 @@ EmbeddingStore<TextSegment> embeddingStore(
EmbeddingStore<TextSegment> embeddingStore = PgVectorEmbeddingStore.builder()
.host(host)
.port(dbPort != -1 ? dbPort : 5432)
.user(hikariDataSource.getUsername())
.password(hikariDataSource.getPassword())
.user(jdbcConnectionDetails.getUsername())
.password(jdbcConnectionDetails.getPassword())
.database(path.substring(1))
.table("ai_vector_store")
.dimension(384)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.learning.ai.config;

import com.learning.ai.domain.AICustomerSupportResponse;
import com.learning.ai.domain.response.AICustomerSupportResponse;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.V;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public class ChatTools {
/**
* This tool is available to {@link AICustomerSupportAgent}
*/
@Tool
@Tool("chatAssistantTools")
String currentTime() {
log.info("Inside ChatTools");
return LocalTime.now().toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@
import org.springframework.context.annotation.Configuration;

@Configuration(proxyBeanMethods = false)
@OpenAPIDefinition(info = @Info(title = "llm-rag-with-langchain4j", version = "v1.0.0"), servers = @Server(url = "/"))
@OpenAPIDefinition(
info = @Info(title = "rag-langchain4j-AllMiniLmL6V2-llm", version = "v1.0.0"),
servers = @Server(url = "/"))
public class SwaggerConfig {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.learning.ai.controller;

import com.learning.ai.config.AICustomerSupportAgent;
import com.learning.ai.domain.request.AIChatRequest;
import com.learning.ai.domain.response.AICustomerSupportResponse;
import jakarta.validation.Valid;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

@RestController
@RequestMapping("/api/ai")
@Validated
public class CustomerSupportController {

private final AICustomerSupportAgent aiCustomerSupportAgent;

public CustomerSupportController(AICustomerSupportAgent aiCustomerSupportAgent) {
this.aiCustomerSupportAgent = aiCustomerSupportAgent;
}

@PostMapping("/chat")
public AICustomerSupportResponse customerSupportChat(@RequestBody @Valid AIChatRequest aiChatRequest) {
return aiCustomerSupportAgent.chat(aiChatRequest.question());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.learning.ai.domain.request;

import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Pattern;
import jakarta.validation.constraints.Size;
import java.io.Serializable;

public record AIChatRequest(
@NotBlank(message = "Query cannot be empty")
@Size(max = 800, message = "Query exceeds maximum length")
@Pattern(regexp = "^[a-zA-Z0-9 ?]*$", message = "Invalid characters in query")
String question)
implements Serializable {}
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
package com.learning.ai.domain;
package com.learning.ai.domain.response;

public record AICustomerSupportResponse(String response) {}
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
spring.application.name=rag-langchain4j-AllMiniLmL6V2-llm

langchain4j.open-ai.chat-model.api-key=demo
langchain4j.open-ai.chat-model.model-name=gpt-3.5-turbo
langchain4j.open-ai.chat-model.temperature=0.7
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package com.learning.ai;

import static io.restassured.RestAssured.given;
import static io.restassured.RestAssured.when;
import static org.hamcrest.Matchers.notNullValue;

import com.learning.ai.domain.request.AIChatRequest;
import io.restassured.RestAssured;
import io.restassured.http.ContentType;
import io.restassured.http.Method;
import org.apache.http.HttpStatus;
import org.junit.jupiter.api.BeforeAll;
Expand All @@ -26,15 +27,22 @@ public void setUp() {
}

@Test
void whenRequestGet_thenOK() {
when().request(Method.GET, "/api/chat").then().statusCode(HttpStatus.SC_OK);
void whenRequestPost_thenOK() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest(
"what should I know about the transition to consumer direct care network washington?"))
.when()
.request(Method.POST, "/api/ai/chat")
.then()
.statusCode(HttpStatus.SC_OK);
}

@Test
void whenRequestGetTime_thenOK() {
given().param("message", "What is the time now?")
given().contentType(ContentType.JSON)
.body(new AIChatRequest("What is the time now?"))
.when()
.request(Method.GET, "/api/chat")
.request(Method.POST, "/api/ai/chat")
.then()
.statusCode(HttpStatus.SC_OK)
.body("response", notNullValue());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package com.learning.ai.llmragwithspringai.controller;

import com.learning.ai.llmragwithspringai.model.request.AIChatRequest;
import com.learning.ai.llmragwithspringai.model.response.AIChatResponse;
import com.learning.ai.llmragwithspringai.service.AIChatService;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Pattern;
import jakarta.validation.constraints.Size;
import java.util.Map;
import jakarta.validation.Valid;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

@RestController
Expand All @@ -22,14 +21,9 @@ public AiController(AIChatService aiChatService) {
this.aiChatService = aiChatService;
}

@GetMapping("/chat")
Map<String, String> ragService(
@RequestParam
@NotBlank(message = "Query cannot be empty")
@Size(max = 255, message = "Query exceeds maximum length")
@Pattern(regexp = "^[a-zA-Z0-9 ]*$", message = "Invalid characters in query")
String question) {
String chatResponse = aiChatService.chat(question);
return Map.of("response", chatResponse);
@PostMapping("/chat")
AIChatResponse ragService(@Valid @RequestBody AIChatRequest aiChatRequest) {
String chatResponse = aiChatService.chat(aiChatRequest.question());
return new AIChatResponse(chatResponse);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.learning.ai.llmragwithspringai.model.request;

import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Pattern;
import jakarta.validation.constraints.Size;
import java.io.Serializable;

public record AIChatRequest(
@NotBlank(message = "Query cannot be empty")
@Size(max = 800, message = "Query exceeds maximum length")
@Pattern(regexp = "^[a-zA-Z0-9 ?]*$", message = "Invalid characters in query")
String question)
implements Serializable {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.learning.ai.llmragwithspringai.model.response;

import java.io.Serializable;

public record AIChatResponse(String response) implements Serializable {}
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ public class AIChatService {

private static final String template =
"""
You're assisting with questions about cricket
You're assisting with questions about cricketers
Cricket is a bat-and-ball game that is played between two teams of eleven players on a field at the centre of which is a 22-yard (20-metre) pitch with a wicket at each end,
each comprising two bails balanced on three stumps.
Two players from the batting team (the striker and nonstriker) stand in front of either wicket,
with one player from the fielding team (the bowler) bowling the ball towards the striker's wicket from the opposite end of the pitch.
The striker's goal is to hit the bowled ball and then switch places with the nonstriker,
with the batting team scoring one run for each exchange.
The striker's goal is to hit the bowled ball and then switch places with the nonstriker, with the batting team scoring one run for each exchange.
Runs are also scored when the ball reaches or crosses the boundary of the field or when the ball is bowled illegally.
Use the information from the DOCUMENTS section to provide accurate answers but act as if you knew this information innately.
Expand All @@ -44,16 +43,16 @@ public AIChatService(ChatClient aiClient, VectorStore vectorStore) {
this.vectorStore = vectorStore;
}

public String chat(String message) {
public String chat(String searchQuery) {
// Querying the VectorStore using natural language looking for the information about info asked.
List<Document> listOfSimilarDocuments = this.vectorStore.similaritySearch(message);
List<Document> listOfSimilarDocuments = this.vectorStore.similaritySearch(searchQuery);
String documents = listOfSimilarDocuments.stream()
.map(Document::getContent)
.collect(Collectors.joining(System.lineSeparator()));
// Constructing the systemMessage to indicate the AI model to use the passed information
// to answer the question.
Message systemMessage = new SystemPromptTemplate(template).createMessage(Map.of("documents", documents));
UserMessage userMessage = new UserMessage(message);
UserMessage userMessage = new UserMessage(searchQuery);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse aiResponse = aiClient.call(prompt);
return aiResponse.getResult().getOutput().getContent();
Expand Down
Loading

0 comments on commit 6adb623

Please sign in to comment.