Skip to content

Commit

Permalink
continue running the thread when switching chats
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed May 30, 2024
1 parent 728b74f commit 2b18336
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 43 deletions.
28 changes: 17 additions & 11 deletions frontend/app/components/ChatWindow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ export function ChatWindow() {
loadMoreThreads,
areThreadsLoading,
} = useThreadList(userId);
const { streamState, startStream, stopStream } = useStreamState();
const { streamStates, startStream, stopStream } = useStreamState();
const streamState =
currentThread == null
? null
: streamStates[currentThread.thread_id] ?? null;
const { refreshMessages, messages, setMessages, next, areMessagesLoading } =
useThreadMessages(
currentThread?.thread_id ?? null,
Expand Down Expand Up @@ -203,18 +207,14 @@ export function ChatWindow() {

const selectThread = useCallback(
async (id: string | null) => {
if (currentThread) {
stopStream?.(true);
}

if (!id) {
const thread = await createThread("New chat");
insertUrlParam("threadId", thread["thread_id"]);
} else {
insertUrlParam("threadId", id);
}
},
[currentThread, stopStream, setMessages, createThread, insertUrlParam],
[setMessages, createThread, insertUrlParam],
);

const deleteThreadAndReset = async (id: string) => {
Expand Down Expand Up @@ -338,11 +338,11 @@ export function ChatWindow() {
className="flex flex-col-reverse w-full mb-2 overflow-auto max-h-[75vh]"
ref={messageContainerRef}
>
{messages.length > 0 ? (
{messages.length > 0 && currentThread != null ? (
<>
{next.length > 0 &&
streamState?.status !== "inflight" &&
currentThread != null && (
streamStates[currentThread.thread_id]?.status !==
"inflight" && (
<Button
key={"continue-button"}
backgroundColor={"rgb(58, 58, 61)"}
Expand All @@ -359,7 +359,9 @@ export function ChatWindow() {
<ChatMessageBubble
key={m.id}
message={{ ...m }}
feedbackUrls={streamState?.feedbackUrls}
feedbackUrls={
streamStates[currentThread.thread_id]?.feedbackUrls
}
aiEmoji="🦜"
isMostRecent={index === 0}
messageCompleted={!isLoading}
Expand Down Expand Up @@ -403,8 +405,12 @@ export function ChatWindow() {
type="submit"
onClick={(e) => {
e.preventDefault();
if (currentThread == null) {
return;
}

if (isLoading) {
stopStream?.();
stopStream?.(currentThread.thread_id);
} else {
sendMessage();
}
Expand Down
89 changes: 59 additions & 30 deletions frontend/app/hooks/useStreamState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ export interface StreamState {
}

export interface StreamStateProps {
streamState: StreamState | null;
streamStates: { [threadId: string]: StreamState | null };
startStream: (
input: Message[] | null,
threadId: string,
assistantId: string,
config?: Config,
) => Promise<void>;
stopStream?: (clear?: boolean) => void;
stopStream?: (threadId: string, clear?: boolean) => void;
}

export function mergeMessagesById(
Expand All @@ -44,7 +44,9 @@ export function mergeMessagesById(
}

export function useStreamState(): StreamStateProps {
const [current, setCurrent] = useState<StreamState | null>(null);
const [streamStates, setStreamStates] = useState<
StreamStateProps["streamStates"]
>({});
const [controller, setController] = useState<AbortController | null>(null);
const client = useLangGraphClient();

Expand All @@ -57,7 +59,10 @@ export function useStreamState(): StreamStateProps {
) => {
const controller = new AbortController();
setController(controller);
setCurrent({ status: "inflight", messages: messages || [] });
setStreamStates((streamStates) => ({
...streamStates,
[threadId]: { status: "inflight", messages: messages || [] },
}));

const stream = client.runs.stream(threadId, assistantId, {
input: messages == null ? null : { messages },
Expand All @@ -70,33 +75,51 @@ export function useStreamState(): StreamStateProps {
for await (const chunk of stream) {
if (chunk.event === "messages/partial") {
const chunkMessages = chunk.data as Message[];
setCurrent((current) => ({
...current,
status: "inflight",
messages: mergeMessagesById(current?.messages, chunkMessages),
setStreamStates((streamStates) => ({
...streamStates,
[threadId]: {
...streamStates[threadId],
status: "inflight",
messages: mergeMessagesById(
streamStates[threadId]?.messages,
chunkMessages,
),
},
}));
} else if (chunk.event === "values") {
const data = chunk.data as Record<string, any>;
setCurrent((current) => ({
...current,
status: "inflight",
documents: data["documents"],
setStreamStates((streamStates) => ({
...streamStates,
[threadId]: {
...streamStates[threadId],
status: "inflight",
documents: data["documents"],
},
}));
} else if (chunk.event === "error") {
setCurrent((current) => ({
...current,
status: "error",
setStreamStates((streamStates) => ({
...streamStates,
[threadId]: {
...streamStates[threadId],
status: "error",
},
}));
} else if (chunk.event === "feedback") {
setCurrent((current) => ({
...current,
feedbackUrls: chunk.data,
status: "inflight",
setStreamStates((streamStates) => ({
...streamStates,
[threadId]: {
...streamStates[threadId],
feedbackUrls: chunk.data,
status: "inflight",
},
}));
} else if (chunk.event === "end") {
setCurrent((current) => ({
...current,
status: "done",
setStreamStates((streamStates) => ({
...streamStates,
[threadId]: {
...streamStates[threadId],
status: "done",
},
}));
}
}
Expand All @@ -105,17 +128,23 @@ export function useStreamState(): StreamStateProps {
);

const stopStream = useCallback(
(clear: boolean = false) => {
(threadId: string, clear: boolean = false) => {
controller?.abort();
setController(null);
if (clear) {
setCurrent({
status: "done",
});
setStreamStates((streamStates) => ({
...streamStates,
[threadId]: {
status: "done",
},
}));
} else {
setCurrent((current) => ({
...current,
status: "done",
setStreamStates((streamStates) => ({
...streamStates,
[threadId]: {
...streamStates[threadId],
status: "done",
},
}));
}
},
Expand All @@ -125,6 +154,6 @@ export function useStreamState(): StreamStateProps {
return {
startStream,
stopStream,
streamState: current,
streamStates,
};
}
4 changes: 2 additions & 2 deletions frontend/app/hooks/useThreadMessages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export function getSources(documents: Document[]) {
export function useThreadMessages(
threadId: string | null,
streamState: StreamState | null,
stopStream?: (clear?: boolean) => void,
stopStream?: (threadId: string, clear?: boolean) => void,
) {
const client = useLangGraphClient();
const [messages, setMessages] = useState<Message[]>([]);
Expand Down Expand Up @@ -82,7 +82,7 @@ export function useThreadMessages(
);
setMessages(messages);
setNext(next);
stopStream?.(true);
stopStream?.(threadId, true);
}
}

Expand Down

0 comments on commit 2b18336

Please sign in to comment.