Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stop/interrupt capability #174

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions computer-use-demo/computer_use_demo/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import subprocess
import traceback
from contextlib import contextmanager
from datetime import datetime, timedelta
from enum import StrEnum
from functools import partial
Expand All @@ -19,6 +20,7 @@
from anthropic.types.beta import (
BetaContentBlockParam,
BetaTextBlockParam,
BetaToolResultBlockParam,
)
from streamlit.delta_generator import DeltaGenerator

Expand All @@ -33,10 +35,14 @@
API_KEY_FILE = CONFIG_DIR / "api_key"
STREAMLIT_STYLE = """
<style>
/* Hide chat input while agent loop is running */
.stApp[data-teststate=running] .stChatInput textarea,
.stApp[data-test-script-state=running] .stChatInput textarea {
display: none;
/* Highlight the stop button in red */
button[kind=header] {
background-color: rgb(255, 75, 75);
border: 1px solid rgb(255, 75, 75);
color: rgb(255, 255, 255);
}
button[kind=header]:hover {
background-color: rgb(255, 51, 51);
}
/* Hide the streamlit deploy button */
.stAppDeployButton {
Expand All @@ -46,6 +52,8 @@
"""

WARNING_TEXT = "⚠️ Security Alert: Never provide access to sensitive accounts or data, as malicious web content can hijack Claude's behavior"
INTERRUPT_TEXT = "(user stopped or interrupted and wrote the following)"
INTERRUPT_TOOL_ERROR = "human stopped or interrupted tool execution"


class Sender(StrEnum):
Expand Down Expand Up @@ -82,6 +90,8 @@ def setup_state():
st.session_state.custom_system_prompt = load_from_storage("system_prompt") or ""
if "hide_images" not in st.session_state:
st.session_state.hide_images = False
if "in_sampling_loop" not in st.session_state:
st.session_state.in_sampling_loop = False


def _reset_model():
Expand Down Expand Up @@ -195,7 +205,10 @@ def _reset_api_provider():
st.session_state.messages.append(
{
"role": Sender.USER,
"content": [BetaTextBlockParam(type="text", text=new_message)],
"content": [
*maybe_add_interruption_blocks(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

niceee lol

BetaTextBlockParam(type="text", text=new_message),
],
}
)
_render_message(Sender.USER, new_message)
Expand All @@ -209,7 +222,7 @@ def _reset_api_provider():
# we don't have a user message to respond to, exit early
return

with st.spinner("Running Agent..."):
with track_sampling_loop():
# run the agent sampling loop with the newest message
st.session_state.messages = await sampling_loop(
system_prompt_suffix=st.session_state.custom_system_prompt,
Expand All @@ -230,6 +243,36 @@ def _reset_api_provider():
)


def maybe_add_interruption_blocks():
if not st.session_state.in_sampling_loop:
return []
# If this function is called while we're in the sampling loop, we can assume that the previous sampling loop was interrupted
# and we should annotate the conversation with additional context for the model and heal any incomplete tool use calls
result = []
last_message = st.session_state.messages[-1]
previous_tool_use_ids = [
block["id"] for block in last_message["content"] if block["type"] == "tool_use"
]
for tool_use_id in previous_tool_use_ids:
tool_result = BetaToolResultBlockParam(
tool_use_id=tool_use_id,
type="tool_result",
content=INTERRUPT_TOOL_ERROR,
is_error=True,
)
st.session_state.tools[tool_use_id] = tool_result
result.append(tool_result)
result.append(BetaTextBlockParam(type="text", text=INTERRUPT_TEXT))
return result


@contextmanager
def track_sampling_loop():
st.session_state.in_sampling_loop = True
yield
st.session_state.in_sampling_loop = False


def validate_auth(provider: APIProvider, api_key: str | None):
if provider == APIProvider.ANTHROPIC:
if not api_key:
Expand Down
Loading