Skip to content

Commit

Permalink
feat: improve async sharding (#977)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche authored Jun 11, 2024
1 parent c67f275 commit fd1f7da
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 77 deletions.
63 changes: 34 additions & 29 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,43 +739,48 @@ async def read_rows_sharded(
"""
if not sharded_query:
raise ValueError("empty sharded_query")
# reduce operation_timeout between batches
operation_timeout, attempt_timeout = _get_timeouts(
operation_timeout, attempt_timeout, self
)
timeout_generator = _attempt_timeout_generator(
# make sure each rpc stays within overall operation timeout
rpc_timeout_generator = _attempt_timeout_generator(
operation_timeout, operation_timeout
)
# submit shards in batches if the number of shards goes over _CONCURRENCY_LIMIT
batched_queries = [
sharded_query[i : i + _CONCURRENCY_LIMIT]
for i in range(0, len(sharded_query), _CONCURRENCY_LIMIT)
]
# run batches and collect results
results_list = []
error_dict = {}
shard_idx = 0
for batch in batched_queries:
batch_operation_timeout = next(timeout_generator)
routine_list = [
self.read_rows(

# limit the number of concurrent requests using a semaphore
concurrency_sem = asyncio.Semaphore(_CONCURRENCY_LIMIT)

async def read_rows_with_semaphore(query):
async with concurrency_sem:
# calculate new timeout based on time left in overall operation
shard_timeout = next(rpc_timeout_generator)
if shard_timeout <= 0:
raise DeadlineExceeded(
"Operation timeout exceeded before starting query"
)
return await self.read_rows(
query,
operation_timeout=batch_operation_timeout,
attempt_timeout=min(attempt_timeout, batch_operation_timeout),
operation_timeout=shard_timeout,
attempt_timeout=min(attempt_timeout, shard_timeout),
retryable_errors=retryable_errors,
)
for query in batch
]
batch_result = await asyncio.gather(*routine_list, return_exceptions=True)
for result in batch_result:
if isinstance(result, Exception):
error_dict[shard_idx] = result
elif isinstance(result, BaseException):
# BaseException not expected; raise immediately
raise result
else:
results_list.extend(result)
shard_idx += 1

routine_list = [read_rows_with_semaphore(query) for query in sharded_query]
batch_result = await asyncio.gather(*routine_list, return_exceptions=True)

# collect results and errors
error_dict = {}
shard_idx = 0
results_list = []
for result in batch_result:
if isinstance(result, Exception):
error_dict[shard_idx] = result
elif isinstance(result, BaseException):
# BaseException not expected; raise immediately
raise result
else:
results_list.extend(result)
shard_idx += 1
if error_dict:
# if any sub-request failed, raise an exception instead of returning results
raise ShardedReadRowsExceptionGroup(
Expand Down
155 changes: 107 additions & 48 deletions tests/unit/data/_async/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,62 +1927,121 @@ async def mock_call(*args, **kwargs):
assert call_time < 0.2

@pytest.mark.asyncio
async def test_read_rows_sharded_batching(self):
async def test_read_rows_sharded_concurrency_limit(self):
"""
Large queries should be processed in batches to limit concurrency
operation timeout should change between batches
Only 10 queries should be processed concurrently. Others should be queued
Should start a new query as soon as previous finishes
"""
from google.cloud.bigtable.data._async.client import TableAsync
from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT

assert _CONCURRENCY_LIMIT == 10 # change this test if this changes
num_queries = 15

n_queries = 90
expected_num_batches = n_queries // _CONCURRENCY_LIMIT
query_list = [ReadRowsQuery() for _ in range(n_queries)]

table_mock = AsyncMock()
start_operation_timeout = 10
start_attempt_timeout = 3
table_mock.default_read_rows_operation_timeout = start_operation_timeout
table_mock.default_read_rows_attempt_timeout = start_attempt_timeout
# clock ticks one second on each check
with mock.patch("time.monotonic", side_effect=range(0, 100000)):
with mock.patch("asyncio.gather", AsyncMock()) as gather_mock:
await TableAsync.read_rows_sharded(table_mock, query_list)
# should have individual calls for each query
assert table_mock.read_rows.call_count == n_queries
# should have single gather call for each batch
assert gather_mock.call_count == expected_num_batches
# ensure that timeouts decrease over time
kwargs = [
table_mock.read_rows.call_args_list[idx][1]
for idx in range(n_queries)
]
for batch_idx in range(expected_num_batches):
batch_kwargs = kwargs[
batch_idx
* _CONCURRENCY_LIMIT : (batch_idx + 1)
* _CONCURRENCY_LIMIT
# each of the first 10 queries take longer than the last
# later rpcs will have to wait on first 10
increment_time = 0.05
max_time = increment_time * (_CONCURRENCY_LIMIT - 1)
rpc_times = [min(i * increment_time, max_time) for i in range(num_queries)]

async def mock_call(*args, **kwargs):
next_sleep = rpc_times.pop(0)
await asyncio.sleep(next_sleep)
return [mock.Mock()]

starting_timeout = 10

async with _make_client() as client:
async with client.get_table("instance", "table") as table:
with mock.patch.object(table, "read_rows") as read_rows:
read_rows.side_effect = mock_call
queries = [ReadRowsQuery() for _ in range(num_queries)]
await table.read_rows_sharded(
queries, operation_timeout=starting_timeout
)
assert read_rows.call_count == num_queries
# check operation timeouts to see how far into the operation each rpc started
rpc_start_list = [
starting_timeout - kwargs["operation_timeout"]
for _, kwargs in read_rows.call_args_list
]
for req_kwargs in batch_kwargs:
# each batch should have the same operation_timeout, and it should decrease in each batch
expected_operation_timeout = start_operation_timeout - (
batch_idx + 1
)
assert (
req_kwargs["operation_timeout"]
== expected_operation_timeout
)
# each attempt_timeout should start with default value, but decrease when operation_timeout reaches it
expected_attempt_timeout = min(
start_attempt_timeout, expected_operation_timeout
eps = 0.01
# first 10 should start immediately
assert all(
rpc_start_list[i] < eps for i in range(_CONCURRENCY_LIMIT)
)
# next rpcs should start as first ones finish
for i in range(num_queries - _CONCURRENCY_LIMIT):
idx = i + _CONCURRENCY_LIMIT
assert rpc_start_list[idx] - (i * increment_time) < eps

@pytest.mark.asyncio
async def test_read_rows_sharded_expirary(self):
"""
If the operation times out before all shards complete, should raise
a ShardedReadRowsExceptionGroup
"""
from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT
from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup
from google.api_core.exceptions import DeadlineExceeded

operation_timeout = 0.1

# let the first batch complete, but the next batch times out
num_queries = 15
sleeps = [0] * _CONCURRENCY_LIMIT + [DeadlineExceeded("times up")] * (
num_queries - _CONCURRENCY_LIMIT
)

async def mock_call(*args, **kwargs):
next_item = sleeps.pop(0)
if isinstance(next_item, Exception):
raise next_item
else:
await asyncio.sleep(next_item)
return [mock.Mock()]

async with _make_client() as client:
async with client.get_table("instance", "table") as table:
with mock.patch.object(table, "read_rows") as read_rows:
read_rows.side_effect = mock_call
queries = [ReadRowsQuery() for _ in range(num_queries)]
with pytest.raises(ShardedReadRowsExceptionGroup) as exc:
await table.read_rows_sharded(
queries, operation_timeout=operation_timeout
)
assert req_kwargs["attempt_timeout"] == expected_attempt_timeout
# await all created coroutines to avoid warnings
for i in range(len(gather_mock.call_args_list)):
for j in range(len(gather_mock.call_args_list[i][0])):
await gather_mock.call_args_list[i][0][j]
assert isinstance(exc.value, ShardedReadRowsExceptionGroup)
assert len(exc.value.exceptions) == num_queries - _CONCURRENCY_LIMIT
# should keep successful queries
assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT

@pytest.mark.asyncio
async def test_read_rows_sharded_negative_batch_timeout(self):
"""
try to run with batch that starts after operation timeout
They should raise DeadlineExceeded errors
"""
from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup
from google.api_core.exceptions import DeadlineExceeded

async def mock_call(*args, **kwargs):
await asyncio.sleep(0.05)
return [mock.Mock()]

async with _make_client() as client:
async with client.get_table("instance", "table") as table:
with mock.patch.object(table, "read_rows") as read_rows:
read_rows.side_effect = mock_call
queries = [ReadRowsQuery() for _ in range(15)]
with pytest.raises(ShardedReadRowsExceptionGroup) as exc:
await table.read_rows_sharded(queries, operation_timeout=0.01)
assert isinstance(exc.value, ShardedReadRowsExceptionGroup)
assert len(exc.value.exceptions) == 5
assert all(
isinstance(e.__cause__, DeadlineExceeded)
for e in exc.value.exceptions
)


class TestSampleRowKeys:
Expand Down

0 comments on commit fd1f7da

Please sign in to comment.