Skip to content

Commit

Permalink
WIP: fix timeouts user-after-free
Browse files Browse the repository at this point in the history
  • Loading branch information
yorickpeterse committed Jan 6, 2025
1 parent ddf2eeb commit e5e0462
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 116 deletions.
57 changes: 13 additions & 44 deletions rt/src/process.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::arc_without_weak::ArcWithoutWeak;
use crate::mem::{allocate, header_of, Header, TypePointer};
use crate::scheduler::process::Thread;
use crate::scheduler::timeouts::Timeout;
use crate::scheduler::timeouts::TimeoutProcess;
use crate::stack::Stack;
use crate::state::State;
use std::alloc::dealloc;
Expand Down Expand Up @@ -223,7 +222,7 @@ pub(crate) struct ProcessState {
///
/// If missing and the process is suspended, it means the process is
/// suspended indefinitely.
timeout: Option<ArcWithoutWeak<Timeout>>,
timeout: Option<TimeoutProcess>,
}

impl ProcessState {
Expand All @@ -235,17 +234,7 @@ impl ProcessState {
}
}

pub(crate) fn has_same_timeout(
&self,
timeout: &ArcWithoutWeak<Timeout>,
) -> bool {
self.timeout
.as_ref()
.map(|t| t.as_ptr() == timeout.as_ptr())
.unwrap_or(false)
}

pub(crate) fn suspend(&mut self, timeout: ArcWithoutWeak<Timeout>) {
pub(crate) fn suspend(&mut self, timeout: TimeoutProcess) {
self.timeout = Some(timeout);
self.status.set_sleeping(true);
}
Expand Down Expand Up @@ -274,16 +263,13 @@ impl ProcessState {

pub(crate) fn waiting_for_value(
&mut self,
timeout: Option<ArcWithoutWeak<Timeout>>,
timeout: Option<TimeoutProcess>,
) {
self.timeout = timeout;
self.status.set_waiting_for_value(true);
}

pub(crate) fn waiting_for_io(
&mut self,
timeout: Option<ArcWithoutWeak<Timeout>>,
) {
pub(crate) fn waiting_for_io(&mut self, timeout: Option<TimeoutProcess>) {
self.timeout = timeout;
self.status.set_waiting_for_io(true);
}
Expand Down Expand Up @@ -719,10 +705,9 @@ impl DerefMut for ProcessPointer {
#[cfg(test)]
mod tests {
use super::*;
use crate::test::{empty_process_type, setup, OwnedProcess};
use crate::test::{empty_process_type, OwnedProcess};
use rustix::param::page_size;
use std::mem::size_of;
use std::time::Duration;

macro_rules! offset_of {
($value: expr, $field: ident) => {{
Expand Down Expand Up @@ -860,22 +845,8 @@ mod tests {
assert!(RescheduleRights::AcquiredWithTimeout.are_acquired());
}

#[test]
fn test_process_state_has_same_timeout() {
let state = setup();
let mut proc_state = ProcessState::new();
let timeout = Timeout::duration(&state, Duration::from_secs(0));

assert!(!proc_state.has_same_timeout(&timeout));

proc_state.timeout = Some(timeout.clone());

assert!(proc_state.has_same_timeout(&timeout));
}

#[test]
fn test_process_state_try_reschedule_after_timeout() {
let state = setup();
let mut proc_state = ProcessState::new();

assert_eq!(
Expand All @@ -893,7 +864,8 @@ mod tests {
assert!(!proc_state.status.is_waiting_for_value());
assert!(!proc_state.status.is_waiting());

let timeout = Timeout::duration(&state, Duration::from_secs(0));
let timeout =
TimeoutProcess::new(unsafe { ProcessPointer::new(0x4 as _) });

proc_state.waiting_for_value(Some(timeout));

Expand All @@ -908,9 +880,9 @@ mod tests {

#[test]
fn test_process_state_waiting_for_value() {
let state = setup();
let mut proc_state = ProcessState::new();
let timeout = Timeout::duration(&state, Duration::from_secs(0));
let timeout =
TimeoutProcess::new(unsafe { ProcessPointer::new(0x4 as _) });

proc_state.waiting_for_value(None);

Expand Down Expand Up @@ -943,7 +915,6 @@ mod tests {

#[test]
fn test_process_state_try_reschedule_for_value() {
let state = setup();
let mut proc_state = ProcessState::new();

assert_eq!(
Expand All @@ -960,7 +931,7 @@ mod tests {

proc_state.status.set_waiting_for_value(true);
proc_state.timeout =
Some(Timeout::duration(&state, Duration::from_secs(0)));
Some(TimeoutProcess::new(unsafe { ProcessPointer::new(0x4 as _) }));

assert_eq!(
proc_state.try_reschedule_for_value(),
Expand Down Expand Up @@ -1004,11 +975,10 @@ mod tests {

#[test]
fn test_process_state_suspend() {
let state = setup();
let typ = empty_process_type("A");
let stack = Stack::new(32, page_size());
let process = OwnedProcess::new(Process::alloc(*typ, stack));
let timeout = Timeout::duration(&state, Duration::from_secs(0));
let timeout = TimeoutProcess::new(*process);

process.state().suspend(timeout);

Expand All @@ -1018,11 +988,10 @@ mod tests {

#[test]
fn test_process_timeout_expired() {
let state = setup();
let typ = empty_process_type("A");
let stack = Stack::new(32, page_size());
let process = OwnedProcess::new(Process::alloc(*typ, stack));
let timeout = Timeout::duration(&state, Duration::from_secs(0));
let timeout = TimeoutProcess::new(*process);

assert!(!process.timeout_expired());

Expand Down
12 changes: 4 additions & 8 deletions rt/src/runtime/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ pub unsafe extern "system" fn inko_process_suspend(
let timeout = Timeout::duration(state, Duration::from_nanos(nanos as _));

{
let mut proc_state = process.state();

proc_state.suspend(timeout.clone());
state.timeout_worker.suspend(process, timeout);
process.state().suspend(state.timeout_worker.suspend(process, timeout));
}

// Safety: the current thread is holding on to the run lock
Expand Down Expand Up @@ -245,20 +242,19 @@ pub unsafe extern "system" fn inko_process_wait_for_value_until(
let state = &*state;
let deadline = Timeout::until(nanos);
let mut proc_state = process.state();

proc_state.waiting_for_value(Some(deadline.clone()));

let _ = (*lock).compare_exchange(
current,
new,
Ordering::AcqRel,
Ordering::Acquire,
);

proc_state.waiting_for_value(Some(
state.timeout_worker.suspend(process, deadline),
));
drop(proc_state);

// Safety: the current thread is holding on to the run lock
state.timeout_worker.suspend(process, deadline);
context::switch(process);
process.timeout_expired()
}
Expand Down
5 changes: 3 additions & 2 deletions rt/src/runtime/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ pub(crate) unsafe extern "system" fn inko_socket_poll(
if deadline >= 0 {
let time = Timeout::until(deadline as u64);

proc_state.waiting_for_io(Some(time.clone()));
state.timeout_worker.suspend(process, time);
proc_state.waiting_for_io(Some(
state.timeout_worker.suspend(process, time),
));
} else {
proc_state.waiting_for_io(None);
}
Expand Down
45 changes: 27 additions & 18 deletions rt/src/scheduler/timeout_worker.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
//! Rescheduling of processes with expired timeouts.
use crate::arc_without_weak::ArcWithoutWeak;
use crate::process::ProcessPointer;
use crate::scheduler::process::Scheduler;
use crate::scheduler::timeouts::{Timeout, Timeouts};
use crate::scheduler::timeouts::{Timeout, TimeoutProcess, Timeouts};
use crate::state::State;
use std::cell::UnsafeCell;
use std::collections::VecDeque;
Expand All @@ -23,8 +22,9 @@ const QUEUE_START_CAPACITY: usize = 1024 / size_of::<Message>();
const FRAGMENTATION_THRESHOLD: f64 = 0.1;

struct Message {
process: ProcessPointer,
timeout: ArcWithoutWeak<Timeout>,
id: usize,
process: TimeoutProcess,
timeout: Timeout,
}

/// The inner part of a worker, only accessible by the owning thread.
Expand Down Expand Up @@ -100,12 +100,16 @@ impl TimeoutWorker {
pub(crate) fn suspend(
&self,
process: ProcessPointer,
timeout: ArcWithoutWeak<Timeout>,
) {
timeout: Timeout,
) -> TimeoutProcess {
let id = process.identifier();
let ours = TimeoutProcess::new(process);
let theirs = ours.clone();
let mut queue = self.queue.lock().unwrap();

queue.push_back(Message { process, timeout });
queue.push_back(Message { id, process: ours, timeout });
self.cvar.notify_one();
theirs
}

fn run_iteration(&self, state: &State) -> Option<Duration> {
Expand Down Expand Up @@ -153,7 +157,7 @@ impl TimeoutWorker {

fn handle_pending_messages(&self) {
while let Some(msg) = self.inner_mut().queue.pop_front() {
self.inner_mut().timeouts.insert(msg.process, msg.timeout);
self.inner_mut().timeouts.insert(msg.id, msg.process, msg.timeout);
}
}

Expand Down Expand Up @@ -234,8 +238,9 @@ mod tests {
for time in &[10_u64, 5_u64] {
let timeout = Timeout::duration(&state, Duration::from_secs(*time));

process.state().waiting_for_value(Some(timeout.clone()));
worker.suspend(process, timeout);
process
.state()
.waiting_for_value(Some(worker.suspend(process, timeout)));
}

worker.increase_expired_timeouts();
Expand All @@ -258,8 +263,9 @@ mod tests {
let worker = TimeoutWorker::new();
let timeout = Timeout::duration(&state, Duration::from_secs(10));

process.state().waiting_for_value(Some(timeout.clone()));
worker.suspend(process, timeout);
process
.state()
.waiting_for_value(Some(worker.suspend(process, timeout)));
worker.run_iteration(&state);

assert_eq!(worker.inner().timeouts.len(), 1);
Expand All @@ -273,8 +279,9 @@ mod tests {
let worker = TimeoutWorker::new();
let timeout = Timeout::duration(&state, Duration::from_secs(0));

process.state().waiting_for_value(Some(timeout.clone()));
worker.suspend(process, timeout);
process
.state()
.waiting_for_value(Some(worker.suspend(process, timeout)));
worker.run_iteration(&state);

assert_eq!(worker.inner().timeouts.len(), 0);
Expand All @@ -288,8 +295,9 @@ mod tests {
let worker = TimeoutWorker::new();
let timeout = Timeout::duration(&state, Duration::from_secs(1));

process.state().waiting_for_value(Some(timeout.clone()));
worker.suspend(process, timeout);
process
.state()
.waiting_for_value(Some(worker.suspend(process, timeout)));
worker.move_messages();
worker.handle_pending_messages();
worker.defragment_heap();
Expand All @@ -308,8 +316,9 @@ mod tests {
for time in &[1_u64, 1_u64] {
let timeout = Timeout::duration(&state, Duration::from_secs(*time));

process.state().waiting_for_value(Some(timeout.clone()));
worker.suspend(process, timeout);
process
.state()
.waiting_for_value(Some(worker.suspend(process, timeout)));
}

worker.increase_expired_timeouts();
Expand Down
Loading

0 comments on commit e5e0462

Please sign in to comment.