diff --git a/Cargo.lock b/Cargo.lock index a17342a27..f256c1c2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -335,9 +335,17 @@ dependencies = [ name = "chia-client" version = "0.10.0" dependencies = [ + "anyhow", "chia-protocol", + "chia-ssl", "chia-traits 0.10.0", + "env_logger", "futures-util", + "hex", + "hex-literal", + "log", + "native-tls", + "sha2", "thiserror", "tokio", "tokio-tungstenite", @@ -699,6 +707,22 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + [[package]] name = "cpufeatures" version = "0.2.12" @@ -920,12 +944,45 @@ dependencies = [ "zeroize", ] +[[package]] +name = "env_filter" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "humantime", + "log", +] + [[package]] name = "equivalent" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys", +] + [[package]] name = "fallible-iterator" version = "0.3.0" @@ -938,6 +995,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" +[[package]] +name = "fastrand" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" + [[package]] name = "ff" version = "0.13.0" @@ -1221,6 +1284,12 @@ version = "1.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fcc0b4a115bf80b728eb8ea024ad5bd707b615bfed49e0665b6e0f86fd082d9" +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "idna" version = "0.5.0" @@ -1370,6 +1439,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + [[package]] name = "lock_api" version = "0.4.12" @@ -1437,6 +1512,23 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nom" version = "7.1.3" @@ -1576,6 +1668,12 @@ dependencies = [ "syn 2.0.72", ] +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + [[package]] name = "openssl-src" version = "300.3.1+3.3.1" @@ -2060,6 +2158,19 @@ dependencies = [ "nom", ] +[[package]] +name = "rustix" +version = "0.38.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "rustls-pki-types" version = "1.7.0" @@ -2081,6 +2192,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +dependencies = [ + "windows-sys", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2101,6 +2221,29 @@ dependencies = [ "zeroize", ] +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.23" @@ -2161,6 +2304,15 @@ dependencies = [ "digest", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + [[package]] name = "signature" version = "2.2.0" @@ -2263,6 +2415,19 @@ version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" +[[package]] +name = "tempfile" +version = "3.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fcd239983515c23a32fb82099f97d0b11b8c72f654ed659363a95c3dad7a53" +dependencies = [ + "cfg-if", + "fastrand", + "once_cell", + "rustix", + "windows-sys", +] + [[package]] name = "term" version = "0.2.14" @@ -2378,11 +2543,35 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", + "tokio-macros", "windows-sys", ] +[[package]] +name = "tokio-macros" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-tungstenite" version = "0.21.0" @@ -2391,7 +2580,9 @@ checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" dependencies = [ "futures-util", "log", + "native-tls", "tokio", + "tokio-native-tls", "tungstenite", ] @@ -2424,6 +2615,7 @@ dependencies = [ "http", "httparse", "log", + "native-tls", "rand", "sha1", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index 0256f0729..f27f83cd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -138,3 +138,6 @@ zstd = "0.13.2" blocking-threadpool = "1.0.1" libfuzzer-sys = "0.4" wasm-bindgen = "0.2.92" +log = "0.4.22" +native-tls = "0.2.12" +env_logger = "0.11.3" diff --git a/crates/chia-client/Cargo.toml b/crates/chia-client/Cargo.toml index 880be8777..c96e410e6 100644 --- a/crates/chia-client/Cargo.toml +++ b/crates/chia-client/Cargo.toml @@ -14,8 +14,19 @@ workspace = true [dependencies] chia-protocol = { workspace = true } chia-traits = { workspace = true } -tokio = { workspace = true, features = ["rt", "sync"] } -tokio-tungstenite = { workspace = true } +tokio = { workspace = true, features = ["rt", "sync", "time"] } +tokio-tungstenite = { workspace = true, features = ["native-tls"] } futures-util = { workspace = true } tungstenite = { workspace = true } thiserror = { workspace = true } +sha2 = { workspace = true } +log = { workspace = true } +native-tls = { workspace = true } +hex-literal = { workspace = true } +hex = { workspace = true } + +[dev-dependencies] +chia-ssl = { path = "../chia-ssl" } +tokio = { workspace = true, features = ["full"] } +anyhow = { workspace = true } +env_logger = { workspace = true } diff --git a/crates/chia-client/examples/peer_connection.rs b/crates/chia-client/examples/peer_connection.rs new file mode 100644 index 000000000..d18069903 --- /dev/null +++ b/crates/chia-client/examples/peer_connection.rs @@ -0,0 +1,41 @@ +use std::{env, net::SocketAddr}; + +use chia_client::{create_tls_connector, Peer}; +use chia_protocol::{Handshake, NodeType}; +use chia_ssl::ChiaCertificate; +use chia_traits::Streamable; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let ssl = ChiaCertificate::generate()?; + let tls_connector = create_tls_connector(ssl.cert_pem.as_bytes(), ssl.key_pem.as_bytes())?; + let (peer, mut receiver) = Peer::connect( + SocketAddr::new(env::var("PEER")?.parse()?, 58444), + tls_connector, + ) + .await?; + + peer.send(Handshake { + network_id: "testnet11".to_string(), + protocol_version: "0.0.34".to_string(), + software_version: "0.0.0".to_string(), + server_port: 0, + node_type: NodeType::Wallet, + capabilities: vec![ + (1, "1".to_string()), + (2, "1".to_string()), + (3, "1".to_string()), + ], + }) + .await?; + + let message = receiver.recv().await.unwrap(); + let handshake = Handshake::from_bytes(&message.data)?; + println!("{handshake:#?}"); + + while let Some(message) = receiver.recv().await { + println!("{message:?}"); + } + + Ok(()) +} diff --git a/crates/chia-client/src/error.rs b/crates/chia-client/src/error.rs index 108290e2c..7c3cc3052 100644 --- a/crates/chia-client/src/error.rs +++ b/crates/chia-client/src/error.rs @@ -1,21 +1,35 @@ -use chia_protocol::Message; -use chia_traits::chia_error; +use chia_protocol::ProtocolMessageTypes; use thiserror::Error; +use tokio::sync::oneshot::error::RecvError; #[derive(Debug, Error)] -pub enum Error { - #[error("{0:?}")] - Chia(#[from] chia_error::Error), +pub enum Error { + #[error("Peer is missing certificate")] + MissingCertificate, - #[error("{0}")] + #[error("Streamable error: {0}")] + Streamable(#[from] chia_traits::Error), + + #[error("WebSocket error: {0}")] WebSocket(#[from] tungstenite::Error), - #[error("{0:?}")] - InvalidResponse(Message), + #[error("TLS error: {0}")] + Tls(#[from] native_tls::Error), + + #[error("Unexpected message received with type {0:?}")] + UnexpectedMessage(ProtocolMessageTypes), + + #[error("Expected response with type {0:?}, found {1:?}")] + InvalidResponse(Vec, ProtocolMessageTypes), - #[error("missing response")] - MissingResponse, + #[error("Failed to send event")] + EventNotSent, - #[error("rejection")] - Rejection(R), + #[error("Failed to receive message")] + Recv(#[from] RecvError), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), } + +pub type Result = std::result::Result; diff --git a/crates/chia-client/src/lib.rs b/crates/chia-client/src/lib.rs index 1cfd95c93..363affdf5 100644 --- a/crates/chia-client/src/lib.rs +++ b/crates/chia-client/src/lib.rs @@ -1,6 +1,10 @@ mod error; +mod network; mod peer; -mod utils; +mod request_map; +mod tls; pub use error::*; +pub use network::*; pub use peer::*; +pub use tls::*; diff --git a/crates/chia-client/src/network.rs b/crates/chia-client/src/network.rs new file mode 100644 index 000000000..af229e184 --- /dev/null +++ b/crates/chia-client/src/network.rs @@ -0,0 +1,39 @@ +use chia_protocol::Bytes32; +use hex_literal::hex; + +#[derive(Debug, Clone)] +pub struct Network { + pub network_id: String, + pub default_port: u16, + pub genesis_challenge: Bytes32, + pub dns_introducers: Vec, +} + +impl Network { + pub fn mainnet() -> Self { + Self { + network_id: "mainnet".to_string(), + default_port: 8444, + genesis_challenge: Bytes32::new(hex!( + "ccd5bb71183532bff220ba46c268991a3ff07eb358e8255a65c30a2dce0e5fbb" + )), + dns_introducers: vec![ + "dns-introducer.chia.net".to_string(), + "chia.ctrlaltdel.ch".to_string(), + "seeder.dexie.space".to_string(), + "chia.hoffmang.com".to_string(), + ], + } + } + + pub fn testnet11() -> Self { + Self { + network_id: "testnet11".to_string(), + default_port: 58444, + genesis_challenge: Bytes32::new(hex!( + "37a90eb5185a9c4439a91ddc98bbadce7b4feba060d50116a067de66bf236615" + )), + dns_introducers: vec!["dns-introducer-testnet11.chia.net".to_string()], + } + } +} diff --git a/crates/chia-client/src/peer.rs b/crates/chia-client/src/peer.rs index a7db5cc62..bd45172cf 100644 --- a/crates/chia-client/src/peer.rs +++ b/crates/chia-client/src/peer.rs @@ -1,363 +1,341 @@ -use std::sync::atomic::{AtomicU16, Ordering}; -use std::{collections::HashMap, sync::Arc}; - -use chia_protocol::*; +use std::{fmt, net::SocketAddr, sync::Arc}; + +use chia_protocol::{ + Bytes32, ChiaProtocolMessage, CoinStateFilters, Message, PuzzleSolutionResponse, + RegisterForCoinUpdates, RegisterForPhUpdates, RejectCoinState, RejectPuzzleSolution, + RejectPuzzleState, RequestChildren, RequestCoinState, RequestPeers, RequestPuzzleSolution, + RequestPuzzleState, RequestTransaction, RespondChildren, RespondCoinState, RespondPeers, + RespondPuzzleSolution, RespondPuzzleState, RespondToCoinUpdates, RespondToPhUpdates, + RespondTransaction, SendTransaction, SpendBundle, TransactionAck, +}; use chia_traits::Streamable; -use futures_util::stream::SplitSink; -use futures_util::{SinkExt, StreamExt}; -use tokio::sync::{broadcast, oneshot, Mutex}; -use tokio::{net::TcpStream, task::JoinHandle}; -use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; -use tungstenite::Message as WsMessage; - -use crate::utils::stream; -use crate::Error; +use futures_util::{ + stream::{SplitSink, SplitStream}, + SinkExt, StreamExt, +}; +use native_tls::TlsConnector; +use sha2::{digest::FixedOutput, Digest, Sha256}; +use tokio::{ + net::TcpStream, + sync::{mpsc, oneshot, Mutex}, + task::JoinHandle, +}; +use tokio_tungstenite::{Connector, MaybeTlsStream, WebSocketStream}; + +use crate::{request_map::RequestMap, Error, Result}; type WebSocket = WebSocketStream>; -type Requests = Arc>>>; +type Sink = SplitSink; +type Stream = SplitStream; +type Response = std::result::Result; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct PeerId([u8; 32]); -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PeerEvent { - CoinStateUpdate(CoinStateUpdate), - NewPeakWallet(NewPeakWallet), +impl fmt::Display for PeerId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", hex::encode(self.0)) + } } -pub struct Peer { - sink: Mutex>, - inbound_task: JoinHandle<()>, - event_receiver: broadcast::Receiver, - requests: Requests, +#[derive(Debug, Clone)] +pub struct Peer(Arc); - // TODO: This does not currently prevent multiple requests with the same id at the same time. - // If one of them is still running while all other ids are being iterated through. - nonce: AtomicU16, +#[derive(Debug)] +struct PeerInner { + sink: Mutex, + inbound_handle: JoinHandle<()>, + requests: Arc, + peer_id: PeerId, + socket_addr: SocketAddr, } impl Peer { - pub fn new(ws: WebSocket) -> Self { - let (sink, mut stream) = ws.split(); - let (event_sender, event_receiver) = broadcast::channel(32); - - let requests = Requests::default(); - let requests_clone = Arc::clone(&requests); - - let inbound_task = tokio::spawn(async move { - while let Some(message) = stream.next().await { - if let Ok(message) = message { - Self::handle_inbound(message, &requests_clone, &event_sender) - .await - .ok(); - } - } - }); + /// Connects to a peer using its IP address and port. + pub async fn connect( + socket_addr: SocketAddr, + tls_connector: TlsConnector, + ) -> Result<(Self, mpsc::Receiver)> { + Self::connect_full_uri(&format!("wss://{socket_addr}/ws"), tls_connector).await + } - Self { - sink: Mutex::new(sink), - inbound_task, - event_receiver, - requests, - nonce: AtomicU16::new(0), - } + /// Connects to a peer using its full WebSocket URI. + /// For example, `wss://127.0.0.1:8444/ws`. + pub async fn connect_full_uri( + uri: &str, + tls_connector: TlsConnector, + ) -> Result<(Self, mpsc::Receiver)> { + let (ws, _) = tokio_tungstenite::connect_async_tls_with_config( + uri, + None, + false, + Some(Connector::NativeTls(tls_connector)), + ) + .await?; + Self::from_websocket(ws) } - pub async fn send_handshake( - &self, - network_id: String, - node_type: NodeType, - ) -> Result<(), Error<()>> { - let body = Handshake { - network_id, - protocol_version: "0.0.34".to_string(), - software_version: "0.0.0".to_string(), - server_port: 0, - node_type, - capabilities: vec![ - (1, "1".to_string()), - (2, "1".to_string()), - (3, "1".to_string()), - ], + /// Creates a peer from an existing WebSocket connection. + /// The connection must be secured with TLS, so that the certificate can be hashed in a peer id. + pub fn from_websocket(ws: WebSocket) -> Result<(Self, mpsc::Receiver)> { + let (socket_addr, cert) = match ws.get_ref() { + MaybeTlsStream::NativeTls(tls) => { + let tls_stream = tls.get_ref(); + let tcp_stream = tls_stream.get_ref().get_ref(); + (tcp_stream.peer_addr()?, tls_stream.peer_certificate()?) + } + _ => return Err(Error::MissingCertificate), }; - self.send(body).await - } - pub async fn request_puzzle_and_solution( - &self, - coin_id: Bytes32, - height: u32, - ) -> Result> { - let body = RequestPuzzleSolution { - coin_name: coin_id, - height, + let Some(cert) = cert else { + return Err(Error::MissingCertificate); }; - let response: RespondPuzzleSolution = self.request_or_reject(body).await?; - Ok(response.response) + + let mut hasher = Sha256::new(); + hasher.update(cert.to_der()?); + + let peer_id = PeerId(hasher.finalize_fixed().into()); + let (sink, stream) = ws.split(); + let (sender, receiver) = mpsc::channel(32); + + let requests = Arc::new(RequestMap::new()); + let requests_clone = requests.clone(); + + let inbound_handle = tokio::spawn(async move { + if let Err(error) = handle_inbound_messages(stream, sender, requests_clone).await { + log::warn!("Error handling message: {error}"); + } + }); + + let peer = Self(Arc::new(PeerInner { + sink: Mutex::new(sink), + inbound_handle, + requests, + peer_id, + socket_addr, + })); + + Ok((peer, receiver)) } - pub async fn send_transaction( - &self, - spend_bundle: SpendBundle, - ) -> Result> { - let body = SendTransaction { - transaction: spend_bundle, - }; - self.request(body).await + /// The hash of the TLS certificate used by the peer. + pub fn peer_id(&self) -> PeerId { + self.0.peer_id } - pub async fn request_block_header( - &self, - height: u32, - ) -> Result> { - let body = RequestBlockHeader { height }; - let response: RespondBlockHeader = self.request_or_reject(body).await?; - Ok(response.header_block) + /// The IP address and port of the peer connection. + pub fn socket_addr(&self) -> SocketAddr { + self.0.socket_addr } - pub async fn request_block_headers( - &self, - start_height: u32, - end_height: u32, - return_filter: bool, - ) -> Result, Error<()>> { - let body = RequestBlockHeaders { - start_height, - end_height, - return_filter, - }; - let response: RespondBlockHeaders = - self.request_or_reject(body) - .await - .map_err(|error: Error| match error { - Error::Rejection(_rejection) => Error::Rejection(()), - Error::Chia(error) => Error::Chia(error), - Error::WebSocket(error) => Error::WebSocket(error), - Error::InvalidResponse(error) => Error::InvalidResponse(error), - Error::MissingResponse => Error::MissingResponse, - })?; - Ok(response.header_blocks) + pub async fn send_transaction(&self, spend_bundle: SpendBundle) -> Result { + self.request_infallible(SendTransaction::new(spend_bundle)) + .await } - pub async fn request_removals( + pub async fn request_puzzle_state( &self, - height: u32, + puzzle_hashes: Vec, + previous_height: Option, header_hash: Bytes32, - coin_ids: Option>, - ) -> Result> { - let body = RequestRemovals { - height, + filters: CoinStateFilters, + subscribe_when_finished: bool, + ) -> Result> { + self.request_fallible(RequestPuzzleState::new( + puzzle_hashes, + previous_height, header_hash, - coin_names: coin_ids, - }; - self.request_or_reject(body).await + filters, + subscribe_when_finished, + )) + .await } - pub async fn request_additions( + pub async fn request_coin_state( &self, - height: u32, - header_hash: Option, - puzzle_hashes: Option>, - ) -> Result> { - let body = RequestAdditions { - height, + coin_ids: Vec, + previous_height: Option, + header_hash: Bytes32, + subscribe: bool, + ) -> Result> { + self.request_fallible(RequestCoinState::new( + coin_ids, + previous_height, header_hash, - puzzle_hashes, - }; - self.request_or_reject(body).await + subscribe, + )) + .await } pub async fn register_for_ph_updates( &self, puzzle_hashes: Vec, min_height: u32, - ) -> Result, Error<()>> { - let body = RegisterForPhUpdates { - puzzle_hashes, - min_height, - }; - let response: RespondToPhUpdates = self.request(body).await?; - Ok(response.coin_states) + ) -> Result { + self.request_infallible(RegisterForPhUpdates::new(puzzle_hashes, min_height)) + .await } pub async fn register_for_coin_updates( &self, coin_ids: Vec, min_height: u32, - ) -> Result, Error<()>> { - let body = RegisterForCoinUpdates { - coin_ids, - min_height, - }; - let response: RespondToCoinUpdates = self.request(body).await?; - Ok(response.coin_states) + ) -> Result { + self.request_infallible(RegisterForCoinUpdates::new(coin_ids, min_height)) + .await } - pub async fn request_children(&self, coin_id: Bytes32) -> Result, Error<()>> { - let body = RequestChildren { coin_name: coin_id }; - let response: RespondChildren = self.request(body).await?; - Ok(response.coin_states) + pub async fn request_transaction(&self, transaction_id: Bytes32) -> Result { + self.request_infallible(RequestTransaction::new(transaction_id)) + .await } - pub async fn request_ses_info( + pub async fn request_puzzle_and_solution( &self, - start_height: u32, - end_height: u32, - ) -> Result> { - let body = RequestSesInfo { - start_height, - end_height, - }; - self.request(body).await + coin_id: Bytes32, + height: u32, + ) -> Result> { + match self + .request_fallible::(RequestPuzzleSolution::new( + coin_id, height, + )) + .await? + { + Ok(response) => Ok(Ok(response.response)), + Err(rejection) => Ok(Err(rejection)), + } } - pub async fn request_fee_estimates( - &self, - time_targets: Vec, - ) -> Result> { - let body = RequestFeeEstimates { time_targets }; - let response: RespondFeeEstimates = self.request(body).await?; - Ok(response.estimates) + pub async fn request_children(&self, coin_id: Bytes32) -> Result { + self.request_infallible(RequestChildren::new(coin_id)).await + } + + pub async fn request_peers(&self) -> Result { + self.request_infallible(RequestPeers::new()).await } - pub async fn send(&self, body: T) -> Result<(), Error<()>> + /// Sends a message to the peer, but does not expect any response. + pub async fn send(&self, body: T) -> Result<()> where T: Streamable + ChiaProtocolMessage, { - // Create the message. - let message = Message { - msg_type: T::msg_type(), - id: None, - data: stream(&body)?.into(), - }; + let message = Message::new(T::msg_type(), None, body.to_bytes()?.into()) + .to_bytes()? + .into(); - // Send the message through the websocket. - let mut sink = self.sink.lock().await; - sink.send(stream(&message)?.into()).await?; + self.0.sink.lock().await.send(message).await?; Ok(()) } - pub async fn request_or_reject(&self, body: B) -> Result> + /// Sends a message to the peer and expects a message that's either a response or a rejection. + pub async fn request_fallible(&self, body: B) -> Result> where T: Streamable + ChiaProtocolMessage, - R: Streamable + ChiaProtocolMessage, + E: Streamable + ChiaProtocolMessage, B: Streamable + ChiaProtocolMessage, { let message = self.request_raw(body).await?; - let data = message.data.as_ref(); - + if message.msg_type != T::msg_type() && message.msg_type != E::msg_type() { + return Err(Error::InvalidResponse( + vec![T::msg_type(), E::msg_type()], + message.msg_type, + )); + } if message.msg_type == T::msg_type() { - T::from_bytes(data).or(Err(Error::InvalidResponse(message))) - } else if message.msg_type == R::msg_type() { - let rejection = R::from_bytes(data).or(Err(Error::InvalidResponse(message)))?; - Err(Error::Rejection(rejection)) + Ok(Ok(T::from_bytes(&message.data)?)) } else { - Err(Error::InvalidResponse(message)) + Ok(Err(E::from_bytes(&message.data)?)) } } - pub async fn request(&self, body: T) -> Result> + /// Sends a message to the peer and expects a specific response message. + pub async fn request_infallible(&self, body: B) -> Result where - Response: Streamable + ChiaProtocolMessage, T: Streamable + ChiaProtocolMessage, + B: Streamable + ChiaProtocolMessage, { let message = self.request_raw(body).await?; - let data = message.data.as_ref(); - - if message.msg_type == Response::msg_type() { - Response::from_bytes(data).or(Err(Error::InvalidResponse(message))) - } else { - Err(Error::InvalidResponse(message)) + if message.msg_type != T::msg_type() { + return Err(Error::InvalidResponse( + vec![T::msg_type()], + message.msg_type, + )); } + Ok(T::from_bytes(&message.data)?) } - pub async fn request_raw(&self, body: T) -> Result> + /// Sends a message to the peer and expects any arbitrary protocol message without parsing it. + pub async fn request_raw(&self, body: T) -> Result where T: Streamable + ChiaProtocolMessage, { - // Get the current nonce and increment. - let message_id = self.nonce.fetch_add(1, Ordering::SeqCst); + let (sender, receiver) = oneshot::channel(); - // Create the message. let message = Message { msg_type: T::msg_type(), - id: Some(message_id), - data: stream(&body)?.into(), - }; - - // Create a saved oneshot channel to receive the response. - let (sender, receiver) = oneshot::channel::(); - self.requests.lock().await.insert(message_id, sender); - - // Send the message. - let bytes = match stream(&message) { - Ok(bytes) => bytes.into(), - Err(error) => { - self.requests.lock().await.remove(&message_id); - return Err(error.into()); - } - }; - let send_result = self.sink.lock().await.send(bytes).await; - - if let Err(error) = send_result { - self.requests.lock().await.remove(&message_id); - return Err(error.into()); + id: Some(self.0.requests.insert(sender).await), + data: body.to_bytes()?.into(), } + .to_bytes()? + .into(); - // Wait for the response. - let response = receiver.await; - - // Remove the one shot channel. - self.requests.lock().await.remove(&message_id); - - // Handle the response, if present. - response.or(Err(Error::MissingResponse)) + self.0.sink.lock().await.send(message).await?; + Ok(receiver.await?) } +} - pub fn receiver(&self) -> &broadcast::Receiver { - &self.event_receiver +impl Drop for PeerInner { + fn drop(&mut self) { + self.inbound_handle.abort(); } +} - pub fn receiver_mut(&mut self) -> &mut broadcast::Receiver { - &mut self.event_receiver - } +async fn handle_inbound_messages( + mut stream: Stream, + sender: mpsc::Sender, + requests: Arc, +) -> Result<()> { + use tungstenite::Message::{Binary, Close, Frame, Ping, Pong, Text}; - async fn handle_inbound( - message: WsMessage, - requests: &Requests, - event_sender: &broadcast::Sender, - ) -> Result<(), Error<()>> { - // Parse the message. - let message = Message::from_bytes(message.into_data().as_ref())?; - - if let Some(id) = message.id { - // Send response through oneshot channel if present. - if let Some(request) = requests.lock().await.remove(&id) { - request.send(message).ok(); - } - return Ok(()); - } + while let Some(message) = stream.next().await { + let message = message?; - macro_rules! events { - ( $( $event:ident ),+ $(,)? ) => { - match message.msg_type { - $( ProtocolMessageTypes::$event => { - event_sender - .send(PeerEvent::$event($event::from_bytes(message.data.as_ref())?)) - .ok(); - } )+ - _ => {} - } - }; + match message { + Text(text) => { + log::warn!("Received unexpected text message: {text}"); + } + Close(close) => { + log::warn!("Received close: {close:?}"); + break; + } + Ping(_ping) => {} + Pong(_pong) => {} + Binary(binary) => { + let message = Message::from_bytes(&binary)?; + + let Some(id) = message.id else { + sender.send(message).await.map_err(|error| { + log::warn!("Failed to send peer message event: {error}"); + Error::EventNotSent + })?; + continue; + }; + + let Some(request) = requests.remove(id).await else { + log::warn!( + "Received {:?} message with untracked id {id}", + message.msg_type + ); + return Err(Error::UnexpectedMessage(message.msg_type)); + }; + + request.send(message); + } + Frame(frame) => { + log::warn!("Received frame: {frame}"); + } } - - // TODO: Handle unexpected messages. - events!(CoinStateUpdate, NewPeakWallet); - - Ok(()) - } -} - -impl Drop for Peer { - fn drop(&mut self) { - self.inbound_task.abort(); } + Ok(()) } diff --git a/crates/chia-client/src/request_map.rs b/crates/chia-client/src/request_map.rs new file mode 100644 index 000000000..597ef17c6 --- /dev/null +++ b/crates/chia-client/src/request_map.rs @@ -0,0 +1,65 @@ +use std::{collections::HashMap, sync::Arc}; + +use chia_protocol::Message; +use tokio::sync::{oneshot, Mutex, OwnedSemaphorePermit, Semaphore}; + +#[derive(Debug)] +pub struct Request { + sender: oneshot::Sender, + _permit: OwnedSemaphorePermit, +} + +impl Request { + pub fn send(self, message: Message) { + self.sender.send(message).ok(); + } +} + +#[derive(Debug)] +pub struct RequestMap { + items: Mutex>, + semaphore: Arc, +} + +impl RequestMap { + pub fn new() -> Self { + Self { + items: Mutex::new(HashMap::new()), + semaphore: Arc::new(Semaphore::new(u16::MAX as usize)), + } + } + + pub async fn insert(&self, sender: oneshot::Sender) -> u16 { + let permit = self + .semaphore + .clone() + .acquire_owned() + .await + .expect("semaphore closed"); + + let mut items = self.items.lock().await; + + let mut index = None; + + for i in 0..=u16::MAX { + if !items.contains_key(&i) { + index = Some(i); + break; + } + } + + let index = index.expect("exceeded expected number of requests"); + items.insert( + index, + Request { + sender, + _permit: permit, + }, + ); + index + } + + pub async fn remove(&self, id: u16) -> Option { + self.items.lock().await.remove(&id) + } +} diff --git a/crates/chia-client/src/tls.rs b/crates/chia-client/src/tls.rs new file mode 100644 index 000000000..efbd9fc10 --- /dev/null +++ b/crates/chia-client/src/tls.rs @@ -0,0 +1,11 @@ +use native_tls::{Identity, TlsConnector}; + +pub fn create_tls_connector( + cert_pem: &[u8], + key_pem: &[u8], +) -> Result { + TlsConnector::builder() + .identity(Identity::from_pkcs8(cert_pem, key_pem)?) + .danger_accept_invalid_certs(true) + .build() +} diff --git a/crates/chia-client/src/utils.rs b/crates/chia-client/src/utils.rs deleted file mode 100644 index cddc7d40a..000000000 --- a/crates/chia-client/src/utils.rs +++ /dev/null @@ -1,7 +0,0 @@ -use chia_traits::{chia_error::Result, Streamable}; - -pub fn stream(value: &T) -> Result> { - let mut bytes = Vec::new(); - value.stream(&mut bytes)?; - Ok(bytes) -}