Skip to content

Commit

Permalink
Implement graceful memfd fallback for FreeBSD
Browse files Browse the repository at this point in the history
  • Loading branch information
quantum5 committed Dec 30, 2024
1 parent 8f4ba33 commit 52948c7
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 43 deletions.
4 changes: 2 additions & 2 deletions dmoj/cptbox/_cptbox.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ AT_FDCWD: int
bsd_get_proc_cwd: Callable[[int], str]
bsd_get_proc_fdno: Callable[[int, int], str]

memory_fd_create: Callable[[], int]
memory_fd_seal: Callable[[int], None]
memfd_create: Callable[[], int]
memfd_seal: Callable[[int], None]

class BufferProxy:
def _get_real_buffer(self): ...
12 changes: 6 additions & 6 deletions dmoj/cptbox/_cptbox.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ cdef extern from 'helper.h' nogil:
PTBOX_SPAWN_FAIL_EXECVE
PTBOX_SPAWN_FAIL_SETAFFINITY

int _memory_fd_create "memory_fd_create"()
int _memory_fd_seal "memory_fd_seal"(int fd)
int cptbox_memfd_create()
int cptbox_memfd_seal(int fd)


cdef extern from 'fcntl.h' nogil:
Expand Down Expand Up @@ -215,14 +215,14 @@ def bsd_get_proc_fdno(pid_t pid, int fd):
free(buf)
return res

def memory_fd_create():
cdef int fd = _memory_fd_create()
def memfd_create():
cdef int fd = cptbox_memfd_create()
if fd < 0:
PyErr_SetFromErrno(OSError)
return fd

def memory_fd_seal(int fd):
cdef int result = _memory_fd_seal(fd)
def memfd_seal(int fd):
cdef int result = cptbox_memfd_seal(fd)
if result == -1:
PyErr_SetFromErrno(OSError)

Expand Down
11 changes: 4 additions & 7 deletions dmoj/cptbox/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,19 +328,16 @@ char *bsd_get_proc_fdno(pid_t pid, int fdno) {
return bsd_get_proc_fd(pid, 0, fdno);
}

int memory_fd_create(void) {
int cptbox_memfd_create(void) {
#ifdef __FreeBSD__
char filename[] = "/tmp/cptbox-memoryfd-XXXXXXXX";
int fd = mkstemp(filename);
if (fd >= 0)
unlink(filename);
return fd;
errno = ENOSYS;
return -1;
#else
return memfd_create("cptbox memory_fd", MFD_ALLOW_SEALING);
#endif
}

int memory_fd_seal(int fd) {
int cptbox_memfd_seal(int fd) {
#ifdef __FreeBSD__
errno = ENOSYS;
return -1;
Expand Down
4 changes: 2 additions & 2 deletions dmoj/cptbox/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ int cptbox_child_run(const struct child_config *config);
char *bsd_get_proc_cwd(pid_t pid);
char *bsd_get_proc_fdno(pid_t pid, int fdno);

int memory_fd_create(void);
int memory_fd_seal(int fd);
int cptbox_memfd_create(void);
int cptbox_memfd_seal(int fd);

#endif
113 changes: 92 additions & 21 deletions dmoj/cptbox/utils.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,42 @@
import errno
import io
import mmap
import os
from abc import ABCMeta, abstractmethod
from tempfile import NamedTemporaryFile, TemporaryFile
from typing import Optional

from dmoj.cptbox._cptbox import memory_fd_create, memory_fd_seal
from dmoj.cptbox._cptbox import memfd_create, memfd_seal


class MemoryIO(io.FileIO):
def __init__(self, prefill: Optional[bytes] = None, seal=False) -> None:
super().__init__(memory_fd_create(), 'r+')
def _make_fd_readonly(fd):
new_fd = os.open(f'/proc/self/fd/{fd}', os.O_RDONLY)
try:
os.dup2(new_fd, fd)
finally:
os.close(new_fd)


class MmapableIO(io.FileIO, metaclass=ABCMeta):
def __init__(self, fd, *, prefill: Optional[bytes] = None, seal=False) -> None:
super().__init__(fd, 'r+')

if prefill:
self.write(prefill)
if seal:
self.seal()

def seal(self) -> None:
fd = self.fileno()
try:
memory_fd_seal(fd)
except OSError as e:
if e.errno == errno.ENOSYS:
# FreeBSD
self.seek(0, os.SEEK_SET)
return
raise
@classmethod
@abstractmethod
def usable_with_name(cls) -> bool:
...

new_fd = os.open(f'/proc/self/fd/{fd}', os.O_RDONLY)
try:
os.dup2(new_fd, fd)
finally:
os.close(new_fd)
@abstractmethod
def seal(self) -> None:
...

@abstractmethod
def to_path(self) -> str:
return f'/proc/{os.getpid()}/fd/{self.fileno()}'
...

def to_bytes(self) -> bytes:
try:
Expand All @@ -43,3 +46,71 @@ def to_bytes(self) -> bytes:
if e.args[0] == 'cannot mmap an empty file':
return b''
raise


class NamedFileIO(MmapableIO):
_name: str

def __init__(self, *, prefill: Optional[bytes] = None, seal=False) -> None:
with NamedTemporaryFile(delete=False) as f:
self._name = f.name
super().__init__(os.dup(f.fileno()), prefill=prefill, seal=seal)

def seal(self) -> None:
self.seek(0, os.SEEK_SET)

def close(self) -> None:
super().close()
os.unlink(self._name)

def to_path(self) -> str:
return self._name

@classmethod
def usable_with_name(cls):
return True


class UnnamedFileIO(MmapableIO):
def __init__(self, *, prefill: Optional[bytes] = None, seal=False) -> None:
with TemporaryFile() as f:
super().__init__(os.dup(f.fileno()), prefill=prefill, seal=seal)

def seal(self) -> None:
self.seek(0, os.SEEK_SET)
_make_fd_readonly(self.fileno())

def to_path(self) -> str:
return f'/proc/{os.getpid()}/fd/{self.fileno()}'

@classmethod
def usable_with_name(cls):
with cls() as f:
return os.path.exists(f.to_path())


class MemfdIO(MmapableIO):
def __init__(self, *, prefill: Optional[bytes] = None, seal=False) -> None:
super().__init__(memfd_create(), prefill=prefill, seal=seal)

def seal(self) -> None:
fd = self.fileno()
memfd_seal(fd)
_make_fd_readonly(fd)

def to_path(self) -> str:
return f'/proc/{os.getpid()}/fd/{self.fileno()}'

@classmethod
def usable_with_name(cls):
try:
with cls() as f:
return os.path.exists(f.to_path())
except OSError:
return False


# Try to use memfd if possible, otherwise fallback to unlinked temporary files
# (UnnamedFileIO). On FreeBSD and some other systems, /proc/[pid]/fd doesn't
# exist, so to_path() will not work. We fall back to NamedFileIO in that case.
MemoryIO = next((i for i in (MemfdIO, UnnamedFileIO, NamedFileIO) if i.usable_with_name()))
10 changes: 5 additions & 5 deletions dmoj/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from dmoj import checkers
from dmoj.checkers import Checker
from dmoj.config import ConfigNode, InvalidInitException
from dmoj.cptbox.utils import MemoryIO
from dmoj.cptbox.utils import MemoryIO, MmapableIO
from dmoj.judgeenv import env, get_problem_root
from dmoj.utils.helper_files import compile_with_auxiliary_files, parse_helper_file_error
from dmoj.utils.module import load_module_from_file
Expand Down Expand Up @@ -275,7 +275,7 @@ def open(self, key: str):
return self.archive.open(zipinfo)
raise KeyError('file "%s" could not be found in "%s"' % (key, self.problem_root_dir))

def as_fd(self, key: str, normalize: bool = False) -> MemoryIO:
def as_fd(self, key: str, normalize: bool = False) -> MmapableIO:
memory = MemoryIO()
with self.open(key) as f:
if normalize:
Expand Down Expand Up @@ -344,7 +344,7 @@ class TestCase(BaseTestCase):
batch: int
output_prefix_length: int
has_binary_data: bool
_input_data_fd: Optional[MemoryIO]
_input_data_fd: Optional[MmapableIO]
_generated: Optional[Tuple[bytes, bytes]]

def __init__(self, count: int, batch_no: int, config: ConfigNode, problem: Problem):
Expand Down Expand Up @@ -451,14 +451,14 @@ def _run_generator(self, gen: Union[str, ConfigNode], args: Optional[Iterable[st
def input_data(self) -> bytes:
return self.input_data_fd().to_bytes()

def input_data_fd(self) -> MemoryIO:
def input_data_fd(self) -> MmapableIO:
if self._input_data_fd:
return self._input_data_fd

result = self._input_data_fd = self._make_input_data_fd()
return result

def _make_input_data_fd(self) -> MemoryIO:
def _make_input_data_fd(self) -> MmapableIO:
gen = self.config.generator

# don't try running the generator if we specify an output file explicitly,
Expand Down

0 comments on commit 52948c7

Please sign in to comment.