From c04d756d9091edd63943357b04e29a810bc27a61 Mon Sep 17 00:00:00 2001 From: Raja Kolli Date: Tue, 24 Dec 2024 17:41:48 +0000 Subject: [PATCH] feat : adds streaming endpoint --- .../example/ai/controller/ChatController.java | 12 ++++++----- .../com/example/ai/service/ChatService.java | 20 +++++++------------ 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java index df70268..32bbbe2 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java +++ b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java @@ -2,6 +2,7 @@ import com.example.ai.model.request.AIChatRequest; import com.example.ai.model.response.AIChatResponse; +import com.example.ai.model.response.AIStreamChatResponse; import com.example.ai.model.response.ActorsFilms; import com.example.ai.service.ChatService; import org.springframework.web.bind.annotation.GetMapping; @@ -10,6 +11,7 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; +import reactor.core.publisher.Flux; @RestController @RequestMapping("/api/ai") @@ -56,9 +58,9 @@ AIChatResponse chatUsingRag(@RequestBody AIChatRequest aiChatRequest) { return chatService.ragGenerate(aiChatRequest.query()); } - // @PostMapping("/chat/stream") - // AIStreamChatResponse streamChat(@RequestBody AIChatRequest aiChatRequest) { - // Flux streamChat = chatService.streamChat(aiChatRequest.query()); - // return new AIStreamChatResponse(streamChat); - // } + @PostMapping("/chat/stream") + AIStreamChatResponse streamChat(@RequestBody AIChatRequest aiChatRequest) { + Flux streamChat = chatService.streamChat(aiChatRequest.query()); + return new AIStreamChatResponse(streamChat); + } } diff --git a/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java b/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java index 5bf4bc2..ca95e7f 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java +++ b/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java @@ -10,8 +10,6 @@ import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.AssistantPromptTemplate; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; @@ -26,6 +24,7 @@ import org.springframework.core.io.Resource; import org.springframework.stereotype.Service; import org.springframework.util.Assert; +import reactor.core.publisher.Flux; @Service public class ChatService { @@ -57,9 +56,7 @@ public AIChatResponse chat(String query) { public AIChatResponse chatWithPrompt(String query) { PromptTemplate promptTemplate = new PromptTemplate("Tell me a joke about {subject}"); Prompt prompt = promptTemplate.create(Map.of("subject", query)); - ChatResponse response = chatClient.prompt(prompt).call().chatResponse(); - Generation generation = response.getResult(); - String answer = (generation != null) ? generation.getOutput().getContent() : ""; + String answer = chatClient.prompt(prompt).call().content(); return new AIChatResponse(answer); } @@ -67,17 +64,14 @@ public AIChatResponse chatWithSystemPrompt(String query) { SystemMessage systemMessage = new SystemMessage("You are a sarcastic and funny chatbot"); UserMessage userMessage = new UserMessage("Tell me a joke about " + query); Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); - ChatResponse response = chatClient.prompt(prompt).call().chatResponse(); - String answer = response.getResult().getOutput().getContent(); + String answer = chatClient.prompt(prompt).call().content(); return new AIChatResponse(answer); } public AIChatResponse analyzeSentiment(String query) { AssistantPromptTemplate promptTemplate = new AssistantPromptTemplate(SENTIMENT_ANALYSIS_TEMPLATE); Prompt prompt = promptTemplate.create(Map.of("query", query)); - ChatResponse response = chatClient.prompt(prompt).call().chatResponse(); - Generation generation = response.getResult(); - String answer = (generation != null) ? generation.getOutput().getContent() : ""; + String answer = chatClient.prompt(prompt).call().content(); return new AIChatResponse(answer); } @@ -138,7 +132,7 @@ public AIChatResponse ragGenerate(String query) { return new AIChatResponse(response); } - // public Flux streamChat(String query) { - // return streamingChatClient.stream(query); - // } + public Flux streamChat(String query) { + return chatClient.prompt(query).stream().content(); + } }