Skip to content

Commit

Permalink
use set_read_timeout instead of sleep
Browse files Browse the repository at this point in the history
  • Loading branch information
jprochazk committed Oct 24, 2024
1 parent 7ab9734 commit 934de81
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 92 deletions.
10 changes: 7 additions & 3 deletions ewebsock/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,12 @@ pub struct Options {
/// Currently only supported on native.
pub subprotocols: Vec<String>,

/// Delay blocking in ms - default 10ms
pub delay_blocking: std::time::Duration,
/// Socket read timeout.
///
/// Reads will block forever if this is set to `None` or `Some(Duration::ZERO)`.
///
/// Defaults to 10ms.
pub read_timeout: Option<std::time::Duration>,
}

impl Default for Options {
Expand All @@ -159,7 +163,7 @@ impl Default for Options {
subprotocols: vec![],
// let the OS schedule something else, otherwise busy-loop
// TODO: use polling on native instead
delay_blocking: std::time::Duration::from_millis(0),
read_timeout: Some(std::time::Duration::from_millis(10)),
}
}
}
Expand Down
169 changes: 80 additions & 89 deletions ewebsock/src/native_tungstenite.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
//! Native implementation of the WebSocket client using the `tungstenite` crate.

use std::net::TcpStream;
use std::{
ops::ControlFlow,
sync::mpsc::{Receiver, TryRecvError},
};

use tungstenite::stream::MaybeTlsStream;
use tungstenite::WebSocket;

use crate::tungstenite_common::into_requester;
use crate::{EventHandler, Options, Result, WsEvent, WsMessage};

Expand Down Expand Up @@ -70,13 +74,13 @@ pub(crate) fn ws_receive_impl(url: String, options: Options, on_event: EventHand
/// # Errors
/// All errors are returned to the caller, and NOT reported via `on_event`.
pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler) -> Result<()> {
let delay = options.delay_blocking;
let uri: tungstenite::http::Uri = url
.parse()
.map_err(|err| format!("Failed to parse URL {url:?}: {err}"))?;
let config = tungstenite::protocol::WebSocketConfig::from(options.clone());
let max_redirects = 3; // tungstenite default

let read_timeout = options.read_timeout;
let (mut socket, response) = match tungstenite::client::connect_with_config(
into_requester(uri, options),
Some(config),
Expand All @@ -88,6 +92,8 @@ pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler
}
};

set_read_timeout(&mut socket, read_timeout)?;

log::debug!("WebSocket HTTP response code: {}", response.status());
log::trace!(
"WebSocket response contains the following headers: {:?}",
Expand All @@ -103,31 +109,7 @@ pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler
}

loop {
let control = match socket.read() {
Ok(incoming_msg) => match incoming_msg {
tungstenite::protocol::Message::Text(text) => {
on_event(WsEvent::Message(WsMessage::Text(text)))
}
tungstenite::protocol::Message::Binary(data) => {
on_event(WsEvent::Message(WsMessage::Binary(data)))
}
tungstenite::protocol::Message::Ping(data) => {
on_event(WsEvent::Message(WsMessage::Ping(data)))
}
tungstenite::protocol::Message::Pong(data) => {
on_event(WsEvent::Message(WsMessage::Pong(data)))
}
tungstenite::protocol::Message::Close(close) => {
on_event(WsEvent::Closed);
log::debug!("WebSocket close received: {close:?}");
return Ok(());
}
tungstenite::protocol::Message::Frame(_) => ControlFlow::Continue(()),
},
Err(err) => {
return Err(format!("read: {err}"));
}
};
let control = read_from_socket(&mut socket, on_event)?;

if control.is_break() {
log::trace!("Closing connection due to Break");
Expand All @@ -136,12 +118,7 @@ pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler
.map_err(|err| format!("Failed to close connection: {err}"));
}

// without the check we wouldn't yield at all on some platforms
if delay == std::time::Duration::ZERO {
std::thread::yield_now();
} else {
std::thread::sleep(delay);
}
std::thread::yield_now();
}
}

Expand Down Expand Up @@ -178,12 +155,13 @@ pub fn ws_connect_blocking(
on_event: &EventHandler,
rx: &Receiver<WsMessage>,
) -> Result<()> {
let delay = options.delay_blocking;
let config = tungstenite::protocol::WebSocketConfig::from(options.clone());
let max_redirects = 3; // tungstenite default
let uri: tungstenite::http::Uri = url
.parse()
.map_err(|err| format!("Failed to parse URL {url:?}: {err}"))?;

let read_timeout = options.read_timeout;
let (mut socket, response) = match tungstenite::client::connect_with_config(
into_requester(uri, options),
Some(config),
Expand All @@ -195,6 +173,8 @@ pub fn ws_connect_blocking(
}
};

set_read_timeout(&mut socket, read_timeout)?;

log::debug!("WebSocket HTTP response code: {}", response.status());
log::trace!(
"WebSocket response contains the following headers: {:?}",
Expand All @@ -209,26 +189,9 @@ pub fn ws_connect_blocking(
.map_err(|err| format!("Failed to close connection: {err}"));
}

match socket.get_mut() {
tungstenite::stream::MaybeTlsStream::Plain(stream) => stream.set_nonblocking(true),

// tungstenite::stream::MaybeTlsStream::NativeTls(stream) => {
// stream.get_mut().set_nonblocking(true)
// }
#[cfg(feature = "tls")]
tungstenite::stream::MaybeTlsStream::Rustls(stream) => {
stream.get_mut().set_nonblocking(true)
}
_ => return Err(format!("Unknown tungstenite stream {:?}", socket.get_mut())),
}
.map_err(|err| format!("Failed to make WebSocket non-blocking: {err}"))?;

loop {
let mut did_work = false;

match rx.try_recv() {
Ok(outgoing_message) => {
did_work = true;
let outgoing_message = match outgoing_message {
WsMessage::Text(text) => tungstenite::protocol::Message::Text(text),
WsMessage::Binary(data) => tungstenite::protocol::Message::Binary(data),
Expand All @@ -251,39 +214,7 @@ pub fn ws_connect_blocking(
Err(TryRecvError::Empty) => {}
};

let control = match socket.read() {
Ok(incoming_msg) => {
did_work = true;
match incoming_msg {
tungstenite::protocol::Message::Text(text) => {
on_event(WsEvent::Message(WsMessage::Text(text)))
}
tungstenite::protocol::Message::Binary(data) => {
on_event(WsEvent::Message(WsMessage::Binary(data)))
}
tungstenite::protocol::Message::Ping(data) => {
on_event(WsEvent::Message(WsMessage::Ping(data)))
}
tungstenite::protocol::Message::Pong(data) => {
on_event(WsEvent::Message(WsMessage::Pong(data)))
}
tungstenite::protocol::Message::Close(close) => {
on_event(WsEvent::Closed);
log::debug!("Close received: {close:?}");
return Ok(());
}
tungstenite::protocol::Message::Frame(_) => ControlFlow::Continue(()),
}
}
Err(tungstenite::Error::Io(io_err))
if io_err.kind() == std::io::ErrorKind::WouldBlock =>
{
ControlFlow::Continue(()) // Ignore
}
Err(err) => {
return Err(format!("read: {err}"));
}
};
let control = read_from_socket(&mut socket, on_event)?;

if control.is_break() {
log::trace!("Closing connection due to Break");
Expand All @@ -292,15 +223,75 @@ pub fn ws_connect_blocking(
.map_err(|err| format!("Failed to close connection: {err}"));
}

if !did_work {
// without the check we wouldn't yield at all on some platforms
if delay == std::time::Duration::ZERO {
std::thread::yield_now();
} else {
std::thread::sleep(delay);
std::thread::yield_now();
}
}

fn read_from_socket(
socket: &mut WebSocket<MaybeTlsStream<TcpStream>>,
on_event: &EventHandler,
) -> Result<ControlFlow<()>> {
let control = match socket.read() {
Ok(incoming_msg) => match incoming_msg {
tungstenite::protocol::Message::Text(text) => {
on_event(WsEvent::Message(WsMessage::Text(text)))
}
tungstenite::protocol::Message::Binary(data) => {
on_event(WsEvent::Message(WsMessage::Binary(data)))
}
tungstenite::protocol::Message::Ping(data) => {
on_event(WsEvent::Message(WsMessage::Ping(data)))
}
tungstenite::protocol::Message::Pong(data) => {
on_event(WsEvent::Message(WsMessage::Pong(data)))
}
tungstenite::protocol::Message::Close(close) => {
on_event(WsEvent::Closed);
log::debug!("WebSocket close received: {close:?}");
ControlFlow::Break(())
}
tungstenite::protocol::Message::Frame(_) => ControlFlow::Continue(()),
},
// If we get `WouldBlock`, then the read timed out.
// Windows may emit `TimedOut` instead.
Err(tungstenite::Error::Io(io_err))
if io_err.kind() == std::io::ErrorKind::WouldBlock
|| io_err.kind() == std::io::ErrorKind::TimedOut =>
{
ControlFlow::Continue(()) // Ignore
}
Err(err) => {
return Err(format!("read: {err}"));
}
};

Ok(control)
}

fn set_read_timeout(
s: &mut WebSocket<MaybeTlsStream<TcpStream>>,
value: Option<std::time::Duration>,
) -> Result<()> {
// zero timeout is the same as no timeout
if value.is_none() || value.is_some_and(|value| value.is_zero()) {
return Ok(());
}

match s.get_mut() {
MaybeTlsStream::Plain(s) => {
s.set_read_timeout(value)
.map_err(|err| format!("failed to set read timeout: {err}"))?;
}
#[cfg(feature = "tls")]
MaybeTlsStream::Rustls(s) => {
s.get_mut()
.set_read_timeout(value)
.map_err(|err| format!("failed to set read timeout: {err}"))?;
}
_ => {}
};

Ok(())
}

#[test]
Expand Down

0 comments on commit 934de81

Please sign in to comment.