Skip to content

Commit

Permalink
move non-standard standard test to openai-specific tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme committed Jul 23, 2024
1 parent a5691e2 commit 54ec656
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 65 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Test ChatOpenAI chat model."""

import base64
from typing import Any, AsyncIterator, List, Optional, cast

import httpx
import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import (
Expand Down Expand Up @@ -684,3 +686,67 @@ def test_openai_response_headers_invoke() -> None:
assert headers
assert isinstance(headers, dict)
assert "content-type" in headers


def test_image_token_counting_jpeg() -> None:
model = ChatOpenAI(model="gpt-4o", temperature=0)
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
message = HumanMessage(
content=[
{"type": "text", "text": "describe the weather in this image"},
{"type": "image_url", "image_url": {"url": image_url}},
]
)
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens"
]
actual = model.get_num_tokens_from_messages([message])
assert expected == actual

image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
message = HumanMessage(
content=[
{"type": "text", "text": "describe the weather in this image"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
},
]
)
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens"
]
actual = model.get_num_tokens_from_messages([message])
assert expected == actual


def test_image_token_counting_png() -> None:
model = ChatOpenAI(model="gpt-4o", temperature=0)
image_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
message = HumanMessage(
content=[
{"type": "text", "text": "how many dice are in this image"},
{"type": "image_url", "image_url": {"url": image_url}},
]
)
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens"
]
actual = model.get_num_tokens_from_messages([message])
assert expected == actual

image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
message = HumanMessage(
content=[
{"type": "text", "text": "how many dice are in this image"},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_data}"},
},
]
)
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens"
]
actual = model.get_num_tokens_from_messages([message])
assert expected == actual
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
"""Standard LangChain interface tests"""

import base64
from typing import Type, cast
from typing import Type

import httpx
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests

from langchain_openai import ChatOpenAI
Expand All @@ -23,64 +20,3 @@ def chat_model_params(self) -> dict:
@property
def supports_image_inputs(self) -> bool:
return True

# TODO: Add to standard tests if reliable token counting is added to other models.
def test_image_token_counting_jpeg(self, model: BaseChatModel) -> None:
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
message = HumanMessage(
content=[
{"type": "text", "text": "describe the weather in this image"},
{"type": "image_url", "image_url": {"url": image_url}},
]
)
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens"
]
actual = model.get_num_tokens_from_messages([message])
assert expected == actual

image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
message = HumanMessage(
content=[
{"type": "text", "text": "describe the weather in this image"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
},
]
)
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens"
]
actual = model.get_num_tokens_from_messages([message])
assert expected == actual

def test_image_token_counting_png(self, model: BaseChatModel) -> None:
image_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
message = HumanMessage(
content=[
{"type": "text", "text": "how many dice are in this image"},
{"type": "image_url", "image_url": {"url": image_url}},
]
)
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens"
]
actual = model.get_num_tokens_from_messages([message])
assert expected == actual

image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
message = HumanMessage(
content=[
{"type": "text", "text": "how many dice are in this image"},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_data}"},
},
]
)
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens"
]
actual = model.get_num_tokens_from_messages([message])
assert expected == actual

0 comments on commit 54ec656

Please sign in to comment.