Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi committed Jun 8, 2022
1 parent b98f904 commit 5e80e46
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 154 deletions.
1 change: 1 addition & 0 deletions mars/deploy/oscar/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def _wrap_original_deploy_band_resources(*args, **kwargs):


@pytest.mark.asyncio
@pytest.mark.skipif(vineyard is None, reason="vineyard not installed")
async def test_vineyard_operators(create_cluster):
param = create_cluster[1]
if param != "vineyard":
Expand Down
18 changes: 14 additions & 4 deletions mars/services/scheduling/api/oscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ async def update_subtask_priority(self, args_list, kwargs_list):
)

async def cancel_subtasks(
self, subtask_ids: List[str], kill_timeout: Union[float, int] = None
self,
subtask_ids: List[str],
kill_timeout: Union[float, int] = None,
wait: bool = False,
):
"""
Cancel pending and running subtasks.
Expand All @@ -108,7 +111,14 @@ async def cancel_subtasks(
kill_timeout
timeout seconds to kill actor process forcibly
"""
await self._manager_ref.cancel_subtasks(subtask_ids, kill_timeout=kill_timeout)
if wait:
await self._manager_ref.cancel_subtasks(
subtask_ids, kill_timeout=kill_timeout
)
else:
await self._manager_ref.cancel_subtasks.tell(
subtask_ids, kill_timeout=kill_timeout
)

async def finish_subtasks(
self,
Expand All @@ -122,8 +132,8 @@ async def finish_subtasks(
Parameters
----------
subtask_ids
ids of subtasks to mark as finished
subtask_results
results of subtasks, must in finished states
bands
bands of subtasks to mark as finished
schedule_next
Expand Down
103 changes: 42 additions & 61 deletions mars/services/scheduling/supervisor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@
from .... import oscar as mo
from ....lib.aio import alru_cache
from ....metrics import Metrics
from ....oscar.backends.context import ProfilingContext
from ....oscar.errors import MarsError
from ....oscar.profiling import ProfilingData, MARS_ENABLE_PROFILING
from ....typing import BandType
from ....utils import dataslots, Timer
from ....utils import dataslots
from ...subtask import Subtask, SubtaskResult, SubtaskStatus
from ...task import TaskAPI
from ..core import SubtaskScheduleSummary
Expand Down Expand Up @@ -127,14 +125,6 @@ async def __post_create__(self):
)
await self._speculation_execution_scheduler.start()

async def dump_running():
while True:
if self._subtask_infos:
logger.warning("RUNNING: %r", list(self._subtask_infos))
await asyncio.sleep(5)

asyncio.create_task(dump_running())

async def __pre_destroy__(self):
await self._speculation_execution_scheduler.stop()

Expand Down Expand Up @@ -186,7 +176,7 @@ async def _handle_subtask_result(
self, info: SubtaskScheduleInfo, result: SubtaskResult, band: BandType
):
subtask_id = info.subtask.subtask_id
async with redirect_subtask_errors(self, [info.subtask]):
async with redirect_subtask_errors(self, [info.subtask], reraise=False):
try:
info.band_futures[band].set_result(result)
if result.error is not None:
Expand Down Expand Up @@ -262,9 +252,9 @@ async def finish_subtasks(

if subtask_info is not None:
if subtask_band is not None:
logger.warning("BEFORE await self._handle_subtask_result(subtask_info, result, subtask_band)")
await self._handle_subtask_result(subtask_info, result, subtask_band)
logger.warning("AFTER await self._handle_subtask_result(subtask_info, result, subtask_band)")
await self._handle_subtask_result(
subtask_info, result, subtask_band
)

self._finished_subtask_count.record(
1,
Expand All @@ -275,16 +265,15 @@ async def finish_subtasks(
},
)
self._subtask_summaries[subtask_id] = subtask_info.to_summary(
is_finished=True, is_cancelled=result.status == SubtaskStatus.cancelled
is_finished=True,
is_cancelled=result.status == SubtaskStatus.cancelled,
)
subtask_info.end_time = time.time()
self._speculation_execution_scheduler.finish_subtask(subtask_info)
# Cancel subtask on other bands.
aio_task = subtask_info.band_futures.pop(subtask_band, None)
if aio_task:
logger.warning("BEFORE await aio_task")
await aio_task
logger.warning("AFTER await aio_task")
if schedule_next:
band_tasks[subtask_band] += 1
if subtask_info.band_futures:
Expand All @@ -304,7 +293,6 @@ async def finish_subtasks(
if schedule_next:
for band in subtask_info.band_futures.keys():
band_tasks[band] += 1
# await self._queueing_ref.remove_queued_subtasks(subtask_ids)
if band_tasks:
await self._queueing_ref.submit_subtasks.tell(dict(band_tasks))

Expand Down Expand Up @@ -345,7 +333,9 @@ async def batch_submit_subtask_to_band(self, args_list, kwargs_list):
band_to_subtask_ids[band].append(subtask_id)

if res_release_delays:
await self._global_resource_ref.release_subtask_resource.batch(*res_release_delays)
await self._global_resource_ref.release_subtask_resource.batch(
*res_release_delays
)

for band, subtask_ids in band_to_subtask_ids.items():
asyncio.create_task(self._submit_subtasks_to_band(band, subtask_ids))
Expand Down Expand Up @@ -386,29 +376,22 @@ async def cancel_subtasks(
subtask_ids,
kill_timeout,
)
queued_subtask_ids = []
single_cancel_tasks = []

task_api = await self._get_task_api()

async def cancel_single_task(subtask, raw_tasks, cancel_tasks):
if cancel_tasks:
await asyncio.wait(cancel_tasks)
if raw_tasks:
dones, _ = await asyncio.wait(raw_tasks)
else:
dones = []
if not dones or all(fut.cancelled() for fut in dones):
await task_api.set_subtask_result(
SubtaskResult(
subtask_id=subtask.subtask_id,
session_id=subtask.session_id,
task_id=subtask.task_id,
stage_id=subtask.stage_id,
status=SubtaskStatus.cancelled,
)
)
async def cancel_task_in_band(band):
cancel_delays = band_to_cancel_delays.get(band) or []
execution_ref = await self._get_execution_ref(band)
if cancel_delays:
await execution_ref.cancel_subtask.batch(*cancel_delays)
band_futures = band_to_futures.get(band)
if band_futures:
await asyncio.wait(band_futures)

queued_subtask_ids = []
cancel_tasks = []
band_to_cancel_delays = defaultdict(list)
band_to_futures = defaultdict(list)
for subtask_id in subtask_ids:
if subtask_id not in self._subtask_infos:
# subtask may already finished or not submitted at all
Expand All @@ -423,35 +406,33 @@ async def cancel_single_task(subtask, raw_tasks, cancel_tasks):
raw_tasks_to_cancel = list(info.band_futures.values())

if not raw_tasks_to_cancel:
queued_subtask_ids.append(subtask_id)
single_cancel_tasks.append(
asyncio.create_task(
cancel_single_task(info.subtask, [], [])
)
# not submitted yet: mark subtasks as cancelled
result = SubtaskResult(
subtask_id=info.subtask.subtask_id,
session_id=info.subtask.session_id,
task_id=info.subtask.task_id,
stage_id=info.subtask.stage_id,
status=SubtaskStatus.cancelled,
)
cancel_tasks.append(task_api.set_subtask_result(result))
queued_subtask_ids.append(subtask_id)
else:
cancel_tasks = []
for band in info.band_futures.keys():
for band, future in info.band_futures.items():
execution_ref = await self._get_execution_ref(band)
cancel_tasks.append(
asyncio.create_task(
execution_ref.cancel_subtask(
subtask_id, kill_timeout=kill_timeout
)
)
band_to_cancel_delays[band].append(
execution_ref.cancel_subtask.delay(subtask_id, kill_timeout)
)
single_cancel_tasks.append(
asyncio.create_task(
cancel_single_task(
info.subtask, raw_tasks_to_cancel, cancel_tasks
)
)
)
band_to_futures[band].append(future)

for band in band_to_futures:
cancel_tasks.append(asyncio.create_task(cancel_task_in_band(band)))

if queued_subtask_ids:
# Don't use `finish_subtasks` because it may remove queued
await self._queueing_ref.remove_queued_subtasks(queued_subtask_ids)
if single_cancel_tasks:
yield asyncio.wait(single_cancel_tasks)

if cancel_tasks:
yield asyncio.gather(*cancel_tasks)

for subtask_id in subtask_ids:
info = self._subtask_infos.pop(subtask_id, None)
Expand Down
14 changes: 1 addition & 13 deletions mars/services/scheduling/supervisor/queueing.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ async def _submission_task_func(self):

async def _submit_subtask_request(self, band_to_limit: Dict[BandType, int] = None):
if band_to_limit:
logger.debug(
"TMP_QUEUE_PROBE: Submitting subtasks with limits: %r", band_to_limit
)
logger.debug("Submitting subtasks with limits: %r", band_to_limit)

if not self._band_to_resource or any(
not limit and band not in self._band_to_resource
Expand Down Expand Up @@ -274,8 +272,6 @@ def _load_items_to_submit():

await asyncio.to_thread(_load_items_to_submit)

logger.debug("TMP_QUEUE_PROBE: Finished picking top subtasks")

async with redirect_subtask_errors(
self,
(
Expand All @@ -288,11 +284,6 @@ def _load_items_to_submit():
*apply_delays
)

logger.debug(
"TMP_QUEUE_PROBE: Finished band resource allocation, %d subtasks submitted",
sum(len(ids) for ids in submitted_ids_list),
)

manager_ref = await self._get_manager_ref()
submit_delays = []

Expand Down Expand Up @@ -336,10 +327,7 @@ def _gather_submissions():
heapq.heappush(task_queue, submit_items[stid])

await asyncio.to_thread(_gather_submissions)

logger.debug("TMP_QUEUE_PROBE: Start subtask submission in batch")
await manager_ref.submit_subtask_to_band.batch(*submit_delays)
logger.debug("TMP_QUEUE_PROBE: Finished subtask submission")

def _ensure_top_item_valid(self, task_queue):
"""Clean invalid subtask item from the queue to ensure that when the queue is not empty,
Expand Down
42 changes: 18 additions & 24 deletions mars/services/scheduling/supervisor/tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
import time
from collections import defaultdict
from typing import List, Dict, Tuple, Set

Expand Down Expand Up @@ -91,40 +92,33 @@ async def run_subtask(

async def task_fun():
task_api = await TaskAPI.create(subtask.session_id, supervisor_address)
result = SubtaskResult(
subtask_id=subtask.subtask_id,
session_id=subtask.session_id,
task_id=subtask.task_id,
stage_id=subtask.stage_id,
bands=[(self.address, band_name)],
progress=1.0,
execution_start_time=time.time(),
)
try:
await asyncio.sleep(20)
except asyncio.CancelledError as ex:
await task_api.set_subtask_result(
SubtaskResult(
subtask_id=subtask.subtask_id,
session_id=subtask.session_id,
task_id=subtask.task_id,
stage_id=subtask.stage_id,
bands=[(self.address, band_name)],
status=SubtaskStatus.cancelled,
progress=1.0,
error=ex,
traceback=ex.__traceback__,
)
)
result.status = SubtaskStatus.cancelled
result.error = ex
result.traceback = ex.__traceback__
await task_api.set_subtask_result(result)
raise
else:
await task_api.set_subtask_result(
SubtaskResult(
subtask_id=subtask.subtask_id,
session_id=subtask.session_id,
task_id=subtask.task_id,
stage_id=subtask.stage_id,
status=SubtaskStatus.succeeded,
bands=[(self.address, band_name)],
progress=1.0,
)
)
result.status = SubtaskStatus.succeeded
result.execution_end_time = time.time()
await task_api.set_subtask_result(result)

self._subtask_aiotasks[subtask.subtask_id][band_name] = asyncio.create_task(
task_fun()
)

@mo.extensible
def cancel_subtask(self, subtask_id: str, kill_timeout: int = 5):
for task in self._subtask_aiotasks[subtask_id].values():
task.cancel()
Expand Down
4 changes: 1 addition & 3 deletions mars/services/scheduling/tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ async def set_subtask_result(self, subtask_result: SubtaskResult):
for event in self._events[subtask_result.subtask_id]:
event.set()
self._events.pop(subtask_result.subtask_id, None)
await scheduling_api.finish_subtasks(
[subtask_result], subtask_result.bands
)
await scheduling_api.finish_subtasks([subtask_result], subtask_result.bands)

def _return_result(self, subtask_id: str):
result = self._results[subtask_id]
Expand Down
7 changes: 5 additions & 2 deletions mars/services/scheduling/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ async def _get_task_api(actor: mo.Actor):


@contextlib.asynccontextmanager
async def redirect_subtask_errors(actor: mo.Actor, subtasks: Iterable[Subtask]):
async def redirect_subtask_errors(
actor: mo.Actor, subtasks: Iterable[Subtask], reraise: bool = True
):
try:
yield
except: # noqa: E722 # pylint: disable=bare-except
Expand Down Expand Up @@ -60,4 +62,5 @@ async def redirect_subtask_errors(actor: mo.Actor, subtasks: Iterable[Subtask]):
)
tasks = [asyncio.ensure_future(coro) for coro in coros]
await asyncio.wait(tasks)
raise
if reraise:
raise
Loading

0 comments on commit 5e80e46

Please sign in to comment.