Skip to content

Commit

Permalink
judge: implement instant aborts
Browse files Browse the repository at this point in the history
The way this works is:

- Worker creates a tempdir, and sets `tempfile.tempdir` to this directory.
- Worker sends back the tempdir. The parent process is responsible for cleaning
  it up when the worker exits.

Abortions are then implemented as sending `SIGKILL` to the worker.

As a side benefit of this implementation, we also get to drop the hacky
`CompiledExecutor` cache deletion.
  • Loading branch information
Xyene committed Dec 26, 2023
1 parent 5eb5f59 commit 041136c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 61 deletions.
8 changes: 0 additions & 8 deletions dmoj/graders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@ def grade(self, case):
def _generate_binary(self):
raise NotImplementedError

def abort_grading(self):
self._abort_requested = True
if self._current_proc:
try:
self._current_proc.kill()
except OSError:
pass

def _resolve_testcases(self, cfg, batch_no=0):
cases = []
for case_config in cfg:
Expand Down
101 changes: 48 additions & 53 deletions dmoj/judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import logging
import multiprocessing
import os
import shutil
import signal
import sys
import tempfile
import threading
import traceback
from enum import Enum
Expand Down Expand Up @@ -41,14 +43,14 @@ class IPC(Enum):
BATCH_END = 'BATCH-END'
GRADING_BEGIN = 'GRADING-BEGIN'
GRADING_END = 'GRADING-END'
GRADING_ABORTED = 'GRADING-ABORTED'
UNHANDLED_EXCEPTION = 'UNHANDLED-EXCEPTION'
REQUEST_ABORT = 'REQUEST-ABORT'


class JudgeWorkerAborted(Exception):
pass


# This needs to be at least as large as the timeout for the largest compiler time limit, but we don't enforce that here.
# (Otherwise, aborting during a compilation that exceeds this time limit would result in a `TimeoutError` IE instead of
# a `CompileError`.)
IPC_TIMEOUT = 60 # seconds


Expand Down Expand Up @@ -128,8 +130,6 @@ def begin_grading(self, submission: Submission, report=logger.info, blocking=Fal
)
)

# FIXME(tbrindus): what if we receive an abort from the judge before IPC handshake completes? We'll send
# an abort request down the pipe, possibly messing up the handshake.
self.current_judge_worker = JudgeWorker(submission)

ipc_ready_signal = threading.Event()
Expand All @@ -147,13 +147,19 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non
assert self.current_judge_worker is not None

try:
worker_tempdir = None

def _ipc_hello(_report, tempdir: str):
nonlocal worker_tempdir
ipc_ready_signal.set()
worker_tempdir = tempdir

ipc_handler_dispatch: Dict[IPC, Callable] = {
IPC.HELLO: lambda _report: ipc_ready_signal.set(),
IPC.HELLO: _ipc_hello,
IPC.COMPILE_ERROR: self._ipc_compile_error,
IPC.COMPILE_MESSAGE: self._ipc_compile_message,
IPC.GRADING_BEGIN: self._ipc_grading_begin,
IPC.GRADING_END: self._ipc_grading_end,
IPC.GRADING_ABORTED: self._ipc_grading_aborted,
IPC.BATCH_BEGIN: self._ipc_batch_begin,
IPC.BATCH_END: self._ipc_batch_end,
IPC.RESULT: self._ipc_result,
Expand All @@ -176,12 +182,22 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non
% (self.current_submission.problem_id, self.current_submission.id)
)
)
except JudgeWorkerAborted:
self.packet_manager.submission_aborted_packet()
except Exception: # noqa: E722, we want to catch everything
self.log_internal_error()
finally:
self.current_judge_worker.wait_with_timeout()
self.current_judge_worker = None

print('cleaning up', worker_tempdir)
os.system('ls -al %s' % worker_tempdir)
if worker_tempdir:
try:
shutil.rmtree(worker_tempdir)
except: # noqa: E722
pass

# Might not have been set if an exception was encountered before HELLO message, so signal here to keep the
# other side from waiting forever.
ipc_ready_signal.set()
Expand Down Expand Up @@ -232,10 +248,6 @@ def _ipc_batch_begin(self, report, batch_number: int) -> None:
def _ipc_batch_end(self, _report, _batch_number: int) -> None:
self.packet_manager.batch_end_packet()

def _ipc_grading_aborted(self, report) -> None:
self.packet_manager.submission_aborted_packet()
report(ansi_style('#ansi[Forcefully terminating grading. Temporary files may not be deleted.](red|bold)'))

def _ipc_unhandled_exception(self, _report, message: str) -> None:
logger.error('Unhandled exception in worker process')
self.log_internal_error(message=message)
Expand All @@ -254,10 +266,9 @@ def abort_grading(self, submission_id: Optional[int] = None) -> None:
'Received abortion request for %d, but %d is currently running', submission_id, worker.submission.id
)
else:
logger.info('Received abortion request for %d', worker.submission.id)
# These calls are idempotent, so it doesn't matter if we raced and the worker has exited already.
worker.request_abort_grading()
worker.wait_with_timeout()
logger.info('Received abortion request for %d, killing worker', worker.submission.id)
# This call is idempotent, so it doesn't matter if we raced and the worker has exited already.
worker.abort_grading__kill_worker()

def listen(self) -> None:
"""
Expand All @@ -270,7 +281,8 @@ def murder(self) -> None:
"""
End any submission currently executing, and exit the judge.
"""
self.abort_grading()
if self.current_judge_worker:
self.current_judge_worker.abort_grading__kill_worker()
self.updater_exit = True
self.updater_signal.set()
if self.packet_manager:
Expand Down Expand Up @@ -304,8 +316,8 @@ def log_internal_error(self, exc: Optional[BaseException] = None, message: Optio
class JudgeWorker:
def __init__(self, submission: Submission) -> None:
self.submission = submission
self._abort_requested = False
self._sent_sigkill_to_worker_process = False
self._aborted = False
self._timed_out = False
# FIXME(tbrindus): marked Any pending grader cleanups.
self.grader: Any = None

Expand All @@ -331,8 +343,12 @@ def communicate(self) -> Generator[Tuple[IPC, tuple], None, None]:
self.worker_process.kill()
raise
except EOFError:
if self._sent_sigkill_to_worker_process:
raise TimeoutError('worker did not exit in %d seconds, so it was killed' % IPC_TIMEOUT)
if self._aborted:
raise JudgeWorkerAborted() from None

if self._timed_out:
raise TimeoutError('worker did not exit in %d seconds, so it was killed' % IPC_TIMEOUT) from None

raise
except Exception:
logger.error('Failed to read IPC message from worker!')
Expand All @@ -354,16 +370,14 @@ def wait_with_timeout(self) -> None:
finally:
if self.worker_process.is_alive():
logger.error('Worker is still alive, sending SIGKILL!')
self._sent_sigkill_to_worker_process = True
self._timed_out = True
self.worker_process.kill()

def request_abort_grading(self) -> None:
assert self.worker_process_conn

try:
self.worker_process_conn.send((IPC.REQUEST_ABORT, ()))
except Exception:
logger.exception('Failed to send abort request to worker, did it race?')
def abort_grading__kill_worker(self) -> None:
if self.worker_process and self.worker_process.is_alive():
self._aborted = True
self.worker_process.kill()
self.worker_process.join(timeout=1)

def _worker_process_main(
self,
Expand All @@ -384,15 +398,12 @@ def _ipc_recv_thread_main() -> None:
while True:
try:
ipc_type, data = judge_process_conn.recv()
except: # noqa: E722, whatever happened, we have to abort now.
except: # noqa: E722, whatever happened, we have to exit now.
logger.exception('Judge unexpectedly hung up!')
self._do_abort()
return

if ipc_type == IPC.BYE:
return
elif ipc_type == IPC.REQUEST_ABORT:
self._do_abort()
else:
raise RuntimeError('worker got unexpected IPC message from judge: %s' % ((ipc_type, data),))

Expand All @@ -402,9 +413,12 @@ def _report_unhandled_exception() -> None:
judge_process_conn.send((IPC.UNHANDLED_EXCEPTION, (message,)))
judge_process_conn.send((IPC.BYE, ()))

tempdir = tempfile.mkdtemp('dmoj-judge-worker')
tempfile.tempdir = tempdir

ipc_recv_thread = None
try:
judge_process_conn.send((IPC.HELLO, ()))
judge_process_conn.send((IPC.HELLO, (tempdir,)))

ipc_recv_thread = threading.Thread(target=_ipc_recv_thread_main, daemon=True)
ipc_recv_thread.start()
Expand Down Expand Up @@ -439,15 +453,6 @@ def _report_unhandled_exception() -> None:
if ipc_recv_thread.is_alive():
logger.error('Judge IPC recv thread is still alive after timeout, shutting worker down anyway!')

# FIXME(tbrindus): we need to do this because cleaning up temporary directories happens on __del__, which
# won't get called if we exit the process right now (so we'd leak all files created by the grader). This
# should be refactored to have an explicit `cleanup()` or similar, rather than relying on refcounting
# working out.
from dmoj.executors.compiled_executor import _CompiledExecutorMeta

for cached_executor in _CompiledExecutorMeta.compiled_binary_cache.values():
cached_executor.is_cached = False
cached_executor.cleanup()
self.grader = None

def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:
Expand Down Expand Up @@ -503,11 +508,6 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:
else:
result = self.grader.grade(case)

# If the submission was killed due to a user-initiated abort, any result is meaningless.
if self._abort_requested:
yield IPC.GRADING_ABORTED, ()
return

if result.result_flag & Result.WA:
# If we failed a 0-point case, we will short-circuit every case after this.
is_short_circuiting_enabled |= not case.points
Expand All @@ -532,11 +532,6 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:

yield IPC.GRADING_END, ()

def _do_abort(self) -> None:
self._abort_requested = True
if self.grader:
self.grader.abort_grading()


class ClassicJudge(Judge):
def __init__(self, host, port, **kwargs) -> None:
Expand Down

0 comments on commit 041136c

Please sign in to comment.