diff --git a/computer-use-demo/computer_use_demo/loop.py b/computer-use-demo/computer_use_demo/loop.py index bb959e4b..d13120c3 100644 --- a/computer-use-demo/computer_use_demo/loop.py +++ b/computer-use-demo/computer_use_demo/loop.py @@ -8,11 +8,18 @@ from enum import StrEnum from typing import Any, cast -from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse +from anthropic import ( + Anthropic, + AnthropicBedrock, + AnthropicVertex, + APIResponse, + BaseModel, +) from anthropic.types import ( ToolResultBlockParam, ) from anthropic.types.beta import ( + BetaCacheControlEphemeralParam, BetaContentBlock, BetaContentBlockParam, BetaImageBlockParam, @@ -24,8 +31,6 @@ from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult -BETA_FLAG = "computer-use-2024-10-22" - class APIProvider(StrEnum): ANTHROPIC = "anthropic" @@ -33,6 +38,8 @@ class APIProvider(StrEnum): VERTEX = "vertex" +MAX_PROMPT_CACHING_BREAKPOINTS = 4 + PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = { APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022", APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0", @@ -74,6 +81,7 @@ async def sampling_loop( api_key: str, only_n_most_recent_images: int | None = None, max_tokens: int = 4096, + prompt_caching: bool = True, ): """ Agentic sampling loop for the assistant/tool interaction of computer use. @@ -98,6 +106,11 @@ async def sampling_loop( elif provider == APIProvider.BEDROCK: client = AnthropicBedrock() + betas = ["computer-use-2024-10-22"] + if prompt_caching: + betas.append("prompt-caching-2024-07-31") + _add_prompt_caching_headers(messages) + # Call the API # we use raw_response to provide debug information to streamlit. Your # implementation may be able call the SDK directly with: @@ -108,7 +121,7 @@ async def sampling_loop( model=model, system=system, tools=tool_collection.to_params(), - betas=["computer-use-2024-10-22"], + betas=betas, ) api_response_callback(cast(APIResponse[BetaMessage], raw_response)) @@ -230,3 +243,33 @@ def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str): if result.system: result_text = f"{result.system}\n{result_text}" return result_text + + +def _add_prompt_caching_headers( + messages: list[BetaMessageParam], +): + prompt_caching_breakpoints = 0 + for message in messages: + if isinstance(message["content"], str): + continue + + params: list[BetaContentBlockParam] = [] + for content_block in message["content"]: + if isinstance(content_block, BaseModel): + content_block_param = cast( + BetaContentBlockParam, content_block.to_dict() + ) + else: + content_block_param = content_block + params.append(content_block_param) + + if ( + isinstance(content_block_param, dict) + and content_block_param.get("type") == "image" + and prompt_caching_breakpoints < MAX_PROMPT_CACHING_BREAKPOINTS + ): + content_block_param["cache_control"] = BetaCacheControlEphemeralParam( + type="ephemeral" + ) + prompt_caching_breakpoints += 1 + message["content"] = params diff --git a/computer-use-demo/computer_use_demo/streamlit.py b/computer-use-demo/computer_use_demo/streamlit.py index 6750029c..97bb4fcb 100644 --- a/computer-use-demo/computer_use_demo/streamlit.py +++ b/computer-use-demo/computer_use_demo/streamlit.py @@ -194,7 +194,7 @@ def _reset_api_provider(): st.session_state.messages.append( { "role": Sender.USER, - "content": [TextBlock(type="text", text=new_message)], + "content": [BetaTextBlock(type="text", text=new_message)], } ) _render_message(Sender.USER, new_message) diff --git a/computer-use-demo/tests/loop_test.py b/computer-use-demo/tests/loop_test.py index 4985dbee..acce00ae 100644 --- a/computer-use-demo/tests/loop_test.py +++ b/computer-use-demo/tests/loop_test.py @@ -1,7 +1,11 @@ from unittest import mock -from anthropic.types import TextBlock, ToolUseBlock -from anthropic.types.beta import BetaMessage, BetaMessageParam +from anthropic.types.beta import ( + BetaMessage, + BetaMessageParam, + BetaTextBlock, + BetaToolUseBlock, +) from computer_use_demo.loop import APIProvider, sampling_loop @@ -13,13 +17,13 @@ async def test_loop(): mock.Mock( spec=BetaMessage, content=[ - TextBlock(type="text", text="Hello"), - ToolUseBlock( + BetaTextBlock(type="text", text="Hello"), + BetaToolUseBlock( type="tool_use", id="1", name="computer", input={"action": "test"} ), ], ), - mock.Mock(spec=BetaMessage, content=[TextBlock(type="text", text="Done!")]), + mock.Mock(spec=BetaMessage, content=[BetaTextBlock(type="text", text="Done!")]), ] tool_collection = mock.AsyncMock() @@ -49,7 +53,8 @@ async def test_loop(): ) assert len(result) == 4 - assert result[0] == {"role": "user", "content": "Test message"} + assert result[0]["role"] == "user" + assert result[0]["content"] == "Test message" assert result[1]["role"] == "assistant" assert result[2]["role"] == "user" assert result[3]["role"] == "assistant" @@ -58,7 +63,7 @@ async def test_loop(): tool_collection.run.assert_called_once_with( name="computer", tool_input={"action": "test"} ) - output_callback.assert_called_with(TextBlock(text="Done!", type="text")) + output_callback.assert_called_with(BetaTextBlock(text="Done!", type="text")) assert output_callback.call_count == 3 assert tool_output_callback.call_count == 1 assert api_response_callback.call_count == 2 diff --git a/computer-use-demo/tests/streamlit_test.py b/computer-use-demo/tests/streamlit_test.py index 25cd586b..8235a9dd 100644 --- a/computer-use-demo/tests/streamlit_test.py +++ b/computer-use-demo/tests/streamlit_test.py @@ -1,9 +1,10 @@ from unittest import mock import pytest +from anthropic.types.beta import BetaTextBlock from streamlit.testing.v1 import AppTest -from computer_use_demo.streamlit import Sender, TextBlock +from computer_use_demo.streamlit import Sender @pytest.fixture @@ -18,6 +19,6 @@ def test_streamlit(streamlit_app: AppTest): streamlit_app.chat_input[0].set_value("Hello").run() assert patch.called assert patch.call_args.kwargs["messages"] == [ - {"role": Sender.USER, "content": [TextBlock(text="Hello", type="text")]} + {"role": Sender.USER, "content": [BetaTextBlock(text="Hello", type="text")]} ] assert not streamlit_app.exception