diff --git a/dmoj/graders/base.py b/dmoj/graders/base.py index 4f60f7bfa..96e32fc59 100644 --- a/dmoj/graders/base.py +++ b/dmoj/graders/base.py @@ -21,14 +21,6 @@ def grade(self, case): def _generate_binary(self): raise NotImplementedError - def abort_grading(self): - self._abort_requested = True - if self._current_proc: - try: - self._current_proc.kill() - except OSError: - pass - def _resolve_testcases(self, cfg, batch_no=0): cases = [] for case_config in cfg: diff --git a/dmoj/judge.py b/dmoj/judge.py index cc615c31f..72d5e040a 100644 --- a/dmoj/judge.py +++ b/dmoj/judge.py @@ -2,8 +2,10 @@ import logging import multiprocessing import os +import shutil import signal import sys +import tempfile import threading import traceback from enum import Enum @@ -41,14 +43,14 @@ class IPC(Enum): BATCH_END = 'BATCH-END' GRADING_BEGIN = 'GRADING-BEGIN' GRADING_END = 'GRADING-END' - GRADING_ABORTED = 'GRADING-ABORTED' UNHANDLED_EXCEPTION = 'UNHANDLED-EXCEPTION' - REQUEST_ABORT = 'REQUEST-ABORT' + + +class JudgeWorkerAborted(Exception): + pass # This needs to be at least as large as the timeout for the largest compiler time limit, but we don't enforce that here. -# (Otherwise, aborting during a compilation that exceeds this time limit would result in a `TimeoutError` IE instead of -# a `CompileError`.) IPC_TIMEOUT = 60 # seconds @@ -128,8 +130,6 @@ def begin_grading(self, submission: Submission, report=logger.info, blocking=Fal ) ) - # FIXME(tbrindus): what if we receive an abort from the judge before IPC handshake completes? We'll send - # an abort request down the pipe, possibly messing up the handshake. self.current_judge_worker = JudgeWorker(submission) ipc_ready_signal = threading.Event() @@ -147,13 +147,19 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non assert self.current_judge_worker is not None try: + worker_tempdir = None + + def _ipc_hello(_report, tempdir: str): + nonlocal worker_tempdir + ipc_ready_signal.set() + worker_tempdir = tempdir + ipc_handler_dispatch: Dict[IPC, Callable] = { - IPC.HELLO: lambda _report: ipc_ready_signal.set(), + IPC.HELLO: _ipc_hello, IPC.COMPILE_ERROR: self._ipc_compile_error, IPC.COMPILE_MESSAGE: self._ipc_compile_message, IPC.GRADING_BEGIN: self._ipc_grading_begin, IPC.GRADING_END: self._ipc_grading_end, - IPC.GRADING_ABORTED: self._ipc_grading_aborted, IPC.BATCH_BEGIN: self._ipc_batch_begin, IPC.BATCH_END: self._ipc_batch_end, IPC.RESULT: self._ipc_result, @@ -176,12 +182,22 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non % (self.current_submission.problem_id, self.current_submission.id) ) ) + except JudgeWorkerAborted: + self.packet_manager.submission_aborted_packet() except Exception: # noqa: E722, we want to catch everything self.log_internal_error() finally: self.current_judge_worker.wait_with_timeout() self.current_judge_worker = None + print('cleaning up', worker_tempdir) + os.system('ls -al %s' % worker_tempdir) + if worker_tempdir: + try: + shutil.rmtree(worker_tempdir) + except: # noqa: E722 + pass + # Might not have been set if an exception was encountered before HELLO message, so signal here to keep the # other side from waiting forever. ipc_ready_signal.set() @@ -232,10 +248,6 @@ def _ipc_batch_begin(self, report, batch_number: int) -> None: def _ipc_batch_end(self, _report, _batch_number: int) -> None: self.packet_manager.batch_end_packet() - def _ipc_grading_aborted(self, report) -> None: - self.packet_manager.submission_aborted_packet() - report(ansi_style('#ansi[Forcefully terminating grading. Temporary files may not be deleted.](red|bold)')) - def _ipc_unhandled_exception(self, _report, message: str) -> None: logger.error('Unhandled exception in worker process') self.log_internal_error(message=message) @@ -254,10 +266,9 @@ def abort_grading(self, submission_id: Optional[int] = None) -> None: 'Received abortion request for %d, but %d is currently running', submission_id, worker.submission.id ) else: - logger.info('Received abortion request for %d', worker.submission.id) - # These calls are idempotent, so it doesn't matter if we raced and the worker has exited already. - worker.request_abort_grading() - worker.wait_with_timeout() + logger.info('Received abortion request for %d, killing worker', worker.submission.id) + # This call is idempotent, so it doesn't matter if we raced and the worker has exited already. + worker.abort_grading__kill_worker() def listen(self) -> None: """ @@ -270,7 +281,8 @@ def murder(self) -> None: """ End any submission currently executing, and exit the judge. """ - self.abort_grading() + if self.current_judge_worker: + self.current_judge_worker.abort_grading__kill_worker() self.updater_exit = True self.updater_signal.set() if self.packet_manager: @@ -304,8 +316,8 @@ def log_internal_error(self, exc: Optional[BaseException] = None, message: Optio class JudgeWorker: def __init__(self, submission: Submission) -> None: self.submission = submission - self._abort_requested = False - self._sent_sigkill_to_worker_process = False + self._aborted = False + self._timed_out = False # FIXME(tbrindus): marked Any pending grader cleanups. self.grader: Any = None @@ -331,8 +343,12 @@ def communicate(self) -> Generator[Tuple[IPC, tuple], None, None]: self.worker_process.kill() raise except EOFError: - if self._sent_sigkill_to_worker_process: - raise TimeoutError('worker did not exit in %d seconds, so it was killed' % IPC_TIMEOUT) + if self._aborted: + raise JudgeWorkerAborted() from None + + if self._timed_out: + raise TimeoutError('worker did not exit in %d seconds, so it was killed' % IPC_TIMEOUT) from None + raise except Exception: logger.error('Failed to read IPC message from worker!') @@ -354,16 +370,14 @@ def wait_with_timeout(self) -> None: finally: if self.worker_process.is_alive(): logger.error('Worker is still alive, sending SIGKILL!') - self._sent_sigkill_to_worker_process = True + self._timed_out = True self.worker_process.kill() - def request_abort_grading(self) -> None: - assert self.worker_process_conn - - try: - self.worker_process_conn.send((IPC.REQUEST_ABORT, ())) - except Exception: - logger.exception('Failed to send abort request to worker, did it race?') + def abort_grading__kill_worker(self) -> None: + if self.worker_process and self.worker_process.is_alive(): + self._aborted = True + self.worker_process.kill() + self.worker_process.join(timeout=1) def _worker_process_main( self, @@ -384,15 +398,12 @@ def _ipc_recv_thread_main() -> None: while True: try: ipc_type, data = judge_process_conn.recv() - except: # noqa: E722, whatever happened, we have to abort now. + except: # noqa: E722, whatever happened, we have to exit now. logger.exception('Judge unexpectedly hung up!') - self._do_abort() return if ipc_type == IPC.BYE: return - elif ipc_type == IPC.REQUEST_ABORT: - self._do_abort() else: raise RuntimeError('worker got unexpected IPC message from judge: %s' % ((ipc_type, data),)) @@ -402,9 +413,12 @@ def _report_unhandled_exception() -> None: judge_process_conn.send((IPC.UNHANDLED_EXCEPTION, (message,))) judge_process_conn.send((IPC.BYE, ())) + tempdir = tempfile.mkdtemp('dmoj-judge-worker') + tempfile.tempdir = tempdir + ipc_recv_thread = None try: - judge_process_conn.send((IPC.HELLO, ())) + judge_process_conn.send((IPC.HELLO, (tempdir,))) ipc_recv_thread = threading.Thread(target=_ipc_recv_thread_main, daemon=True) ipc_recv_thread.start() @@ -439,15 +453,6 @@ def _report_unhandled_exception() -> None: if ipc_recv_thread.is_alive(): logger.error('Judge IPC recv thread is still alive after timeout, shutting worker down anyway!') - # FIXME(tbrindus): we need to do this because cleaning up temporary directories happens on __del__, which - # won't get called if we exit the process right now (so we'd leak all files created by the grader). This - # should be refactored to have an explicit `cleanup()` or similar, rather than relying on refcounting - # working out. - from dmoj.executors.compiled_executor import _CompiledExecutorMeta - - for cached_executor in _CompiledExecutorMeta.compiled_binary_cache.values(): - cached_executor.is_cached = False - cached_executor.cleanup() self.grader = None def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]: @@ -503,11 +508,6 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]: else: result = self.grader.grade(case) - # If the submission was killed due to a user-initiated abort, any result is meaningless. - if self._abort_requested: - yield IPC.GRADING_ABORTED, () - return - if result.result_flag & Result.WA: # If we failed a 0-point case, we will short-circuit every case after this. is_short_circuiting_enabled |= not case.points @@ -532,11 +532,6 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]: yield IPC.GRADING_END, () - def _do_abort(self) -> None: - self._abort_requested = True - if self.grader: - self.grader.abort_grading() - class ClassicJudge(Judge): def __init__(self, host, port, **kwargs) -> None: