Skip to content

Commit

Permalink
Merge pull request #9 from 6drf21e/experimental/chat-history
Browse files Browse the repository at this point in the history
[Experimental] Add Chat History Feature
  • Loading branch information
fatwang2 authored Mar 27, 2024
2 parents fbad600 + 8526325 commit c98b899
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 20 deletions.
140 changes: 125 additions & 15 deletions search4all.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
# does not respond within this time, we will return an error.
DEFAULT_SEARCH_ENGINE_TIMEOUT = 5

# 默认记录的对话历史长度
MAX_HISTORY_LEN = 10


# If the user did not provide a query, we will use this default query.
_default_query = "Who said 'live long and prosper'?"
Expand Down Expand Up @@ -104,6 +107,33 @@ def get(self, key: str):
def put(self, key: str, value: str):
self._db[key] = value
self._db.commit()

def append(self, key: str, value):
""" 记录聊天历史 """
self._db[key] = self._db.get(key, [])
# 最长记录的对话轮数 MAX_HISTORY_LEN
_ = self._db[key][-MAX_HISTORY_LEN:]
_.append(value)
self._db[key] = _
self._db.commit()

# 格式化输出部分
def extract_all_sections(text: str):
# 定义正则表达式模式以匹配各部分
sections_pattern = r"(.*?)__LLM_RESPONSE__(.*?)(__RELATED_QUESTIONS__(.*))?$"

# 使用正则表达式查找各部分内容
match = re.search(sections_pattern, text, re.DOTALL)

# 从匹配结果中提取文本,如果没有匹配则返回None
if match:
search_results = match.group(1).strip() # 前置文本作为搜索结果
llm_response = match.group(2).strip() # 问题回答部分
related_questions = match.group(4).strip() if match.group(4) else "" # 相关问题文本,如果不存在则返回空字符串
else:
search_results, llm_response, related_questions = None, None, None

return search_results, llm_response, related_questions

def search_with_search1api(query: str, search1api_key: str):
"""
Expand Down Expand Up @@ -417,6 +447,10 @@ async def server_init(_app, loop):
_app.ctx.should_do_related_questions = bool(
os.getenv("RELATED_QUESTIONS") in ("1", "yes", "true")
)
# 是否开始聊天历史的环境变量
_app.ctx.should_do_chat_history = bool(
os.getenv("CHAT_HISTORY") in ("1", "yes", "true")
)
# Create httpx Session
_app.ctx.http_session = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10, read=120, write=120, pool=10),
Expand Down Expand Up @@ -588,21 +622,72 @@ async def query_function(request: sanic.Request):
generate_related_questions = params.get("generate_related_questions", True)
if not query:
raise HTTPException("query must be provided.")

# 定义传递给生成答案的聊天历史 以及搜索结果
chat_history = []
contexts = ""

# Note that, if uuid exists, we don't check if the stored query is the same
# as the current query, and simply return the stored result. This is to enable
# the user to share a searched link to others and have others see the same result.
if search_uuid:
try:
result = await _app.loop.run_in_executor(
_app.ctx.executor, lambda sid: _app.ctx.kv.get(sid), search_uuid
)
return sanic.text(result)
except KeyError:
logger.info(f"Key {search_uuid} not found, will generate again.")
except Exception as e:
logger.error(
f"KV error: {e}\n{traceback.format_exc()}, will generate again."
)
if _app.ctx.should_do_chat_history:
# 开启了历史记录,读取历史记录
history = []
try:
history = await _app.loop.run_in_executor(
_app.ctx.executor, lambda sid: _app.ctx.kv.get(sid), f"{search_uuid}_history"
)
result = await _app.loop.run_in_executor(
_app.ctx.executor, lambda sid: _app.ctx.kv.get(sid), search_uuid
)
# return sanic.text(result)
except KeyError:
logger.info(f"Key {search_uuid} not found, will generate again.")
except Exception as e:
logger.error(
f"KV error: {e}\n{traceback.format_exc()}, will generate again."
)
# 如果存在历史记录
if history:
# 获取最后一次记录
last_entry = history[-1]
# 确定最后一次记录的数据完整性
old_query, search_results, llm_response = last_entry.get("query", ""), last_entry.get("search_results", ""), last_entry.get("llm_response", "")
# 如果存在旧查询和搜索结果
if old_query and search_results:
if old_query != query:
# 从历史记录中获取搜索结果(最后一条)
contexts = history[-1]["search_results"]
# 将历史聊天的提问和回答提取
chat_history = []
for entry in history:
if "query" in entry and "llm_response" in entry:
chat_history.append({"role": "user", "content": entry["query"]})
chat_history.append({"role": "assistant", "content": entry["llm_response"]})
else:
return sanic.text(result["txt"]) # 查询未改变,直接返回结果
else:
try:
result = await _app.loop.run_in_executor(
_app.ctx.executor, lambda sid: _app.ctx.kv.get(sid), search_uuid
)
# debug
if isinstance(result, dict):
# 只有相同的查询才返回同一个结果, 兼容多轮对话。
if result["query"] == query:
return sanic.text(result["txt"])
else:
# TODO: 兼容旧数据代码 之后删除
# 旧数据强制刷新
# return sanic.text(result)
pass
except KeyError:
logger.info(f"Key {search_uuid} not found, will generate again.")
except Exception as e:
logger.error(
f"KV error: {e}\n{traceback.format_exc()}, will generate again."
)
else:
raise HTTPException("search_uuid must be provided.")

Expand All @@ -619,9 +704,11 @@ async def query_function(request: sanic.Request):
# query = query or _default_query
# Basic attack protection: remove "[INST]" or "[/INST]" from the query
query = re.sub(r"\[/?INST\]", "", query)
contexts = await _app.loop.run_in_executor(
_app.ctx.executor, _app.ctx.search_function, query
)
# 开启聊天历史并且有有效数据 则不再重新请求搜索
if not _app.ctx.should_do_chat_history or contexts in ("", None):
contexts = await _app.loop.run_in_executor(
_app.ctx.executor, _app.ctx.search_function, query
)

system_prompt = _rag_query_text.format(
context="\n\n".join(
Expand All @@ -630,6 +717,13 @@ async def query_function(request: sanic.Request):
)
try:
openai_client = new_async_client(_app)
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": query},
]
if chat_history and len(chat_history) % 2 == 0:
# 将历史插入到消息中 index = 1 的位置
messages[1:1] = chat_history
llm_response = await openai_client.chat.completions.create(
model=_app.ctx.model,
messages=[
Expand Down Expand Up @@ -664,8 +758,24 @@ async def query_function(request: sanic.Request):
# Second, upload to KV. Note that if uploading to KV fails, we will silently
# ignore it, because we don't want to affect the user experience.
await response.eof()
if _app.ctx.should_do_chat_history:
# 保存聊天历史
_search_results, _llm_response, _related_questions = await _app.loop.run_in_executor(
_app.ctx.executor, extract_all_sections, "".join(all_yielded_results)
)
if _search_results:
_search_results = json.loads(_search_results)
if _related_questions:
_related_questions = json.loads(_related_questions)
_ = _app.ctx.executor.submit(
_app.ctx.kv.append, f"{search_uuid}_history", {
"query": query,
"search_results": _search_results,
"llm_response": _llm_response,
"related_questions": _related_questions
})
_ = _app.ctx.executor.submit(
_app.ctx.kv.put, search_uuid, "".join(all_yielded_results)
_app.ctx.kv.put, search_uuid, {"query": query, "txt": "".join(all_yielded_results)} # 原来的缓存是直接根据sid返回结果,开启聊天历史后 同一个sid存储多轮对话,因此需要存储 query 兼容多轮对话
)


Expand Down
22 changes: 18 additions & 4 deletions web/src/app/components/search.tsx
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
"use client";
import { getSearchUrl } from "@/app/utils/get-search-url";
import { ArrowRight } from "lucide-react";
import { ArrowRight, ArrowUp } from "lucide-react";
import { nanoid } from "nanoid";
import { useRouter } from "next/navigation";
import React, { FC, useState } from "react";
import { useSearchParams } from "next/navigation";

export const Search: FC = () => {
interface SearchProps {
useContinueButton?: boolean; // true: 使用“继续对话”按钮; false: 使用“新的搜索”按钮
}
export const Search: FC<SearchProps> = ({ useContinueButton = false }) => {
const [value, setValue] = useState("");
const router = useRouter();
const searchParams = useSearchParams();
const old_rid = decodeURIComponent(searchParams.get("rid") || "");
const handleNewSearch = () => {
// 可以在这里重置任何需要的状态,以准备一个新的搜索
if (value) {
setValue(""); // 清空搜索框
router.push(getSearchUrl(encodeURIComponent(value), nanoid()));
}
};
return (
<form
onSubmit={(e) => {
e.preventDefault();
if (value) {
setValue("");
router.push(getSearchUrl(encodeURIComponent(value), nanoid()));
const rid = useContinueButton ? old_rid : nanoid();
router.push(getSearchUrl(encodeURIComponent(value), rid));
}
}}
>
Expand All @@ -34,7 +48,7 @@ export const Search: FC = () => {
type="submit"
className="w-auto py-1 px-2 bg-black border-black text-white fill-white active:scale-95 border overflow-hidden relative rounded-xl"
>
<ArrowRight size={16} />
{useContinueButton ? <ArrowUp size={16} /> : <ArrowRight size={16} />}
</button>
</label>
</form>
Expand Down
2 changes: 1 addition & 1 deletion web/src/app/search/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export default function SearchPage() {
<div className="h-80 pointer-events-none w-full rounded-b-2xl backdrop-filter absolute bottom-0 bg-gradient-to-b from-transparent to-white [mask-image:linear-gradient(to_top,white,transparent)]"></div>
<div className="absolute z-10 flex items-center justify-center bottom-6 px-4 md:px-8 w-full">
<div className="w-full">
<Search></Search>
<Search useContinueButton={true}></Search>
</div>
</div>
</div>
Expand Down

0 comments on commit c98b899

Please sign in to comment.