Skip to content

Commit

Permalink
Fix use after free when dealing with timeouts
Browse files Browse the repository at this point in the history
This refactors the timeouts handling code to ultimately resolve its
soundness issues. That is, it was possible for a process to be dropped
while the timeouts thread still had one or more entries in the timeouts
heap. This would then result in this thread potentially trying to
reschedule a process that no longer exists. The result would either be a
crash or a hang, depending on how lucky (or not) you are.

This fixes #796.

Changelog: fixed
  • Loading branch information
yorickpeterse committed Jan 7, 2025
1 parent d13fd2f commit 342c777
Show file tree
Hide file tree
Showing 8 changed files with 360 additions and 568 deletions.
13 changes: 7 additions & 6 deletions rt/src/network_poller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,19 @@ impl Worker {
let mut processes = poller.poll(&mut events);

processes.retain(|proc| {
let mut state = proc.state();
let rights = state.try_reschedule_for_io();
// Acquiring the rights first _then_ matching on then ensures we
// don't deadlock with the timeout worker.
let rights = proc.state().try_reschedule_for_io();

// A process may have also been registered with the timeout
// thread (e.g. when using a timeout). As such we should only
// reschedule the process if the timout thread didn't already do
// this for us.
// reschedule the process if the timeout thread didn't already
// do this for us.
match rights {
RescheduleRights::Failed => false,
RescheduleRights::Acquired => true,
RescheduleRights::AcquiredWithTimeout => {
self.state.timeout_worker.increase_expired_timeouts();
RescheduleRights::AcquiredWithTimeout(id) => {
self.state.timeout_worker.expire(id);
true
}
}
Expand Down
93 changes: 30 additions & 63 deletions rt/src/process.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::arc_without_weak::ArcWithoutWeak;
use crate::mem::{allocate, header_of, Header, TypePointer};
use crate::scheduler::process::Thread;
use crate::scheduler::timeouts::Timeout;
use crate::scheduler::timeouts::Id as TimeoutId;
use crate::stack::Stack;
use crate::state::State;
use std::alloc::dealloc;
Expand Down Expand Up @@ -200,7 +199,7 @@ pub(crate) enum RescheduleRights {

/// The rescheduling rights were obtained, and the process was using a
/// timeout.
AcquiredWithTimeout,
AcquiredWithTimeout(TimeoutId),
}

impl RescheduleRights {
Expand All @@ -219,11 +218,11 @@ pub(crate) struct ProcessState {
/// The status of the process.
status: ProcessStatus,

/// The timeout this process is suspended with, if any.
/// The ID of the timeout this process is suspended with, if any.
///
/// If missing and the process is suspended, it means the process is
/// suspended indefinitely.
timeout: Option<ArcWithoutWeak<Timeout>>,
timeout: Option<TimeoutId>,
}

impl ProcessState {
Expand All @@ -235,17 +234,7 @@ impl ProcessState {
}
}

pub(crate) fn has_same_timeout(
&self,
timeout: &ArcWithoutWeak<Timeout>,
) -> bool {
self.timeout
.as_ref()
.map(|t| t.as_ptr() == timeout.as_ptr())
.unwrap_or(false)
}

pub(crate) fn suspend(&mut self, timeout: ArcWithoutWeak<Timeout>) {
pub(crate) fn suspend(&mut self, timeout: TimeoutId) {
self.timeout = Some(timeout);
self.status.set_sleeping(true);
}
Expand All @@ -265,25 +254,19 @@ impl ProcessState {

self.status.no_longer_waiting();

if self.timeout.take().is_some() {
RescheduleRights::AcquiredWithTimeout
if let Some(id) = self.timeout.take() {
RescheduleRights::AcquiredWithTimeout(id)
} else {
RescheduleRights::Acquired
}
}

pub(crate) fn waiting_for_value(
&mut self,
timeout: Option<ArcWithoutWeak<Timeout>>,
) {
pub(crate) fn waiting_for_value(&mut self, timeout: Option<TimeoutId>) {
self.timeout = timeout;
self.status.set_waiting_for_value(true);
}

pub(crate) fn waiting_for_io(
&mut self,
timeout: Option<ArcWithoutWeak<Timeout>>,
) {
pub(crate) fn waiting_for_io(&mut self, timeout: Option<TimeoutId>) {
self.timeout = timeout;
self.status.set_waiting_for_io(true);
}
Expand All @@ -304,8 +287,8 @@ impl ProcessState {

self.status.set_waiting_for_value(false);

if self.timeout.take().is_some() {
RescheduleRights::AcquiredWithTimeout
if let Some(id) = self.timeout.take() {
RescheduleRights::AcquiredWithTimeout(id)
} else {
RescheduleRights::Acquired
}
Expand All @@ -318,8 +301,8 @@ impl ProcessState {

self.status.set_waiting_for_io(false);

if self.timeout.take().is_some() {
RescheduleRights::AcquiredWithTimeout
if let Some(id) = self.timeout.take() {
RescheduleRights::AcquiredWithTimeout(id)
} else {
RescheduleRights::Acquired
}
Expand Down Expand Up @@ -719,10 +702,10 @@ impl DerefMut for ProcessPointer {
#[cfg(test)]
mod tests {
use super::*;
use crate::test::{empty_process_type, setup, OwnedProcess};
use crate::test::{empty_process_type, OwnedProcess};
use rustix::param::page_size;
use std::mem::size_of;
use std::time::Duration;
use std::num::NonZeroU64;

macro_rules! offset_of {
($value: expr, $field: ident) => {{
Expand Down Expand Up @@ -857,25 +840,14 @@ mod tests {
fn test_reschedule_rights_are_acquired() {
assert!(!RescheduleRights::Failed.are_acquired());
assert!(RescheduleRights::Acquired.are_acquired());
assert!(RescheduleRights::AcquiredWithTimeout.are_acquired());
}

#[test]
fn test_process_state_has_same_timeout() {
let state = setup();
let mut proc_state = ProcessState::new();
let timeout = Timeout::duration(&state, Duration::from_secs(0));

assert!(!proc_state.has_same_timeout(&timeout));

proc_state.timeout = Some(timeout.clone());

assert!(proc_state.has_same_timeout(&timeout));
assert!(RescheduleRights::AcquiredWithTimeout(TimeoutId(
NonZeroU64::new(1).unwrap()
))
.are_acquired());
}

#[test]
fn test_process_state_try_reschedule_after_timeout() {
let state = setup();
let mut proc_state = ProcessState::new();

assert_eq!(
Expand All @@ -893,13 +865,13 @@ mod tests {
assert!(!proc_state.status.is_waiting_for_value());
assert!(!proc_state.status.is_waiting());

let timeout = Timeout::duration(&state, Duration::from_secs(0));
let id = TimeoutId(NonZeroU64::new(1).unwrap());

proc_state.waiting_for_value(Some(timeout));
proc_state.waiting_for_value(Some(id));

assert_eq!(
proc_state.try_reschedule_after_timeout(),
RescheduleRights::AcquiredWithTimeout
RescheduleRights::AcquiredWithTimeout(id)
);

assert!(!proc_state.status.is_waiting_for_value());
Expand All @@ -908,16 +880,15 @@ mod tests {

#[test]
fn test_process_state_waiting_for_value() {
let state = setup();
let mut proc_state = ProcessState::new();
let timeout = Timeout::duration(&state, Duration::from_secs(0));

proc_state.waiting_for_value(None);

assert!(proc_state.status.is_waiting_for_value());
assert!(proc_state.timeout.is_none());

proc_state.waiting_for_value(Some(timeout));
proc_state
.waiting_for_value(Some(TimeoutId(NonZeroU64::new(1).unwrap())));

assert!(proc_state.status.is_waiting_for_value());
assert!(proc_state.timeout.is_some());
Expand All @@ -943,7 +914,6 @@ mod tests {

#[test]
fn test_process_state_try_reschedule_for_value() {
let state = setup();
let mut proc_state = ProcessState::new();

assert_eq!(
Expand All @@ -958,13 +928,14 @@ mod tests {
);
assert!(!proc_state.status.is_waiting_for_value());

let id = TimeoutId(NonZeroU64::new(1).unwrap());

proc_state.status.set_waiting_for_value(true);
proc_state.timeout =
Some(Timeout::duration(&state, Duration::from_secs(0)));
proc_state.timeout = Some(id);

assert_eq!(
proc_state.try_reschedule_for_value(),
RescheduleRights::AcquiredWithTimeout
RescheduleRights::AcquiredWithTimeout(id)
);
assert!(!proc_state.status.is_waiting_for_value());
}
Expand Down Expand Up @@ -1004,29 +975,25 @@ mod tests {

#[test]
fn test_process_state_suspend() {
let state = setup();
let typ = empty_process_type("A");
let stack = Stack::new(32, page_size());
let process = OwnedProcess::new(Process::alloc(*typ, stack));
let timeout = Timeout::duration(&state, Duration::from_secs(0));

process.state().suspend(timeout);
process.state().suspend(TimeoutId(NonZeroU64::new(1).unwrap()));

assert!(process.state().timeout.is_some());
assert!(process.state().status.is_waiting());
}

#[test]
fn test_process_timeout_expired() {
let state = setup();
let typ = empty_process_type("A");
let stack = Stack::new(32, page_size());
let process = OwnedProcess::new(Process::alloc(*typ, stack));
let timeout = Timeout::duration(&state, Duration::from_secs(0));

assert!(!process.timeout_expired());

process.state().suspend(timeout);
process.state().suspend(TimeoutId(NonZeroU64::new(1).unwrap()));

assert!(!process.timeout_expired());
assert!(!process.state().status.timeout_expired());
Expand Down
35 changes: 20 additions & 15 deletions rt/src/runtime/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::process::{
StackFrame,
};
use crate::scheduler::process::Action;
use crate::scheduler::timeouts::Timeout;
use crate::scheduler::timeouts::Deadline;
use crate::state::State;
use std::fmt::Write as _;
use std::process::exit;
Expand Down Expand Up @@ -83,8 +83,8 @@ pub unsafe extern "system" fn inko_process_send_message(
let message = Message { method, data };
let state = &*state;
let reschedule = match receiver.send_message(message) {
RescheduleRights::AcquiredWithTimeout => {
state.timeout_worker.increase_expired_timeouts();
RescheduleRights::AcquiredWithTimeout(id) => {
state.timeout_worker.expire(id);
true
}
RescheduleRights::Acquired => true,
Expand Down Expand Up @@ -129,13 +129,16 @@ pub unsafe extern "system" fn inko_process_suspend(
nanos: i64,
) {
let state = &*state;
let timeout = Timeout::duration(state, Duration::from_nanos(nanos as _));
let timeout = Deadline::duration(state, Duration::from_nanos(nanos as _));

{
// We need to hold on to the lock until the end as to ensure the process
// is rescheduled if the timeout happens to expire before we finish the
// work here.
let mut proc_state = process.state();
let timeout_id = state.timeout_worker.suspend(process, timeout);

proc_state.suspend(timeout.clone());
state.timeout_worker.suspend(process, timeout);
proc_state.suspend(timeout_id);
}

// Safety: the current thread is holding on to the run lock
Expand Down Expand Up @@ -243,22 +246,21 @@ pub unsafe extern "system" fn inko_process_wait_for_value_until(
nanos: u64,
) -> bool {
let state = &*state;
let deadline = Timeout::until(nanos);
let deadline = Deadline::until(nanos);
let mut proc_state = process.state();

proc_state.waiting_for_value(Some(deadline.clone()));

let _ = (*lock).compare_exchange(
current,
new,
Ordering::AcqRel,
Ordering::Acquire,
);

let timeout_id = state.timeout_worker.suspend(process, deadline);

proc_state.waiting_for_value(Some(timeout_id));
drop(proc_state);

// Safety: the current thread is holding on to the run lock
state.timeout_worker.suspend(process, deadline);
context::switch(process);
process.timeout_expired()
}
Expand All @@ -270,12 +272,15 @@ pub unsafe extern "system" fn inko_process_reschedule_for_value(
waiter: ProcessPointer,
) {
let state = &*state;
let mut waiter_state = waiter.state();
let reschedule = match waiter_state.try_reschedule_for_value() {

// Acquiring the rights first _then_ matching on then ensures we don't
// deadlock with the timeout worker.
let rights = waiter.state().try_reschedule_for_value();
let reschedule = match rights {
RescheduleRights::Failed => false,
RescheduleRights::Acquired => true,
RescheduleRights::AcquiredWithTimeout => {
state.timeout_worker.increase_expired_timeouts();
RescheduleRights::AcquiredWithTimeout(id) => {
state.timeout_worker.expire(id);
true
}
};
Expand Down
8 changes: 4 additions & 4 deletions rt/src/runtime/socket.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::context;
use crate::network_poller::Interest;
use crate::process::ProcessPointer;
use crate::scheduler::timeouts::Timeout;
use crate::scheduler::timeouts::Deadline;
use crate::socket::Socket;
use crate::state::State;

Expand All @@ -26,10 +26,10 @@ pub(crate) unsafe extern "system" fn inko_socket_poll(

// A deadline of -1 signals that we should wait indefinitely.
if deadline >= 0 {
let time = Timeout::until(deadline as u64);
let time = Deadline::until(deadline as u64);
let timeout_id = state.timeout_worker.suspend(process, time);

proc_state.waiting_for_io(Some(time.clone()));
state.timeout_worker.suspend(process, time);
proc_state.waiting_for_io(Some(timeout_id));
} else {
proc_state.waiting_for_io(None);
}
Expand Down
1 change: 0 additions & 1 deletion rt/src/scheduler/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
pub mod process;
pub mod signal;
pub mod timeout_worker;
pub mod timeouts;

#[cfg(target_os = "linux")]
Expand Down
Loading

0 comments on commit 342c777

Please sign in to comment.