diff --git a/api/catalog/api/utils/validate_images.py b/api/catalog/api/utils/validate_images.py index 11de88ce4..21560a673 100644 --- a/api/catalog/api/utils/validate_images.py +++ b/api/catalog/api/utils/validate_images.py @@ -33,9 +33,9 @@ def _get_expiry(status, default): async def _head(url: str, session: aiohttp.ClientSession) -> tuple[str, int]: try: - async with session.head(url, timeout=2, allow_redirects=False) as response: + async with session.head(url, allow_redirects=False) as response: return url, response.status - except aiohttp.ClientError as exception: + except (aiohttp.ClientError, asyncio.TimeoutError) as exception: _log_validation_failure(exception) return url, -1 @@ -44,7 +44,8 @@ async def _head(url: str, session: aiohttp.ClientSession) -> tuple[str, int]: @async_to_sync async def _make_head_requests(urls: list[str]) -> list[tuple[str, int]]: tasks = [] - async with aiohttp.ClientSession(headers=HEADERS) as session: + timeout = aiohttp.ClientTimeout(total=2) + async with aiohttp.ClientSession(headers=HEADERS, timeout=timeout) as session: tasks = [asyncio.ensure_future(_head(url, session)) for url in urls] responses = asyncio.gather(*tasks) await responses diff --git a/api/test/unit/utils/validate_images_test.py b/api/test/unit/utils/validate_images_test.py index 9987709b1..e38e42284 100644 --- a/api/test/unit/utils/validate_images_test.py +++ b/api/test/unit/utils/validate_images_test.py @@ -1,3 +1,4 @@ +import asyncio from unittest import mock import aiohttp @@ -43,4 +44,27 @@ def test_sends_user_agent(wrapped_client_session: mock.AsyncMock): for url in image_urls: assert url in requested_urls - wrapped_client_session.assert_called_once_with(headers=HEADERS) + wrapped_client_session.assert_called_once_with(headers=HEADERS, timeout=mock.ANY) + + +def test_handles_timeout(): + """ + Note: This test takes just over 3 seconds to run as it simulates network delay of 3 seconds. + """ + query_hash = "test_handles_timeout" + results = [{"identifier": i} for i in range(1)] + image_urls = [f"https://example.org/{i}" for i in range(len(results))] + start_slice = 0 + + def raise_timeout_error(*args, **kwargs): + raise asyncio.TimeoutError() + + with mock.patch( + "aiohttp.client.ClientSession._request", side_effect=raise_timeout_error + ): + validate_images(query_hash, start_slice, results, image_urls) + + # `validate_images` directly modifies the results list + # if the results are timing out then they're considered dead and discarded + # so should not appear in the final list of results. + assert len(results) == 0