From a33b5bf9ea3f6b5a7e785e1cac9596e7fc0273ff Mon Sep 17 00:00:00 2001 From: Tudor Brindus Date: Mon, 25 Dec 2023 17:54:26 -0500 Subject: [PATCH] judge: implement instant aborts The way this works is: - Worker creates a tempdir, and sets `tempfile.tempdir` to this directory. - Worker sends back the tempdir. The parent process is responsible for cleaning it up when the worker exits. Abortions are then implemented as sending `SIGKILL` to the worker. As a side benefit of this implementation, we also get to drop the hacky `CompiledExecutor` cache deletion. --- dmoj/graders/base.py | 5 +++ dmoj/judge.py | 96 ++++++++++++++++++++------------------------ 2 files changed, 48 insertions(+), 53 deletions(-) diff --git a/dmoj/graders/base.py b/dmoj/graders/base.py index 9966c0907..4151ec076 100644 --- a/dmoj/graders/base.py +++ b/dmoj/graders/base.py @@ -37,6 +37,7 @@ def grade(self, case: TestCase) -> Result: def _generate_binary(self) -> BaseExecutor: raise NotImplementedError +<<<<<<< HEAD def abort_grading(self) -> None: self._abort_requested = True if self._current_proc: @@ -47,6 +48,10 @@ def abort_grading(self) -> None: def _resolve_testcases(self, cfg, batch_no=0) -> List[BaseTestCase]: cases: List[BaseTestCase] = [] +======= + def _resolve_testcases(self, cfg, batch_no=0): + cases = [] +>>>>>>> 662a3127 (judge: implement instant aborts) for case_config in cfg: if 'batched' in case_config.raw_config: self._batch_counter += 1 diff --git a/dmoj/judge.py b/dmoj/judge.py index 5d252c564..8f2f6d12f 100644 --- a/dmoj/judge.py +++ b/dmoj/judge.py @@ -2,8 +2,10 @@ import logging import multiprocessing import os +import shutil import signal import sys +import tempfile import threading import traceback from enum import Enum @@ -41,14 +43,14 @@ class IPC(Enum): BATCH_END = 'BATCH-END' GRADING_BEGIN = 'GRADING-BEGIN' GRADING_END = 'GRADING-END' - GRADING_ABORTED = 'GRADING-ABORTED' UNHANDLED_EXCEPTION = 'UNHANDLED-EXCEPTION' - REQUEST_ABORT = 'REQUEST-ABORT' + + +class JudgeWorkerAborted(Exception): + pass # This needs to be at least as large as the timeout for the largest compiler time limit, but we don't enforce that here. -# (Otherwise, aborting during a compilation that exceeds this time limit would result in a `TimeoutError` IE instead of -# a `CompileError`.) IPC_TIMEOUT = 60 # seconds @@ -128,8 +130,6 @@ def begin_grading(self, submission: Submission, report=logger.info, blocking=Fal ) ) - # FIXME(tbrindus): what if we receive an abort from the judge before IPC handshake completes? We'll send - # an abort request down the pipe, possibly messing up the handshake. self.current_judge_worker = JudgeWorker(submission) ipc_ready_signal = threading.Event() @@ -147,13 +147,19 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non assert self.current_judge_worker is not None try: + worker_tempdir = None + + def _ipc_hello(_report, tempdir: str): + nonlocal worker_tempdir + ipc_ready_signal.set() + worker_tempdir = tempdir + ipc_handler_dispatch: Dict[IPC, Callable] = { - IPC.HELLO: lambda _report: ipc_ready_signal.set(), + IPC.HELLO: _ipc_hello, IPC.COMPILE_ERROR: self._ipc_compile_error, IPC.COMPILE_MESSAGE: self._ipc_compile_message, IPC.GRADING_BEGIN: self._ipc_grading_begin, IPC.GRADING_END: self._ipc_grading_end, - IPC.GRADING_ABORTED: self._ipc_grading_aborted, IPC.BATCH_BEGIN: self._ipc_batch_begin, IPC.BATCH_END: self._ipc_batch_end, IPC.RESULT: self._ipc_result, @@ -176,12 +182,17 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non % (self.current_submission.problem_id, self.current_submission.id) ) ) + except JudgeWorkerAborted: + self.packet_manager.submission_aborted_packet() except Exception: # noqa: E722, we want to catch everything self.log_internal_error() finally: self.current_judge_worker.wait_with_timeout() self.current_judge_worker = None + if worker_tempdir: + shutil.rmtree(worker_tempdir) + # Might not have been set if an exception was encountered before HELLO message, so signal here to keep the # other side from waiting forever. ipc_ready_signal.set() @@ -232,10 +243,6 @@ def _ipc_batch_begin(self, report, batch_number: int) -> None: def _ipc_batch_end(self, _report, _batch_number: int) -> None: self.packet_manager.batch_end_packet() - def _ipc_grading_aborted(self, report) -> None: - self.packet_manager.submission_aborted_packet() - report(ansi_style('#ansi[Forcefully terminating grading. Temporary files may not be deleted.](red|bold)')) - def _ipc_unhandled_exception(self, _report, message: str) -> None: logger.error('Unhandled exception in worker process') self.log_internal_error(message=message) @@ -254,10 +261,9 @@ def abort_grading(self, submission_id: Optional[int] = None) -> None: 'Received abortion request for %d, but %d is currently running', submission_id, worker.submission.id ) else: - logger.info('Received abortion request for %d', worker.submission.id) - # These calls are idempotent, so it doesn't matter if we raced and the worker has exited already. - worker.request_abort_grading() - worker.wait_with_timeout() + logger.info('Received abortion request for %d, killing worker', worker.submission.id) + # This call is idempotent, so it doesn't matter if we raced and the worker has exited already. + worker.abort_grading__kill_worker() def listen(self) -> None: """ @@ -270,7 +276,8 @@ def murder(self) -> None: """ End any submission currently executing, and exit the judge. """ - self.abort_grading() + if self.current_judge_worker: + self.current_judge_worker.abort_grading__kill_worker() self.updater_exit = True self.updater_signal.set() if self.packet_manager: @@ -304,8 +311,8 @@ def log_internal_error(self, exc: Optional[BaseException] = None, message: Optio class JudgeWorker: def __init__(self, submission: Submission) -> None: self.submission = submission - self._abort_requested = False - self._sent_sigkill_to_worker_process = False + self._aborted = False + self._timed_out = False # FIXME(tbrindus): marked Any pending grader cleanups. self.grader: Any = None @@ -331,8 +338,12 @@ def communicate(self) -> Generator[Tuple[IPC, tuple], None, None]: self.worker_process.kill() raise except EOFError: - if self._sent_sigkill_to_worker_process: - raise TimeoutError('worker did not exit in %d seconds, so it was killed' % IPC_TIMEOUT) + if self._aborted: + raise JudgeWorkerAborted() from None + + if self._timed_out: + raise TimeoutError('worker did not exit in %d seconds, so it was killed' % IPC_TIMEOUT) from None + raise except Exception: logger.error('Failed to read IPC message from worker!') @@ -354,16 +365,14 @@ def wait_with_timeout(self) -> None: finally: if self.worker_process.is_alive(): logger.error('Worker is still alive, sending SIGKILL!') - self._sent_sigkill_to_worker_process = True + self._timed_out = True self.worker_process.kill() - def request_abort_grading(self) -> None: - assert self.worker_process_conn - - try: - self.worker_process_conn.send((IPC.REQUEST_ABORT, ())) - except Exception: - logger.exception('Failed to send abort request to worker, did it race?') + def abort_grading__kill_worker(self) -> None: + if self.worker_process and self.worker_process.is_alive(): + self._aborted = True + self.worker_process.kill() + self.worker_process.join(timeout=1) def _worker_process_main( self, @@ -384,15 +393,12 @@ def _ipc_recv_thread_main() -> None: while True: try: ipc_type, data = judge_process_conn.recv() - except: # noqa: E722, whatever happened, we have to abort now. + except: # noqa: E722, whatever happened, we have to exit now. logger.exception('Judge unexpectedly hung up!') - self._do_abort() return if ipc_type == IPC.BYE: return - elif ipc_type == IPC.REQUEST_ABORT: - self._do_abort() else: raise RuntimeError('worker got unexpected IPC message from judge: %s' % ((ipc_type, data),)) @@ -402,9 +408,12 @@ def _report_unhandled_exception() -> None: judge_process_conn.send((IPC.UNHANDLED_EXCEPTION, (message,))) judge_process_conn.send((IPC.BYE, ())) + tempdir = tempfile.mkdtemp('dmoj-judge-worker') + tempfile.tempdir = tempdir + ipc_recv_thread = None try: - judge_process_conn.send((IPC.HELLO, ())) + judge_process_conn.send((IPC.HELLO, (tempdir,))) ipc_recv_thread = threading.Thread(target=_ipc_recv_thread_main, daemon=True) ipc_recv_thread.start() @@ -439,15 +448,6 @@ def _report_unhandled_exception() -> None: if ipc_recv_thread.is_alive(): logger.error('Judge IPC recv thread is still alive after timeout, shutting worker down anyway!') - # FIXME(tbrindus): we need to do this because cleaning up temporary directories happens on __del__, which - # won't get called if we exit the process right now (so we'd leak all files created by the grader). This - # should be refactored to have an explicit `cleanup()` or similar, rather than relying on refcounting - # working out. - from dmoj.executors.compiled_executor import _CompiledExecutorMeta - - for cached_executor in _CompiledExecutorMeta.compiled_binary_cache.values(): - cached_executor.is_cached = False - cached_executor.cleanup() self.grader = None def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]: @@ -505,11 +505,6 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]: assert isinstance(case, TestCase) result = self.grader.grade(case) - # If the submission was killed due to a user-initiated abort, any result is meaningless. - if self._abort_requested: - yield IPC.GRADING_ABORTED, () - return - if result.result_flag & Result.WA: # If we failed a 0-point case, we will short-circuit every case after this. is_short_circuiting_enabled |= not case.points @@ -534,11 +529,6 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]: yield IPC.GRADING_END, () - def _do_abort(self) -> None: - self._abort_requested = True - if self.grader: - self.grader.abort_grading() - class ClassicJudge(Judge): def __init__(self, host, port, **kwargs) -> None: