Skip to content

Commit

Permalink
Update tests to account for shared memory
Browse files Browse the repository at this point in the history
  • Loading branch information
jb3 authored and MarkKoz committed Aug 28, 2023
1 parent abbf2e9 commit 8028c08
Showing 1 changed file with 64 additions and 17 deletions.
81 changes: 64 additions & 17 deletions tests/test_nsjail.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def setUp(self):
self.logger = logging.getLogger("snekbox.nsjail")
self.logger.setLevel(logging.WARNING)

# Hard-coded because it's non-trivial to parse the mount options.
self.shm_mount_size = 40 * Size.MiB

def eval_code(self, code: str):
return self.nsjail.python3(["-c", code])

Expand Down Expand Up @@ -125,6 +128,25 @@ def f():
self.assertIn("-9", exit_codes)
self.assertEqual(result.stderr, None)

def test_multiprocessing_pool(self):
# Validates that shm is working as expected
code = dedent(
"""
from multiprocessing import Pool
def f(x):
return x*x
with Pool(2) as p:
print(p.map(f, [1, 2, 3]))
"""
)

result = self.eval_file(code)

self.assertEqual(result.stdout, "[1, 4, 9]\n")
self.assertEqual(result.returncode, 0)

def test_read_only_file_system(self):
for path in ("/", "/etc", "/lib", "/lib64", "/snekbox", "/usr"):
with self.subTest(path=path):
Expand Down Expand Up @@ -335,35 +357,60 @@ def test_log_parser(self):
log.output,
)

def test_shm_and_tmp_not_mounted(self):
for path in ("/dev/shm", "/run/shm", "/tmp"):
with self.subTest(path=path):
def test_tmp_not_mounted(self):
code = dedent(
"""
with open('/tmp/test', 'wb') as file:
file.write(bytes([255]))
"""
).strip()

result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("No such file or directory", result.stdout)
self.assertEqual(result.stderr, None)

def test_multiprocessing_shared_memory(self):
cases = (
(self.shm_mount_size, self.shm_mount_size, 0),
# Even if the shared memory object is larger than the mount,
# writing data within the size of the mount should succeed.
(self.shm_mount_size + 1, self.shm_mount_size, 0),
(self.shm_mount_size + 1, self.shm_mount_size + 1, 135),
)

for shm_size, buffer_size, return_code in cases:
with self.subTest(shm_size=shm_size, buffer_size=buffer_size):
# Need enough memory for buffer and bytearray plus some overhead.
mem_max = (buffer_size * 2) + (400 * Size.MiB)
code = dedent(
f"""
with open('{path}/test', 'wb') as file:
file.write(bytes([255]))
"""
from multiprocessing.shared_memory import SharedMemory
shm = SharedMemory(create=True, size={shm_size})
shm.buf[:{buffer_size}] = bytearray([1] * {buffer_size})
"""
).strip()

result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("No such file or directory", result.stdout)
result = self.eval_file(code, nsjail_args=("--cgroup_mem_max", str(mem_max)))

self.assertEqual(result.returncode, return_code)
self.assertEqual(result.stdout, "")
self.assertEqual(result.stderr, None)

def test_multiprocessing_shared_memory_disabled(self):
def test_multiprocessing_shared_memory_mmap_limited(self):
"""The mmap call should be OOM trying to map a large & sparse shared memory object."""
code = dedent(
"""
f"""
from multiprocessing.shared_memory import SharedMemory
try:
SharedMemory('test', create=True, size=16)
except FileExistsError:
pass
"""
SharedMemory(create=True, size={self.nsjail.config.cgroup_mem_max + Size.GiB})
"""
).strip()

result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Function not implemented", result.stdout)
self.assertIn("[Errno 12] Cannot allocate memory", result.stdout)
self.assertEqual(result.stderr, None)

def test_numpy_import(self):
Expand Down

0 comments on commit 8028c08

Please sign in to comment.